keras 简单 lstm实例(基于one-hot编码)


Posted in Python onJuly 02, 2020

简单的LSTM问题,能够预测一句话的下一个字词是什么

固定长度的句子,一个句子有3个词。

使用one-hot编码

各种引用

import keras
from keras.models import Sequential
from keras.layers import LSTM, Dense, Dropout
import numpy as np

数据预处理

data = 'abcdefghijklmnopqrstuvwxyz'
data_set = set(data)
 
word_2_int = {b:a for a,b in enumerate(data_set)}
int_2_word = {a:b for a,b in enumerate(data_set)}
 
word_len = len(data_set)
print(word_2_int)
print(int_2_word)

一些辅助函数

def words_2_ints(words):
 ints = []
 for itmp in words:
  ints.append(word_2_int[itmp])
 return ints
 
print(words_2_ints('ab'))
 
def words_2_one_hot(words, num_classes=word_len):
 return keras.utils.to_categorical(words_2_ints(words), num_classes=num_classes)
print(words_2_one_hot('a'))
def get_one_hot_max_idx(one_hot):
 idx_ = 0
 max_ = 0
 for i in range(len(one_hot)):
  if max_ < one_hot[i]:
   max_ = one_hot[i]
   idx_ = i
 return idx_
 
def one_hot_2_words(one_hot):
 tmp = []
 for itmp in one_hot:
  tmp.append(int_2_word[get_one_hot_max_idx(itmp)])
 return "".join(tmp)
 
print( one_hot_2_words(words_2_one_hot('adhjlkw')) )

构造样本

time_step = 3 #一个句子有3个词
 
def genarate_data(batch_size=5, genarate_num=100):
 #genarate_num = -1 表示一直循环下去,genarate_num=1表示生成一个batch的数据,以此类推
 #这里,我也不知道数据有多少,就这么循环的生成下去吧。
 #入参batch_size 控制一个batch 有多少数据,也就是一次要yield进多少个batch_size的数据
 '''
 例如,一个batch有batch_size=5个样本,那么对于这个例子,需要yield进的数据为:
 abc->d
 bcd->e
 cde->f
 def->g
 efg->h
 然后把这些数据都转换成one-hot形式,最终数据,输入x的形式为:
 
 [第1个batch]
 [第2个batch]
 ...
 [第genarate_num个batch]
 
 每个batch的形式为:
 
 [第1句话(如abc)]
 [第2句话(如bcd)]
 ...
 每一句话的形式为:
 
 [第1个词的one-hot表示]
 [第2个词的one-hot表示]
 ...
 '''
 cnt = 0
 batch_x = []
 batch_y = []
 sample_num = 0
 while(True):
  for i in range(len(data) - time_step):
   batch_x.append(words_2_one_hot(data[i : i+time_step]))
   batch_y.append(words_2_one_hot(data[i+time_step])[0]) #这里数据加[0],是为了符合keras的输出数据格式。 因为不加[0],表示是3维的数据。 你可以自己尝试不加0,看下面的test打印出来是什么
   sample_num += 1
   #print('sample num is :', sample_num)
   if len(batch_x) == batch_size:
    yield (np.array(batch_x), np.array(batch_y))
    batch_x = []
    batch_y = []
    if genarate_num != -1:
     cnt += 1
 
    if cnt == genarate_num:
     return
   
for test in genarate_data(batch_size=3, genarate_num=1):
 print('--------x:')
 print(test[0])
 print('--------y:')
 print(test[1])

搭建模型并训练

model = Sequential()
 
# LSTM输出维度为 128
# input_shape控制输入数据的形态
# time_stemp表示一句话有多少个单词
# word_len 表示一个单词用多少维度表示,这里是26维
 
model.add(LSTM(128, input_shape=(time_step, word_len)))
model.add(Dense(word_len, activation='softmax')) #输出用一个softmax,来分类,维度就是26,预测是哪一个字母
 
model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
 
model.fit_generator(generator=genarate_data(batch_size=5, genarate_num=-1), epochs=50, steps_per_epoch=10)
#steps_per_epoch的意思是,一个epoch中,执行多少个batch
#batch_size是一个batch中,有多少个样本。
#所以,batch_size*steps_per_epoch就等于一个epoch中,训练的样本数量。(这个说法不对!再观察看看吧)
#可以将epochs设置成1,或者2,然后在genarate_data中打印样本序号,观察到样本总数。

使用训练后的模型进行预测:

result = model.predict(np.array([words_2_one_hot('bcd')]))

print(one_hot_2_words(result))

可以看到,预测结果为

e

补充知识:训练集产生的onehot编码特征如何在测试集、预测集复现

数据处理中有时要用到onehot编码,如果使用pandas自带的get_dummies方法,训练集产生的onehot编码特征会跟测试集、预测集不一样,正确的方式是使用sklearn自带的OneHotEncoder。

代码

import pandas as pd
from sklearn.preprocessing import OneHotEncoder
ohe = OneHotEncoder(handle_unknown='ignore')
data_train=pd.DataFrame({'职业':['数据挖掘工程师','数据库开发工程师','数据分析师','数据分析师'],
     '籍贯':['福州','厦门','泉州','龙岩']})
ohe.fit(data_train)#训练规则
feature_names=ohe.get_feature_names(data_train.columns)#获取编码后的特征名
data_train_onehot=pd.DataFrame(ohe.transform(data_train).toarray(),columns=feature_names)#应用规则在训练集上
 
data_new=pd.DataFrame({'职业':['数据挖掘工程师','jave工程师'],
     '籍贯':['福州','莆田']})
data_new_onehot=pd.DataFrame(ohe.transform(data_new).toarray(),columns=feature_names)#应用规则在预测集上

以上这篇keras 简单 lstm实例(基于one-hot编码)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
树莓派中python获取GY-85九轴模块信息示例
Dec 05 Python
在Django的通用视图中处理Context的方法
Jul 21 Python
python 截取 取出一部分的字符串方法
Mar 01 Python
python数据结构之链表详解
Sep 12 Python
详解python中asyncio模块
Mar 03 Python
Python循环中else,break和continue的用法实例详解
Jul 11 Python
使用pyshp包进行shapefile文件修改的例子
Dec 06 Python
Pandas时间序列基础详解(转换,索引,切片)
Feb 26 Python
python实现梯度下降算法的实例详解
Aug 17 Python
5款实用的python 工具推荐
Oct 13 Python
python使用ctypes库调用DLL动态链接库
Oct 22 Python
Python+kivy BoxLayout布局示例代码详解
Dec 28 Python
Python装饰器结合递归原理解析
Jul 02 #Python
Python OpenCV读取中文路径图像的方法
Jul 02 #Python
keras.utils.to_categorical和one hot格式解析
Jul 02 #Python
python 使用多线程创建一个Buffer缓存器的实现思路
Jul 02 #Python
浅谈keras中的keras.utils.to_categorical用法
Jul 02 #Python
Python使用OpenPyXL处理Excel表格
Jul 02 #Python
解决keras GAN训练是loss不发生变化,accuracy一直为0.5的问题
Jul 02 #Python
You might like
如何去掉文章里的 html 语法
2006/10/09 PHP
PHP 反向排序和随机排序代码
2010/06/30 PHP
PHP实现把文本中的URL转换为链接的auolink()函数分享
2014/07/29 PHP
PHP中把有符号整型转换为无符号整型方法
2015/05/27 PHP
php版微信js-sdk支付接口类用法示例
2016/10/12 PHP
php array_reverse 以相反的顺序返回数组实例代码
2017/04/11 PHP
Yii框架数据库查询、增加、删除操作示例
2019/10/14 PHP
tp5框架使用cookie加密算法实现登录功能示例
2020/02/10 PHP
Javascript 判断 object 的特定类转载
2007/02/01 Javascript
JQuery与iframe交互实现代码
2009/12/24 Javascript
JavaScript使用focus()设置焦点失败的解决方法
2014/09/03 Javascript
JQuery异步获取返回值中文乱码的解决方法
2015/01/29 Javascript
javascript实现textarea中tab键的缩排处理方法
2015/06/26 Javascript
JavaScript实现基于十进制的四舍五入实例
2015/07/17 Javascript
node.js利用redis数据库缓存数据的方法
2017/03/01 Javascript
jquery动态添加带有样式的HTML标签元素方法
2018/02/24 jQuery
vue.js 实现评价五角星组件的实例代码
2018/08/13 Javascript
vue+webpack中配置ESLint
2018/11/07 Javascript
微信小程序冒泡事件及其阻止方法实例分析
2018/12/06 Javascript
详解使用uni-app开发微信小程序之登录模块
2019/05/09 Javascript
Vue项目页面跳转时浏览器窗口上方显示进度条功能
2020/03/26 Javascript
简单实现python数独游戏
2018/03/30 Python
详谈python在windows中的文件路径问题
2018/04/28 Python
对web.py设置favicon.ico的方法详解
2018/12/04 Python
Python用Try语句捕获异常的实例方法
2019/06/26 Python
浅析PyTorch中nn.Module的使用
2019/08/18 Python
CSS3打造百度贴吧的3D翻牌效果示例
2017/01/04 HTML / CSS
html5的canvas实现3d雪花飘舞效果
2013/12/27 HTML / CSS
C#如何进行LDAP用户校验
2012/11/21 面试题
一道Delphi面试题
2016/10/28 面试题
教育专业毕业生推荐信
2014/07/10 职场文书
施工员岗位职责
2015/02/10 职场文书
2016继续教育研修日志
2015/11/13 职场文书
2016元旦晚会主持词开场白和结束语
2015/12/04 职场文书
深度学习小工程练习之垃圾分类详解
2021/04/14 Python
不想升级Win11?教你彻底锁定老版Windows系统的方法(附下载地址)
2022/09/23 数码科技