keras 简单 lstm实例(基于one-hot编码)


Posted in Python onJuly 02, 2020

简单的LSTM问题,能够预测一句话的下一个字词是什么

固定长度的句子,一个句子有3个词。

使用one-hot编码

各种引用

import keras
from keras.models import Sequential
from keras.layers import LSTM, Dense, Dropout
import numpy as np

数据预处理

data = 'abcdefghijklmnopqrstuvwxyz'
data_set = set(data)
 
word_2_int = {b:a for a,b in enumerate(data_set)}
int_2_word = {a:b for a,b in enumerate(data_set)}
 
word_len = len(data_set)
print(word_2_int)
print(int_2_word)

一些辅助函数

def words_2_ints(words):
 ints = []
 for itmp in words:
  ints.append(word_2_int[itmp])
 return ints
 
print(words_2_ints('ab'))
 
def words_2_one_hot(words, num_classes=word_len):
 return keras.utils.to_categorical(words_2_ints(words), num_classes=num_classes)
print(words_2_one_hot('a'))
def get_one_hot_max_idx(one_hot):
 idx_ = 0
 max_ = 0
 for i in range(len(one_hot)):
  if max_ < one_hot[i]:
   max_ = one_hot[i]
   idx_ = i
 return idx_
 
def one_hot_2_words(one_hot):
 tmp = []
 for itmp in one_hot:
  tmp.append(int_2_word[get_one_hot_max_idx(itmp)])
 return "".join(tmp)
 
print( one_hot_2_words(words_2_one_hot('adhjlkw')) )

构造样本

time_step = 3 #一个句子有3个词
 
def genarate_data(batch_size=5, genarate_num=100):
 #genarate_num = -1 表示一直循环下去,genarate_num=1表示生成一个batch的数据,以此类推
 #这里,我也不知道数据有多少,就这么循环的生成下去吧。
 #入参batch_size 控制一个batch 有多少数据,也就是一次要yield进多少个batch_size的数据
 '''
 例如,一个batch有batch_size=5个样本,那么对于这个例子,需要yield进的数据为:
 abc->d
 bcd->e
 cde->f
 def->g
 efg->h
 然后把这些数据都转换成one-hot形式,最终数据,输入x的形式为:
 
 [第1个batch]
 [第2个batch]
 ...
 [第genarate_num个batch]
 
 每个batch的形式为:
 
 [第1句话(如abc)]
 [第2句话(如bcd)]
 ...
 每一句话的形式为:
 
 [第1个词的one-hot表示]
 [第2个词的one-hot表示]
 ...
 '''
 cnt = 0
 batch_x = []
 batch_y = []
 sample_num = 0
 while(True):
  for i in range(len(data) - time_step):
   batch_x.append(words_2_one_hot(data[i : i+time_step]))
   batch_y.append(words_2_one_hot(data[i+time_step])[0]) #这里数据加[0],是为了符合keras的输出数据格式。 因为不加[0],表示是3维的数据。 你可以自己尝试不加0,看下面的test打印出来是什么
   sample_num += 1
   #print('sample num is :', sample_num)
   if len(batch_x) == batch_size:
    yield (np.array(batch_x), np.array(batch_y))
    batch_x = []
    batch_y = []
    if genarate_num != -1:
     cnt += 1
 
    if cnt == genarate_num:
     return
   
for test in genarate_data(batch_size=3, genarate_num=1):
 print('--------x:')
 print(test[0])
 print('--------y:')
 print(test[1])

搭建模型并训练

model = Sequential()
 
# LSTM输出维度为 128
# input_shape控制输入数据的形态
# time_stemp表示一句话有多少个单词
# word_len 表示一个单词用多少维度表示,这里是26维
 
model.add(LSTM(128, input_shape=(time_step, word_len)))
model.add(Dense(word_len, activation='softmax')) #输出用一个softmax,来分类,维度就是26,预测是哪一个字母
 
model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
 
model.fit_generator(generator=genarate_data(batch_size=5, genarate_num=-1), epochs=50, steps_per_epoch=10)
#steps_per_epoch的意思是,一个epoch中,执行多少个batch
#batch_size是一个batch中,有多少个样本。
#所以,batch_size*steps_per_epoch就等于一个epoch中,训练的样本数量。(这个说法不对!再观察看看吧)
#可以将epochs设置成1,或者2,然后在genarate_data中打印样本序号,观察到样本总数。

使用训练后的模型进行预测:

result = model.predict(np.array([words_2_one_hot('bcd')]))

print(one_hot_2_words(result))

可以看到,预测结果为

e

补充知识:训练集产生的onehot编码特征如何在测试集、预测集复现

数据处理中有时要用到onehot编码,如果使用pandas自带的get_dummies方法,训练集产生的onehot编码特征会跟测试集、预测集不一样,正确的方式是使用sklearn自带的OneHotEncoder。

代码

import pandas as pd
from sklearn.preprocessing import OneHotEncoder
ohe = OneHotEncoder(handle_unknown='ignore')
data_train=pd.DataFrame({'职业':['数据挖掘工程师','数据库开发工程师','数据分析师','数据分析师'],
     '籍贯':['福州','厦门','泉州','龙岩']})
ohe.fit(data_train)#训练规则
feature_names=ohe.get_feature_names(data_train.columns)#获取编码后的特征名
data_train_onehot=pd.DataFrame(ohe.transform(data_train).toarray(),columns=feature_names)#应用规则在训练集上
 
data_new=pd.DataFrame({'职业':['数据挖掘工程师','jave工程师'],
     '籍贯':['福州','莆田']})
data_new_onehot=pd.DataFrame(ohe.transform(data_new).toarray(),columns=feature_names)#应用规则在预测集上

以上这篇keras 简单 lstm实例(基于one-hot编码)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python通过colorama模块在控制台输出彩色文字的方法
Mar 19 Python
在Django框架中运行Python应用全攻略
Jul 17 Python
详解Python中的文件操作
Aug 28 Python
Python实现监控键盘鼠标操作示例【基于pyHook与pythoncom模块】
Sep 04 Python
python 自定义异常和异常捕捉的方法
Oct 18 Python
Python实现监控Nginx配置文件的不同并发送邮件报警功能示例
Feb 26 Python
python3实现表白神器
Apr 09 Python
Django学习笔记之为Model添加Action
Apr 30 Python
通过python扫描二维码/条形码并打印数据
Nov 14 Python
TensorFlow tf.nn.conv2d实现卷积的方式
Jan 03 Python
Python利用matplotlib绘制折线图的新手教程
Nov 05 Python
python 爬取腾讯视频评论的实现步骤
Feb 18 Python
Python装饰器结合递归原理解析
Jul 02 #Python
Python OpenCV读取中文路径图像的方法
Jul 02 #Python
keras.utils.to_categorical和one hot格式解析
Jul 02 #Python
python 使用多线程创建一个Buffer缓存器的实现思路
Jul 02 #Python
浅谈keras中的keras.utils.to_categorical用法
Jul 02 #Python
Python使用OpenPyXL处理Excel表格
Jul 02 #Python
解决keras GAN训练是loss不发生变化,accuracy一直为0.5的问题
Jul 02 #Python
You might like
Codeigniter框架的更新事务(transaction)BUG及解决方法
2014/07/25 PHP
基于PHP实现的事件机制实例分析
2015/06/18 PHP
浅谈php中include文件变量作用域
2015/06/18 PHP
微信利用PHP创建自定义菜单的方法
2016/08/01 PHP
PHP笛卡尔积实现算法示例
2018/07/30 PHP
Javascript Object.extend
2010/05/18 Javascript
extjs关于treePanel+chekBox全部选中以及清空选中问题探讨
2013/04/02 Javascript
jquery通过load获取文件的内容并跳到锚点的方法
2015/01/29 Javascript
BootStrap轻松实现微信页面开发代码分享
2016/10/21 Javascript
禁用backspace网页回退功能的实现代码
2016/11/15 Javascript
使用jsonp实现跨域获取数据实例讲解
2016/12/25 Javascript
Bootstrap Multiselect 常用组件实现代码
2017/07/09 Javascript
Angular之toDoList的实现代码示例
2017/12/02 Javascript
vue.js或js实现中文A-Z排序的方法
2018/03/08 Javascript
在vue中使用css modules替代scroped的方法
2018/03/10 Javascript
在小程序/mpvue中使用flyio发起网络请求的方法
2018/09/13 Javascript
小程序扫描普通链接二维码跳转小程序指定界面方法
2019/05/07 Javascript
浅谈Vue.js之初始化el以及数据的绑定说明
2019/11/14 Javascript
Tensorflow 合并通道及加载子模型的方法
2018/07/26 Python
python绘制已知点的坐标的直线实例
2019/07/04 Python
python写程序统计词频的方法
2019/07/29 Python
Python中join()函数多种操作代码实例
2020/01/13 Python
在TensorFlow中屏蔽warning的方式
2020/02/04 Python
Python求凸包及多边形面积教程
2020/04/12 Python
python使用建议与技巧分享(二)
2020/08/17 Python
使用python tkinter开发一个爬取B站直播弹幕工具的实现代码
2021/02/07 Python
西雅图的买手店:Totokaelo
2019/10/19 全球购物
广告学专业推荐信范文
2013/11/23 职场文书
军训自我鉴定
2014/01/22 职场文书
入党积极分子自我鉴定范文
2014/03/25 职场文书
询价采购方案
2014/06/09 职场文书
中国梦演讲稿3分钟
2014/08/19 职场文书
食品安全承诺书范文
2014/08/29 职场文书
交通事故和解协议书
2015/01/27 职场文书
(开源)微信小程序+mqtt,esp8266温湿度读取
2021/04/02 Javascript
mysql数据库入门第一步之创建表
2021/05/14 MySQL