keras 简单 lstm实例(基于one-hot编码)


Posted in Python onJuly 02, 2020

简单的LSTM问题,能够预测一句话的下一个字词是什么

固定长度的句子,一个句子有3个词。

使用one-hot编码

各种引用

import keras
from keras.models import Sequential
from keras.layers import LSTM, Dense, Dropout
import numpy as np

数据预处理

data = 'abcdefghijklmnopqrstuvwxyz'
data_set = set(data)
 
word_2_int = {b:a for a,b in enumerate(data_set)}
int_2_word = {a:b for a,b in enumerate(data_set)}
 
word_len = len(data_set)
print(word_2_int)
print(int_2_word)

一些辅助函数

def words_2_ints(words):
 ints = []
 for itmp in words:
  ints.append(word_2_int[itmp])
 return ints
 
print(words_2_ints('ab'))
 
def words_2_one_hot(words, num_classes=word_len):
 return keras.utils.to_categorical(words_2_ints(words), num_classes=num_classes)
print(words_2_one_hot('a'))
def get_one_hot_max_idx(one_hot):
 idx_ = 0
 max_ = 0
 for i in range(len(one_hot)):
  if max_ < one_hot[i]:
   max_ = one_hot[i]
   idx_ = i
 return idx_
 
def one_hot_2_words(one_hot):
 tmp = []
 for itmp in one_hot:
  tmp.append(int_2_word[get_one_hot_max_idx(itmp)])
 return "".join(tmp)
 
print( one_hot_2_words(words_2_one_hot('adhjlkw')) )

构造样本

time_step = 3 #一个句子有3个词
 
def genarate_data(batch_size=5, genarate_num=100):
 #genarate_num = -1 表示一直循环下去,genarate_num=1表示生成一个batch的数据,以此类推
 #这里,我也不知道数据有多少,就这么循环的生成下去吧。
 #入参batch_size 控制一个batch 有多少数据,也就是一次要yield进多少个batch_size的数据
 '''
 例如,一个batch有batch_size=5个样本,那么对于这个例子,需要yield进的数据为:
 abc->d
 bcd->e
 cde->f
 def->g
 efg->h
 然后把这些数据都转换成one-hot形式,最终数据,输入x的形式为:
 
 [第1个batch]
 [第2个batch]
 ...
 [第genarate_num个batch]
 
 每个batch的形式为:
 
 [第1句话(如abc)]
 [第2句话(如bcd)]
 ...
 每一句话的形式为:
 
 [第1个词的one-hot表示]
 [第2个词的one-hot表示]
 ...
 '''
 cnt = 0
 batch_x = []
 batch_y = []
 sample_num = 0
 while(True):
  for i in range(len(data) - time_step):
   batch_x.append(words_2_one_hot(data[i : i+time_step]))
   batch_y.append(words_2_one_hot(data[i+time_step])[0]) #这里数据加[0],是为了符合keras的输出数据格式。 因为不加[0],表示是3维的数据。 你可以自己尝试不加0,看下面的test打印出来是什么
   sample_num += 1
   #print('sample num is :', sample_num)
   if len(batch_x) == batch_size:
    yield (np.array(batch_x), np.array(batch_y))
    batch_x = []
    batch_y = []
    if genarate_num != -1:
     cnt += 1
 
    if cnt == genarate_num:
     return
   
for test in genarate_data(batch_size=3, genarate_num=1):
 print('--------x:')
 print(test[0])
 print('--------y:')
 print(test[1])

搭建模型并训练

model = Sequential()
 
# LSTM输出维度为 128
# input_shape控制输入数据的形态
# time_stemp表示一句话有多少个单词
# word_len 表示一个单词用多少维度表示,这里是26维
 
model.add(LSTM(128, input_shape=(time_step, word_len)))
model.add(Dense(word_len, activation='softmax')) #输出用一个softmax,来分类,维度就是26,预测是哪一个字母
 
model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
 
model.fit_generator(generator=genarate_data(batch_size=5, genarate_num=-1), epochs=50, steps_per_epoch=10)
#steps_per_epoch的意思是,一个epoch中,执行多少个batch
#batch_size是一个batch中,有多少个样本。
#所以,batch_size*steps_per_epoch就等于一个epoch中,训练的样本数量。(这个说法不对!再观察看看吧)
#可以将epochs设置成1,或者2,然后在genarate_data中打印样本序号,观察到样本总数。

使用训练后的模型进行预测:

result = model.predict(np.array([words_2_one_hot('bcd')]))

print(one_hot_2_words(result))

可以看到,预测结果为

e

补充知识:训练集产生的onehot编码特征如何在测试集、预测集复现

数据处理中有时要用到onehot编码,如果使用pandas自带的get_dummies方法,训练集产生的onehot编码特征会跟测试集、预测集不一样,正确的方式是使用sklearn自带的OneHotEncoder。

代码

import pandas as pd
from sklearn.preprocessing import OneHotEncoder
ohe = OneHotEncoder(handle_unknown='ignore')
data_train=pd.DataFrame({'职业':['数据挖掘工程师','数据库开发工程师','数据分析师','数据分析师'],
     '籍贯':['福州','厦门','泉州','龙岩']})
ohe.fit(data_train)#训练规则
feature_names=ohe.get_feature_names(data_train.columns)#获取编码后的特征名
data_train_onehot=pd.DataFrame(ohe.transform(data_train).toarray(),columns=feature_names)#应用规则在训练集上
 
data_new=pd.DataFrame({'职业':['数据挖掘工程师','jave工程师'],
     '籍贯':['福州','莆田']})
data_new_onehot=pd.DataFrame(ohe.transform(data_new).toarray(),columns=feature_names)#应用规则在预测集上

以上这篇keras 简单 lstm实例(基于one-hot编码)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python计算时间差的方法
May 20 Python
Python代码解决RenderView窗口not found问题
Aug 28 Python
Python模拟脉冲星伪信号频率实例代码
Jan 03 Python
用pandas中的DataFrame时选取行或列的方法
Jul 11 Python
pycharm在调试python时执行其他语句的方法
Nov 29 Python
详解Python Matplot中文显示完美解决方案
Mar 07 Python
使用Python完成15位18位身份证的互转功能
Nov 06 Python
如何基于python实现画不同品种的樱花树
Jan 03 Python
利用Python制作动态排名图的实现代码
Apr 09 Python
Python importlib动态导入模块实现代码
Apr 16 Python
python 动态渲染 mysql 配置文件的示例
Nov 20 Python
新手必备Python开发环境搭建教程
May 28 Python
Python装饰器结合递归原理解析
Jul 02 #Python
Python OpenCV读取中文路径图像的方法
Jul 02 #Python
keras.utils.to_categorical和one hot格式解析
Jul 02 #Python
python 使用多线程创建一个Buffer缓存器的实现思路
Jul 02 #Python
浅谈keras中的keras.utils.to_categorical用法
Jul 02 #Python
Python使用OpenPyXL处理Excel表格
Jul 02 #Python
解决keras GAN训练是loss不发生变化,accuracy一直为0.5的问题
Jul 02 #Python
You might like
超级好用的一个php上传图片类(随机名,缩略图,加水印)
2010/06/30 PHP
php XPath对XML文件查找及修改实现代码
2011/07/27 PHP
探讨:array2xml和xml2array以及xml与array的互相转化
2013/06/24 PHP
详解PHP的Yii框架中组件行为的属性注入和方法注入
2016/03/18 PHP
php如何获取Http请求
2020/04/30 PHP
[原创]网络复制内容时常用的正则+editplus
2006/11/30 Javascript
javascript算法学习(直接插入排序)
2011/04/12 Javascript
JS动态改变浏览器标题的方法
2016/04/06 Javascript
原生js制作日历控件实例分享
2016/04/06 Javascript
JS触摸屏网页版仿app弹窗型滚动列表选择器/日期选择器
2016/10/30 Javascript
基于JS实现移动端向左滑动出现删除按钮功能
2017/02/22 Javascript
Vue.js实战之组件之间的数据传递
2017/04/01 Javascript
vue 组件中slot插口的具体用法
2018/04/03 Javascript
jQuery实现的电子时钟效果完整示例
2018/04/28 jQuery
微信小程序项目实践之主页tab选项实现
2018/07/18 Javascript
vue项目每30秒刷新1次接口的实现方法
2018/12/04 Javascript
vuex 多模块时 模块内部的mutation和action的调用方式
2020/07/24 Javascript
[00:55]2015国际邀请赛中国区预选赛5月23日——28日约战上海
2015/05/25 DOTA
[01:03:31]DOTA2上海特级锦标赛B组资格赛#1 Alliance VS Fnatic第二局
2016/02/26 DOTA
python实现数组插入新元素的方法
2015/05/22 Python
python实现图像识别功能
2018/01/29 Python
PyCharm安装第三方库如Requests的图文教程
2018/05/18 Python
Python绘制的二项分布概率图示例
2018/08/22 Python
python utc datetime转换为时间戳的方法
2019/01/15 Python
keras 特征图可视化实例(中间层)
2020/01/24 Python
Python selenium文件上传下载功能代码实例
2020/04/13 Python
Python fileinput模块如何逐行读取多个文件
2020/10/05 Python
Html5剪切板功能的实现代码
2018/06/29 HTML / CSS
利用html5 canvas动态画饼状图的示例代码
2018/04/02 HTML / CSS
Bugatchi官方网站:男士服装在线
2019/04/10 全球购物
如何利用cmp命令比较文件
2016/04/11 面试题
小摄影师教学反思
2014/04/27 职场文书
消费者投诉书范文
2015/07/02 职场文书
2016年师德学习心得体会
2016/01/12 职场文书
Vue3中toRef与toRefs的区别
2022/03/24 Vue.js
Python中npy和mat文件的保存与读取
2022/04/24 Python