keras 简单 lstm实例(基于one-hot编码)


Posted in Python onJuly 02, 2020

简单的LSTM问题,能够预测一句话的下一个字词是什么

固定长度的句子,一个句子有3个词。

使用one-hot编码

各种引用

import keras
from keras.models import Sequential
from keras.layers import LSTM, Dense, Dropout
import numpy as np

数据预处理

data = 'abcdefghijklmnopqrstuvwxyz'
data_set = set(data)
 
word_2_int = {b:a for a,b in enumerate(data_set)}
int_2_word = {a:b for a,b in enumerate(data_set)}
 
word_len = len(data_set)
print(word_2_int)
print(int_2_word)

一些辅助函数

def words_2_ints(words):
 ints = []
 for itmp in words:
  ints.append(word_2_int[itmp])
 return ints
 
print(words_2_ints('ab'))
 
def words_2_one_hot(words, num_classes=word_len):
 return keras.utils.to_categorical(words_2_ints(words), num_classes=num_classes)
print(words_2_one_hot('a'))
def get_one_hot_max_idx(one_hot):
 idx_ = 0
 max_ = 0
 for i in range(len(one_hot)):
  if max_ < one_hot[i]:
   max_ = one_hot[i]
   idx_ = i
 return idx_
 
def one_hot_2_words(one_hot):
 tmp = []
 for itmp in one_hot:
  tmp.append(int_2_word[get_one_hot_max_idx(itmp)])
 return "".join(tmp)
 
print( one_hot_2_words(words_2_one_hot('adhjlkw')) )

构造样本

time_step = 3 #一个句子有3个词
 
def genarate_data(batch_size=5, genarate_num=100):
 #genarate_num = -1 表示一直循环下去,genarate_num=1表示生成一个batch的数据,以此类推
 #这里,我也不知道数据有多少,就这么循环的生成下去吧。
 #入参batch_size 控制一个batch 有多少数据,也就是一次要yield进多少个batch_size的数据
 '''
 例如,一个batch有batch_size=5个样本,那么对于这个例子,需要yield进的数据为:
 abc->d
 bcd->e
 cde->f
 def->g
 efg->h
 然后把这些数据都转换成one-hot形式,最终数据,输入x的形式为:
 
 [第1个batch]
 [第2个batch]
 ...
 [第genarate_num个batch]
 
 每个batch的形式为:
 
 [第1句话(如abc)]
 [第2句话(如bcd)]
 ...
 每一句话的形式为:
 
 [第1个词的one-hot表示]
 [第2个词的one-hot表示]
 ...
 '''
 cnt = 0
 batch_x = []
 batch_y = []
 sample_num = 0
 while(True):
  for i in range(len(data) - time_step):
   batch_x.append(words_2_one_hot(data[i : i+time_step]))
   batch_y.append(words_2_one_hot(data[i+time_step])[0]) #这里数据加[0],是为了符合keras的输出数据格式。 因为不加[0],表示是3维的数据。 你可以自己尝试不加0,看下面的test打印出来是什么
   sample_num += 1
   #print('sample num is :', sample_num)
   if len(batch_x) == batch_size:
    yield (np.array(batch_x), np.array(batch_y))
    batch_x = []
    batch_y = []
    if genarate_num != -1:
     cnt += 1
 
    if cnt == genarate_num:
     return
   
for test in genarate_data(batch_size=3, genarate_num=1):
 print('--------x:')
 print(test[0])
 print('--------y:')
 print(test[1])

搭建模型并训练

model = Sequential()
 
# LSTM输出维度为 128
# input_shape控制输入数据的形态
# time_stemp表示一句话有多少个单词
# word_len 表示一个单词用多少维度表示,这里是26维
 
model.add(LSTM(128, input_shape=(time_step, word_len)))
model.add(Dense(word_len, activation='softmax')) #输出用一个softmax,来分类,维度就是26,预测是哪一个字母
 
model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
 
model.fit_generator(generator=genarate_data(batch_size=5, genarate_num=-1), epochs=50, steps_per_epoch=10)
#steps_per_epoch的意思是,一个epoch中,执行多少个batch
#batch_size是一个batch中,有多少个样本。
#所以,batch_size*steps_per_epoch就等于一个epoch中,训练的样本数量。(这个说法不对!再观察看看吧)
#可以将epochs设置成1,或者2,然后在genarate_data中打印样本序号,观察到样本总数。

使用训练后的模型进行预测:

result = model.predict(np.array([words_2_one_hot('bcd')]))

print(one_hot_2_words(result))

可以看到,预测结果为

e

补充知识:训练集产生的onehot编码特征如何在测试集、预测集复现

数据处理中有时要用到onehot编码,如果使用pandas自带的get_dummies方法,训练集产生的onehot编码特征会跟测试集、预测集不一样,正确的方式是使用sklearn自带的OneHotEncoder。

代码

import pandas as pd
from sklearn.preprocessing import OneHotEncoder
ohe = OneHotEncoder(handle_unknown='ignore')
data_train=pd.DataFrame({'职业':['数据挖掘工程师','数据库开发工程师','数据分析师','数据分析师'],
     '籍贯':['福州','厦门','泉州','龙岩']})
ohe.fit(data_train)#训练规则
feature_names=ohe.get_feature_names(data_train.columns)#获取编码后的特征名
data_train_onehot=pd.DataFrame(ohe.transform(data_train).toarray(),columns=feature_names)#应用规则在训练集上
 
data_new=pd.DataFrame({'职业':['数据挖掘工程师','jave工程师'],
     '籍贯':['福州','莆田']})
data_new_onehot=pd.DataFrame(ohe.transform(data_new).toarray(),columns=feature_names)#应用规则在预测集上

以上这篇keras 简单 lstm实例(基于one-hot编码)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python服务器端收发请求的实现代码
Sep 29 Python
关于你不想知道的所有Python3 unicode特性
Nov 28 Python
仅用500行Python代码实现一个英文解析器的教程
Apr 02 Python
django2+uwsgi+nginx上线部署到服务器Ubuntu16.04
Jun 26 Python
Python退火算法在高次方程的应用
Jul 26 Python
Python中几种属性访问的区别与用法详解
Oct 10 Python
使用python 对验证码图片进行降噪处理
Dec 18 Python
Python matplotlib画图时图例说明(legend)放到图像外侧详解
May 16 Python
python代码如何注释
Jun 01 Python
对Pytorch 中的contiguous理解说明
Mar 03 Python
pandas中DataFrame数据合并连接(merge、join、concat)
May 30 Python
Python四款GUI图形界面库介绍
Jun 05 Python
Python装饰器结合递归原理解析
Jul 02 #Python
Python OpenCV读取中文路径图像的方法
Jul 02 #Python
keras.utils.to_categorical和one hot格式解析
Jul 02 #Python
python 使用多线程创建一个Buffer缓存器的实现思路
Jul 02 #Python
浅谈keras中的keras.utils.to_categorical用法
Jul 02 #Python
Python使用OpenPyXL处理Excel表格
Jul 02 #Python
解决keras GAN训练是loss不发生变化,accuracy一直为0.5的问题
Jul 02 #Python
You might like
关于我转生变成史莱姆这档事:第二季PV上线,萌王2021年回归
2020/05/06 日漫
dojo 之基础篇
2007/03/24 Javascript
javascript 继承实现方法
2009/08/26 Javascript
dropdownlist之间的互相联动实现(显示与隐藏)
2009/11/24 Javascript
jquery垂直公告滚动实现代码
2013/12/08 Javascript
js与jquery实时监听输入框值的oninput与onpropertychange方法
2015/02/05 Javascript
javascript删除数组重复元素的方法汇总
2015/06/24 Javascript
AngularJS实现元素显示和隐藏的几个案例
2015/12/09 Javascript
基于JQuery实现图片轮播效果(焦点图)
2016/02/02 Javascript
jQuery简单设置文本框回车事件的方法
2016/08/01 Javascript
微信小程序实现文字跑马灯
2020/05/26 Javascript
JavaScript使用闭包模仿块级作用域操作示例
2019/01/21 Javascript
详解vue中移动端自适应方案
2019/05/05 Javascript
Vue 实现把表单form数据 转化成json格式的数据
2019/10/29 Javascript
微信小程序地图绘制线段并且测量(实例代码)
2020/01/02 Javascript
JavaScript实现动态生成表格
2020/08/02 Javascript
[01:37]TI4西雅图DOTA2前线报道 VG拿下首胜教练357给出获胜秘诀
2014/07/10 DOTA
在Python中操作文件之seek()方法的使用教程
2015/05/24 Python
pycham查看程序执行的时间方法
2018/11/29 Python
python3 pygame实现接小球游戏
2019/05/14 Python
Python Django 简单分页的实现代码解析
2019/08/21 Python
澳大利亚领先的在线葡萄酒零售商:Get Wines Direct
2018/03/27 全球购物
街头时尚在线:JESSICABUURMAN
2019/06/16 全球购物
锐步英国官网:Reebok英国
2019/11/29 全球购物
介绍一下Java中的Class类
2015/04/10 面试题
如何定义一个可复用的服务
2014/09/30 面试题
领导失职检讨书
2014/02/24 职场文书
副总经理任命书
2014/06/05 职场文书
小学安全汇报材料
2014/08/14 职场文书
校园会短篇的广播稿
2014/10/21 职场文书
2014酒店客房部工作总结
2014/12/16 职场文书
党校党性分析材料
2014/12/19 职场文书
2015年小学重阳节活动总结
2015/07/29 职场文书
HR在给员工开具离职证明时,需要注意哪些问题?
2019/07/03 职场文书
Python Matplotlib绘制动画的代码详解
2022/05/30 Python
Spring Boot项目如何优雅实现Excel导入与导出功能
2022/06/10 Java/Android