keras 简单 lstm实例(基于one-hot编码)


Posted in Python onJuly 02, 2020

简单的LSTM问题,能够预测一句话的下一个字词是什么

固定长度的句子,一个句子有3个词。

使用one-hot编码

各种引用

import keras
from keras.models import Sequential
from keras.layers import LSTM, Dense, Dropout
import numpy as np

数据预处理

data = 'abcdefghijklmnopqrstuvwxyz'
data_set = set(data)
 
word_2_int = {b:a for a,b in enumerate(data_set)}
int_2_word = {a:b for a,b in enumerate(data_set)}
 
word_len = len(data_set)
print(word_2_int)
print(int_2_word)

一些辅助函数

def words_2_ints(words):
 ints = []
 for itmp in words:
  ints.append(word_2_int[itmp])
 return ints
 
print(words_2_ints('ab'))
 
def words_2_one_hot(words, num_classes=word_len):
 return keras.utils.to_categorical(words_2_ints(words), num_classes=num_classes)
print(words_2_one_hot('a'))
def get_one_hot_max_idx(one_hot):
 idx_ = 0
 max_ = 0
 for i in range(len(one_hot)):
  if max_ < one_hot[i]:
   max_ = one_hot[i]
   idx_ = i
 return idx_
 
def one_hot_2_words(one_hot):
 tmp = []
 for itmp in one_hot:
  tmp.append(int_2_word[get_one_hot_max_idx(itmp)])
 return "".join(tmp)
 
print( one_hot_2_words(words_2_one_hot('adhjlkw')) )

构造样本

time_step = 3 #一个句子有3个词
 
def genarate_data(batch_size=5, genarate_num=100):
 #genarate_num = -1 表示一直循环下去,genarate_num=1表示生成一个batch的数据,以此类推
 #这里,我也不知道数据有多少,就这么循环的生成下去吧。
 #入参batch_size 控制一个batch 有多少数据,也就是一次要yield进多少个batch_size的数据
 '''
 例如,一个batch有batch_size=5个样本,那么对于这个例子,需要yield进的数据为:
 abc->d
 bcd->e
 cde->f
 def->g
 efg->h
 然后把这些数据都转换成one-hot形式,最终数据,输入x的形式为:
 
 [第1个batch]
 [第2个batch]
 ...
 [第genarate_num个batch]
 
 每个batch的形式为:
 
 [第1句话(如abc)]
 [第2句话(如bcd)]
 ...
 每一句话的形式为:
 
 [第1个词的one-hot表示]
 [第2个词的one-hot表示]
 ...
 '''
 cnt = 0
 batch_x = []
 batch_y = []
 sample_num = 0
 while(True):
  for i in range(len(data) - time_step):
   batch_x.append(words_2_one_hot(data[i : i+time_step]))
   batch_y.append(words_2_one_hot(data[i+time_step])[0]) #这里数据加[0],是为了符合keras的输出数据格式。 因为不加[0],表示是3维的数据。 你可以自己尝试不加0,看下面的test打印出来是什么
   sample_num += 1
   #print('sample num is :', sample_num)
   if len(batch_x) == batch_size:
    yield (np.array(batch_x), np.array(batch_y))
    batch_x = []
    batch_y = []
    if genarate_num != -1:
     cnt += 1
 
    if cnt == genarate_num:
     return
   
for test in genarate_data(batch_size=3, genarate_num=1):
 print('--------x:')
 print(test[0])
 print('--------y:')
 print(test[1])

搭建模型并训练

model = Sequential()
 
# LSTM输出维度为 128
# input_shape控制输入数据的形态
# time_stemp表示一句话有多少个单词
# word_len 表示一个单词用多少维度表示,这里是26维
 
model.add(LSTM(128, input_shape=(time_step, word_len)))
model.add(Dense(word_len, activation='softmax')) #输出用一个softmax,来分类,维度就是26,预测是哪一个字母
 
model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
 
model.fit_generator(generator=genarate_data(batch_size=5, genarate_num=-1), epochs=50, steps_per_epoch=10)
#steps_per_epoch的意思是,一个epoch中,执行多少个batch
#batch_size是一个batch中,有多少个样本。
#所以,batch_size*steps_per_epoch就等于一个epoch中,训练的样本数量。(这个说法不对!再观察看看吧)
#可以将epochs设置成1,或者2,然后在genarate_data中打印样本序号,观察到样本总数。

使用训练后的模型进行预测:

result = model.predict(np.array([words_2_one_hot('bcd')]))

print(one_hot_2_words(result))

可以看到,预测结果为

e

补充知识:训练集产生的onehot编码特征如何在测试集、预测集复现

数据处理中有时要用到onehot编码,如果使用pandas自带的get_dummies方法,训练集产生的onehot编码特征会跟测试集、预测集不一样,正确的方式是使用sklearn自带的OneHotEncoder。

代码

import pandas as pd
from sklearn.preprocessing import OneHotEncoder
ohe = OneHotEncoder(handle_unknown='ignore')
data_train=pd.DataFrame({'职业':['数据挖掘工程师','数据库开发工程师','数据分析师','数据分析师'],
     '籍贯':['福州','厦门','泉州','龙岩']})
ohe.fit(data_train)#训练规则
feature_names=ohe.get_feature_names(data_train.columns)#获取编码后的特征名
data_train_onehot=pd.DataFrame(ohe.transform(data_train).toarray(),columns=feature_names)#应用规则在训练集上
 
data_new=pd.DataFrame({'职业':['数据挖掘工程师','jave工程师'],
     '籍贯':['福州','莆田']})
data_new_onehot=pd.DataFrame(ohe.transform(data_new).toarray(),columns=feature_names)#应用规则在预测集上

以上这篇keras 简单 lstm实例(基于one-hot编码)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python内置函数——__import__ 的使用方法
Nov 24 Python
浅谈pandas中DataFrame关于显示值省略的解决方法
Apr 08 Python
使用python生成杨辉三角形的示例代码
Aug 29 Python
Python实现多级目录压缩与解压文件的方法
Sep 01 Python
详解用Python实现自动化监控远程服务器
May 18 Python
django创建简单的页面响应实例教程
Sep 06 Python
TensorFlow2.X使用图片制作简单的数据集训练模型
Apr 08 Python
jupyter notebook中新建cell的方法与快捷键操作
Apr 22 Python
Python使用20行代码实现微信聊天机器人
Jun 05 Python
Python中的With语句的使用及原理
Jul 29 Python
Python在字符串中处理html和xml的方法
Jul 31 Python
Python文件的操作示例的详细讲解
Apr 08 Python
Python装饰器结合递归原理解析
Jul 02 #Python
Python OpenCV读取中文路径图像的方法
Jul 02 #Python
keras.utils.to_categorical和one hot格式解析
Jul 02 #Python
python 使用多线程创建一个Buffer缓存器的实现思路
Jul 02 #Python
浅谈keras中的keras.utils.to_categorical用法
Jul 02 #Python
Python使用OpenPyXL处理Excel表格
Jul 02 #Python
解决keras GAN训练是loss不发生变化,accuracy一直为0.5的问题
Jul 02 #Python
You might like
全局记录程序片段的运行时间 正确找到程序逻辑耗时多的断点
2011/01/06 PHP
浅析PHP编程中10个最常见的错误
2014/08/08 PHP
php微信开发之批量生成带参数的二维码
2016/06/26 PHP
PHP pthreads v3下worker和pool的使用方法示例
2020/02/21 PHP
学习jquery必备 api中英文对照的chm手册 下载
2007/05/03 Javascript
jquery 双色表格实现代码
2009/12/08 Javascript
javascript设计模式 封装和信息隐藏(上)
2012/07/24 Javascript
jQuery学习笔记(3)--用jquery(插件)实现多选项卡功能
2013/04/08 Javascript
js判断浏览器类型为ie6时不执行
2014/06/15 Javascript
JavaScript获取页面中表单(form)数量的方法
2015/04/03 Javascript
js父页面中使用子页面的方法
2016/01/09 Javascript
详解jQuery lazyload 懒加载
2016/12/19 Javascript
微信小程序用户自定义模版用法实例分析
2017/11/28 Javascript
Vue+webpack项目配置便于维护的目录结构教程详解
2018/10/14 Javascript
Vue中el-form标签中的自定义el-select下拉框标签功能
2020/04/20 Javascript
JS中队列和双端队列实现及应用详解
2020/09/29 Javascript
Vue实现手机号、验证码登录(60s禁用倒计时)
2020/12/19 Vue.js
原生JavaScript实现进度条
2021/02/19 Javascript
[54:56]DOTA2上海特级锦标赛主赛事日 - 5 总决赛Liquid VS Secret第三局
2016/03/06 DOTA
[25:59]Newbee vs TNC 2018国际邀请赛小组赛BO2 第二场 8.16
2018/08/17 DOTA
Python标准库defaultdict模块使用示例
2015/04/28 Python
python计算牛顿迭代多项式实例分析
2015/05/07 Python
[原创]教女朋友学Python(一)运行环境搭建
2017/11/29 Python
对python3 Serial 串口助手的接收读取数据方法详解
2019/06/12 Python
python实现图像全景拼接
2020/03/27 Python
python:解析requests返回的response(json格式)说明
2020/04/30 Python
Python ADF 单位根检验 如何查看结果的实现
2020/06/03 Python
matplotlib 三维图表绘制方法简介
2020/09/20 Python
C++:memset ,memcpy和strcpy的根本区别
2013/04/27 面试题
函授毕业生的自我鉴定
2013/11/26 职场文书
信息专业大学生自我评价分享
2014/01/17 职场文书
家长通知书家长评语
2014/04/17 职场文书
工伤劳动仲裁代理词
2015/05/25 职场文书
小学庆六一主持词
2015/06/30 职场文书
解析laravel使用workerman用户交互、服务器交互
2021/04/28 PHP
用python实现监控视频人数统计
2021/05/21 Python