keras 简单 lstm实例(基于one-hot编码)


Posted in Python onJuly 02, 2020

简单的LSTM问题,能够预测一句话的下一个字词是什么

固定长度的句子,一个句子有3个词。

使用one-hot编码

各种引用

import keras
from keras.models import Sequential
from keras.layers import LSTM, Dense, Dropout
import numpy as np

数据预处理

data = 'abcdefghijklmnopqrstuvwxyz'
data_set = set(data)
 
word_2_int = {b:a for a,b in enumerate(data_set)}
int_2_word = {a:b for a,b in enumerate(data_set)}
 
word_len = len(data_set)
print(word_2_int)
print(int_2_word)

一些辅助函数

def words_2_ints(words):
 ints = []
 for itmp in words:
  ints.append(word_2_int[itmp])
 return ints
 
print(words_2_ints('ab'))
 
def words_2_one_hot(words, num_classes=word_len):
 return keras.utils.to_categorical(words_2_ints(words), num_classes=num_classes)
print(words_2_one_hot('a'))
def get_one_hot_max_idx(one_hot):
 idx_ = 0
 max_ = 0
 for i in range(len(one_hot)):
  if max_ < one_hot[i]:
   max_ = one_hot[i]
   idx_ = i
 return idx_
 
def one_hot_2_words(one_hot):
 tmp = []
 for itmp in one_hot:
  tmp.append(int_2_word[get_one_hot_max_idx(itmp)])
 return "".join(tmp)
 
print( one_hot_2_words(words_2_one_hot('adhjlkw')) )

构造样本

time_step = 3 #一个句子有3个词
 
def genarate_data(batch_size=5, genarate_num=100):
 #genarate_num = -1 表示一直循环下去,genarate_num=1表示生成一个batch的数据,以此类推
 #这里,我也不知道数据有多少,就这么循环的生成下去吧。
 #入参batch_size 控制一个batch 有多少数据,也就是一次要yield进多少个batch_size的数据
 '''
 例如,一个batch有batch_size=5个样本,那么对于这个例子,需要yield进的数据为:
 abc->d
 bcd->e
 cde->f
 def->g
 efg->h
 然后把这些数据都转换成one-hot形式,最终数据,输入x的形式为:
 
 [第1个batch]
 [第2个batch]
 ...
 [第genarate_num个batch]
 
 每个batch的形式为:
 
 [第1句话(如abc)]
 [第2句话(如bcd)]
 ...
 每一句话的形式为:
 
 [第1个词的one-hot表示]
 [第2个词的one-hot表示]
 ...
 '''
 cnt = 0
 batch_x = []
 batch_y = []
 sample_num = 0
 while(True):
  for i in range(len(data) - time_step):
   batch_x.append(words_2_one_hot(data[i : i+time_step]))
   batch_y.append(words_2_one_hot(data[i+time_step])[0]) #这里数据加[0],是为了符合keras的输出数据格式。 因为不加[0],表示是3维的数据。 你可以自己尝试不加0,看下面的test打印出来是什么
   sample_num += 1
   #print('sample num is :', sample_num)
   if len(batch_x) == batch_size:
    yield (np.array(batch_x), np.array(batch_y))
    batch_x = []
    batch_y = []
    if genarate_num != -1:
     cnt += 1
 
    if cnt == genarate_num:
     return
   
for test in genarate_data(batch_size=3, genarate_num=1):
 print('--------x:')
 print(test[0])
 print('--------y:')
 print(test[1])

搭建模型并训练

model = Sequential()
 
# LSTM输出维度为 128
# input_shape控制输入数据的形态
# time_stemp表示一句话有多少个单词
# word_len 表示一个单词用多少维度表示,这里是26维
 
model.add(LSTM(128, input_shape=(time_step, word_len)))
model.add(Dense(word_len, activation='softmax')) #输出用一个softmax,来分类,维度就是26,预测是哪一个字母
 
model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
 
model.fit_generator(generator=genarate_data(batch_size=5, genarate_num=-1), epochs=50, steps_per_epoch=10)
#steps_per_epoch的意思是,一个epoch中,执行多少个batch
#batch_size是一个batch中,有多少个样本。
#所以,batch_size*steps_per_epoch就等于一个epoch中,训练的样本数量。(这个说法不对!再观察看看吧)
#可以将epochs设置成1,或者2,然后在genarate_data中打印样本序号,观察到样本总数。

使用训练后的模型进行预测:

result = model.predict(np.array([words_2_one_hot('bcd')]))

print(one_hot_2_words(result))

可以看到,预测结果为

e

补充知识:训练集产生的onehot编码特征如何在测试集、预测集复现

数据处理中有时要用到onehot编码,如果使用pandas自带的get_dummies方法,训练集产生的onehot编码特征会跟测试集、预测集不一样,正确的方式是使用sklearn自带的OneHotEncoder。

代码

import pandas as pd
from sklearn.preprocessing import OneHotEncoder
ohe = OneHotEncoder(handle_unknown='ignore')
data_train=pd.DataFrame({'职业':['数据挖掘工程师','数据库开发工程师','数据分析师','数据分析师'],
     '籍贯':['福州','厦门','泉州','龙岩']})
ohe.fit(data_train)#训练规则
feature_names=ohe.get_feature_names(data_train.columns)#获取编码后的特征名
data_train_onehot=pd.DataFrame(ohe.transform(data_train).toarray(),columns=feature_names)#应用规则在训练集上
 
data_new=pd.DataFrame({'职业':['数据挖掘工程师','jave工程师'],
     '籍贯':['福州','莆田']})
data_new_onehot=pd.DataFrame(ohe.transform(data_new).toarray(),columns=feature_names)#应用规则在预测集上

以上这篇keras 简单 lstm实例(基于one-hot编码)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
安装dbus-python的简要教程
May 05 Python
python实现带声音的摩斯码翻译实现方法
May 20 Python
Python实现的选择排序算法示例
Nov 29 Python
python 异或加密字符串的实例
Oct 14 Python
Python字典的基本用法实例分析【创建、增加、获取、修改、删除】
Mar 05 Python
tensorflow如何批量读取图片
Aug 29 Python
Django之路由层的实现
Sep 09 Python
Django 请求Request的具体使用方法
Nov 11 Python
python 如何将office文件转换为PDF
Sep 22 Python
用python对excel查重
Dec 07 Python
基于Pytorch版yolov5的滑块验证码破解思路详解
Feb 25 Python
教你怎么用Python生成九宫格照片
May 20 Python
Python装饰器结合递归原理解析
Jul 02 #Python
Python OpenCV读取中文路径图像的方法
Jul 02 #Python
keras.utils.to_categorical和one hot格式解析
Jul 02 #Python
python 使用多线程创建一个Buffer缓存器的实现思路
Jul 02 #Python
浅谈keras中的keras.utils.to_categorical用法
Jul 02 #Python
Python使用OpenPyXL处理Excel表格
Jul 02 #Python
解决keras GAN训练是loss不发生变化,accuracy一直为0.5的问题
Jul 02 #Python
You might like
了解咖啡雨林联盟认证 什么是雨林认证 雨林认证是什么意思
2021/03/05 新手入门
PHP中error_log()函数的使用方法
2015/01/20 PHP
Thinkphp自定义代码生成工具及用法说明(附下载地址)
2016/05/27 PHP
在jQuery ajax中按钮button和submit的区别分析
2012/10/07 Javascript
json属性名为什么要双引号(个人猜测)
2014/07/31 Javascript
nodejs URL模块操作URL相关方法介绍
2015/03/03 NodeJs
详解JavaScript中的4种类型识别方法
2015/09/14 Javascript
js倒计时简单实现方法
2015/12/17 Javascript
微信小程序 slider的简单实例
2017/04/19 Javascript
vue小白入门教程
2018/04/02 Javascript
NodeJs项目中关闭ESLint的方法
2018/08/09 NodeJs
浅谈webpack4 图片处理汇总
2018/09/12 Javascript
Vue 使用计时器实现跑马灯效果的实例代码
2019/07/11 Javascript
使用element-ui的el-menu导航选中后刷新页面保持当前选中状态
2019/07/19 Javascript
[36:17]DOTA2上海特级锦标赛 - VGL音乐会全集
2016/03/06 DOTA
python抓取网页内容并进行语音播报的方法
2018/12/24 Python
Python向excel中写入数据的方法
2019/05/05 Python
在python中将list分段并保存为array类型的方法
2019/07/15 Python
Python API自动化框架总结
2019/11/12 Python
使用PyTorch训练一个图像分类器实例
2020/01/08 Python
关于Python解包知识点总结
2020/05/05 Python
html特殊符号示例 html特殊字符编码对照表
2014/01/14 HTML / CSS
Holiday Inn中国官网:IHG旗下假日酒店预订
2018/04/08 全球购物
Jack Rogers官网:美国经典的女性鞋靴品牌
2019/09/04 全球购物
娱乐地球:Entertainment Earth
2020/01/08 全球购物
为什么group by 和order by会使查询变慢
2014/05/16 面试题
说一下Linux下有关用户和组管理的命令
2016/01/04 面试题
医院办公室主任职责
2013/12/29 职场文书
服装采购员岗位职责
2014/03/15 职场文书
服务承诺书格式
2014/05/21 职场文书
乡镇综治宣传月活动总结
2014/07/02 职场文书
环境保护与污染治理求职信
2014/07/16 职场文书
党员教师个人对照检查材料范文
2014/09/25 职场文书
庆祝三八妇女节标语
2014/10/09 职场文书
市场营销计划书
2015/01/17 职场文书
2016年公务员六五普法心得体会
2016/01/21 职场文书