keras 简单 lstm实例(基于one-hot编码)


Posted in Python onJuly 02, 2020

简单的LSTM问题,能够预测一句话的下一个字词是什么

固定长度的句子,一个句子有3个词。

使用one-hot编码

各种引用

import keras
from keras.models import Sequential
from keras.layers import LSTM, Dense, Dropout
import numpy as np

数据预处理

data = 'abcdefghijklmnopqrstuvwxyz'
data_set = set(data)
 
word_2_int = {b:a for a,b in enumerate(data_set)}
int_2_word = {a:b for a,b in enumerate(data_set)}
 
word_len = len(data_set)
print(word_2_int)
print(int_2_word)

一些辅助函数

def words_2_ints(words):
 ints = []
 for itmp in words:
  ints.append(word_2_int[itmp])
 return ints
 
print(words_2_ints('ab'))
 
def words_2_one_hot(words, num_classes=word_len):
 return keras.utils.to_categorical(words_2_ints(words), num_classes=num_classes)
print(words_2_one_hot('a'))
def get_one_hot_max_idx(one_hot):
 idx_ = 0
 max_ = 0
 for i in range(len(one_hot)):
  if max_ < one_hot[i]:
   max_ = one_hot[i]
   idx_ = i
 return idx_
 
def one_hot_2_words(one_hot):
 tmp = []
 for itmp in one_hot:
  tmp.append(int_2_word[get_one_hot_max_idx(itmp)])
 return "".join(tmp)
 
print( one_hot_2_words(words_2_one_hot('adhjlkw')) )

构造样本

time_step = 3 #一个句子有3个词
 
def genarate_data(batch_size=5, genarate_num=100):
 #genarate_num = -1 表示一直循环下去,genarate_num=1表示生成一个batch的数据,以此类推
 #这里,我也不知道数据有多少,就这么循环的生成下去吧。
 #入参batch_size 控制一个batch 有多少数据,也就是一次要yield进多少个batch_size的数据
 '''
 例如,一个batch有batch_size=5个样本,那么对于这个例子,需要yield进的数据为:
 abc->d
 bcd->e
 cde->f
 def->g
 efg->h
 然后把这些数据都转换成one-hot形式,最终数据,输入x的形式为:
 
 [第1个batch]
 [第2个batch]
 ...
 [第genarate_num个batch]
 
 每个batch的形式为:
 
 [第1句话(如abc)]
 [第2句话(如bcd)]
 ...
 每一句话的形式为:
 
 [第1个词的one-hot表示]
 [第2个词的one-hot表示]
 ...
 '''
 cnt = 0
 batch_x = []
 batch_y = []
 sample_num = 0
 while(True):
  for i in range(len(data) - time_step):
   batch_x.append(words_2_one_hot(data[i : i+time_step]))
   batch_y.append(words_2_one_hot(data[i+time_step])[0]) #这里数据加[0],是为了符合keras的输出数据格式。 因为不加[0],表示是3维的数据。 你可以自己尝试不加0,看下面的test打印出来是什么
   sample_num += 1
   #print('sample num is :', sample_num)
   if len(batch_x) == batch_size:
    yield (np.array(batch_x), np.array(batch_y))
    batch_x = []
    batch_y = []
    if genarate_num != -1:
     cnt += 1
 
    if cnt == genarate_num:
     return
   
for test in genarate_data(batch_size=3, genarate_num=1):
 print('--------x:')
 print(test[0])
 print('--------y:')
 print(test[1])

搭建模型并训练

model = Sequential()
 
# LSTM输出维度为 128
# input_shape控制输入数据的形态
# time_stemp表示一句话有多少个单词
# word_len 表示一个单词用多少维度表示,这里是26维
 
model.add(LSTM(128, input_shape=(time_step, word_len)))
model.add(Dense(word_len, activation='softmax')) #输出用一个softmax,来分类,维度就是26,预测是哪一个字母
 
model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
 
model.fit_generator(generator=genarate_data(batch_size=5, genarate_num=-1), epochs=50, steps_per_epoch=10)
#steps_per_epoch的意思是,一个epoch中,执行多少个batch
#batch_size是一个batch中,有多少个样本。
#所以,batch_size*steps_per_epoch就等于一个epoch中,训练的样本数量。(这个说法不对!再观察看看吧)
#可以将epochs设置成1,或者2,然后在genarate_data中打印样本序号,观察到样本总数。

使用训练后的模型进行预测:

result = model.predict(np.array([words_2_one_hot('bcd')]))

print(one_hot_2_words(result))

可以看到,预测结果为

e

补充知识:训练集产生的onehot编码特征如何在测试集、预测集复现

数据处理中有时要用到onehot编码,如果使用pandas自带的get_dummies方法,训练集产生的onehot编码特征会跟测试集、预测集不一样,正确的方式是使用sklearn自带的OneHotEncoder。

代码

import pandas as pd
from sklearn.preprocessing import OneHotEncoder
ohe = OneHotEncoder(handle_unknown='ignore')
data_train=pd.DataFrame({'职业':['数据挖掘工程师','数据库开发工程师','数据分析师','数据分析师'],
     '籍贯':['福州','厦门','泉州','龙岩']})
ohe.fit(data_train)#训练规则
feature_names=ohe.get_feature_names(data_train.columns)#获取编码后的特征名
data_train_onehot=pd.DataFrame(ohe.transform(data_train).toarray(),columns=feature_names)#应用规则在训练集上
 
data_new=pd.DataFrame({'职业':['数据挖掘工程师','jave工程师'],
     '籍贯':['福州','莆田']})
data_new_onehot=pd.DataFrame(ohe.transform(data_new).toarray(),columns=feature_names)#应用规则在预测集上

以上这篇keras 简单 lstm实例(基于one-hot编码)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python多线程编程(三):threading.Thread类的重要函数和方法
Apr 05 Python
python 判断是否为正小数和正整数的实例
Jul 23 Python
浅谈Python2获取中文文件名的编码问题
Jan 09 Python
基于DATAFRAME中元素的读取与修改方法
Jun 08 Python
在python中使用with打开多个文件的方法
Jan 07 Python
python的几种矩阵相乘的公式详解
Jul 10 Python
Python字典常见操作实例小结【定义、添加、删除、遍历】
Oct 25 Python
django实现web接口 python3模拟Post请求方式
Nov 19 Python
Python使用matplotlib绘制Logistic曲线操作示例
Nov 28 Python
对pytorch的函数中的group参数的作用介绍
Feb 18 Python
pycharm设置python文件模板信息过程图解
Mar 10 Python
如何利用Python 进行边缘检测
Oct 14 Python
Python装饰器结合递归原理解析
Jul 02 #Python
Python OpenCV读取中文路径图像的方法
Jul 02 #Python
keras.utils.to_categorical和one hot格式解析
Jul 02 #Python
python 使用多线程创建一个Buffer缓存器的实现思路
Jul 02 #Python
浅谈keras中的keras.utils.to_categorical用法
Jul 02 #Python
Python使用OpenPyXL处理Excel表格
Jul 02 #Python
解决keras GAN训练是loss不发生变化,accuracy一直为0.5的问题
Jul 02 #Python
You might like
php输出含有“#”字符串的方法
2017/01/18 PHP
php获取文件名称和扩展名的方法
2017/02/07 PHP
javascript下function声明一些小结
2007/12/28 Javascript
iphone safari不支持position fixed的解决方法
2012/05/04 Javascript
jquery 获取自定义属性(attr和prop)的实现代码
2012/06/27 Javascript
js中用window.open()打开多个窗口的name问题
2014/03/13 Javascript
JavaScript字符串对象toLowerCase方法入门实例(用于把字母转换为小写)
2014/10/17 Javascript
自己封装的常用javascript函数分享
2015/01/07 Javascript
使用js画图之饼图
2015/01/12 Javascript
Bootstrap基本组件学习笔记之导航(10)
2016/12/07 Javascript
在一个页面重复使用一个js函数的方法详解
2016/12/26 Javascript
vue2.0 keep-alive最佳实践
2017/07/06 Javascript
vue实现模态框的通用写法推荐
2018/02/26 Javascript
vue 实现边输入边搜索功能的实例讲解
2018/09/16 Javascript
vsCode安装使用教程和插件安装方法
2020/08/24 Javascript
Vue项目中使用better-scroll实现菜单映射功能方法
2019/09/11 Javascript
layui-tree实现Ajax异步请求后动态添加节点的方法
2019/09/23 Javascript
Python中的深拷贝和浅拷贝详解
2015/06/03 Python
Python去除、替换字符串空格的处理方法
2018/04/01 Python
python 一篇文章搞懂装饰器所有用法(建议收藏)
2019/08/23 Python
基于Django框架的权限组件rbac实例讲解
2019/08/31 Python
Python FtpLib模块应用操作详解
2019/12/12 Python
基于python3抓取pinpoint应用信息入库
2020/01/08 Python
详解Django关于StreamingHttpResponse与FileResponse文件下载的最优方法
2021/01/07 Python
拉丁舞学习者的自我评价
2013/10/27 职场文书
财务支持类个人的自我评价
2014/02/14 职场文书
煤矿班组长竞聘书
2014/03/31 职场文书
个人作风建设剖析材料
2014/10/11 职场文书
个人政风行风自查自纠报告
2014/10/21 职场文书
硕士毕业论文导师评语
2014/12/31 职场文书
2015年幼儿园安全工作总结
2015/05/12 职场文书
毕业欢送会致辞
2015/07/29 职场文书
教师信息技术学习心得体会
2016/01/21 职场文书
2016年六一文艺汇演开幕词
2016/03/04 职场文书
解决Python中的modf()函数取小数部分不准确问题
2021/05/28 Python
Python Pandas数据分析之iloc和loc的用法详解
2021/11/11 Python