keras 简单 lstm实例(基于one-hot编码)


Posted in Python onJuly 02, 2020

简单的LSTM问题,能够预测一句话的下一个字词是什么

固定长度的句子,一个句子有3个词。

使用one-hot编码

各种引用

import keras
from keras.models import Sequential
from keras.layers import LSTM, Dense, Dropout
import numpy as np

数据预处理

data = 'abcdefghijklmnopqrstuvwxyz'
data_set = set(data)
 
word_2_int = {b:a for a,b in enumerate(data_set)}
int_2_word = {a:b for a,b in enumerate(data_set)}
 
word_len = len(data_set)
print(word_2_int)
print(int_2_word)

一些辅助函数

def words_2_ints(words):
 ints = []
 for itmp in words:
  ints.append(word_2_int[itmp])
 return ints
 
print(words_2_ints('ab'))
 
def words_2_one_hot(words, num_classes=word_len):
 return keras.utils.to_categorical(words_2_ints(words), num_classes=num_classes)
print(words_2_one_hot('a'))
def get_one_hot_max_idx(one_hot):
 idx_ = 0
 max_ = 0
 for i in range(len(one_hot)):
  if max_ < one_hot[i]:
   max_ = one_hot[i]
   idx_ = i
 return idx_
 
def one_hot_2_words(one_hot):
 tmp = []
 for itmp in one_hot:
  tmp.append(int_2_word[get_one_hot_max_idx(itmp)])
 return "".join(tmp)
 
print( one_hot_2_words(words_2_one_hot('adhjlkw')) )

构造样本

time_step = 3 #一个句子有3个词
 
def genarate_data(batch_size=5, genarate_num=100):
 #genarate_num = -1 表示一直循环下去,genarate_num=1表示生成一个batch的数据,以此类推
 #这里,我也不知道数据有多少,就这么循环的生成下去吧。
 #入参batch_size 控制一个batch 有多少数据,也就是一次要yield进多少个batch_size的数据
 '''
 例如,一个batch有batch_size=5个样本,那么对于这个例子,需要yield进的数据为:
 abc->d
 bcd->e
 cde->f
 def->g
 efg->h
 然后把这些数据都转换成one-hot形式,最终数据,输入x的形式为:
 
 [第1个batch]
 [第2个batch]
 ...
 [第genarate_num个batch]
 
 每个batch的形式为:
 
 [第1句话(如abc)]
 [第2句话(如bcd)]
 ...
 每一句话的形式为:
 
 [第1个词的one-hot表示]
 [第2个词的one-hot表示]
 ...
 '''
 cnt = 0
 batch_x = []
 batch_y = []
 sample_num = 0
 while(True):
  for i in range(len(data) - time_step):
   batch_x.append(words_2_one_hot(data[i : i+time_step]))
   batch_y.append(words_2_one_hot(data[i+time_step])[0]) #这里数据加[0],是为了符合keras的输出数据格式。 因为不加[0],表示是3维的数据。 你可以自己尝试不加0,看下面的test打印出来是什么
   sample_num += 1
   #print('sample num is :', sample_num)
   if len(batch_x) == batch_size:
    yield (np.array(batch_x), np.array(batch_y))
    batch_x = []
    batch_y = []
    if genarate_num != -1:
     cnt += 1
 
    if cnt == genarate_num:
     return
   
for test in genarate_data(batch_size=3, genarate_num=1):
 print('--------x:')
 print(test[0])
 print('--------y:')
 print(test[1])

搭建模型并训练

model = Sequential()
 
# LSTM输出维度为 128
# input_shape控制输入数据的形态
# time_stemp表示一句话有多少个单词
# word_len 表示一个单词用多少维度表示,这里是26维
 
model.add(LSTM(128, input_shape=(time_step, word_len)))
model.add(Dense(word_len, activation='softmax')) #输出用一个softmax,来分类,维度就是26,预测是哪一个字母
 
model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
 
model.fit_generator(generator=genarate_data(batch_size=5, genarate_num=-1), epochs=50, steps_per_epoch=10)
#steps_per_epoch的意思是,一个epoch中,执行多少个batch
#batch_size是一个batch中,有多少个样本。
#所以,batch_size*steps_per_epoch就等于一个epoch中,训练的样本数量。(这个说法不对!再观察看看吧)
#可以将epochs设置成1,或者2,然后在genarate_data中打印样本序号,观察到样本总数。

使用训练后的模型进行预测:

result = model.predict(np.array([words_2_one_hot('bcd')]))

print(one_hot_2_words(result))

可以看到,预测结果为

e

补充知识:训练集产生的onehot编码特征如何在测试集、预测集复现

数据处理中有时要用到onehot编码,如果使用pandas自带的get_dummies方法,训练集产生的onehot编码特征会跟测试集、预测集不一样,正确的方式是使用sklearn自带的OneHotEncoder。

代码

import pandas as pd
from sklearn.preprocessing import OneHotEncoder
ohe = OneHotEncoder(handle_unknown='ignore')
data_train=pd.DataFrame({'职业':['数据挖掘工程师','数据库开发工程师','数据分析师','数据分析师'],
     '籍贯':['福州','厦门','泉州','龙岩']})
ohe.fit(data_train)#训练规则
feature_names=ohe.get_feature_names(data_train.columns)#获取编码后的特征名
data_train_onehot=pd.DataFrame(ohe.transform(data_train).toarray(),columns=feature_names)#应用规则在训练集上
 
data_new=pd.DataFrame({'职业':['数据挖掘工程师','jave工程师'],
     '籍贯':['福州','莆田']})
data_new_onehot=pd.DataFrame(ohe.transform(data_new).toarray(),columns=feature_names)#应用规则在预测集上

以上这篇keras 简单 lstm实例(基于one-hot编码)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详谈在flask中使用jsonify和json.dumps的区别
Mar 26 Python
Python实现的端口扫描功能示例
Apr 08 Python
opencv python统计及绘制直方图的方法
Jan 21 Python
Python使用sqlalchemy模块连接数据库操作示例
Mar 13 Python
python 的 scapy库,实现网卡收发包的例子
Jul 23 Python
Django REST framework 如何实现内置访问频率控制
Jul 23 Python
详解Python图像处理库Pillow常用使用方法
Sep 02 Python
PHP统计代码行数的小代码
Sep 19 Python
详解centos7+django+python3+mysql+阿里云部署项目全流程
Nov 15 Python
基于Python下载网络图片方法汇总代码实例
Jun 24 Python
Python获取百度热搜的完整代码
Apr 07 Python
python垃圾回收机制原理分析
Apr 13 Python
Python装饰器结合递归原理解析
Jul 02 #Python
Python OpenCV读取中文路径图像的方法
Jul 02 #Python
keras.utils.to_categorical和one hot格式解析
Jul 02 #Python
python 使用多线程创建一个Buffer缓存器的实现思路
Jul 02 #Python
浅谈keras中的keras.utils.to_categorical用法
Jul 02 #Python
Python使用OpenPyXL处理Excel表格
Jul 02 #Python
解决keras GAN训练是loss不发生变化,accuracy一直为0.5的问题
Jul 02 #Python
You might like
PHP 页面编码声明方法详解(header或meta)
2010/03/12 PHP
drupal 代码实现URL重写
2011/05/04 PHP
php摘要生成函数(无乱码)
2012/02/04 PHP
CentOS下与Apache连接的PHP多版本共存方案实现详解
2015/12/19 PHP
php7函数,声明,返回值等新特性介绍
2018/05/25 PHP
PHP获取日期对应星期、一周日期、星期开始与结束日期的方法
2018/06/22 PHP
js常用函数 不错
2006/09/08 Javascript
关闭浏览器窗口弹出提示框并且可以控制其失效
2014/04/15 Javascript
chrome下jq width()方法取值为0的解决方法
2014/05/26 Javascript
js实现从数组里随机获取元素
2015/01/12 Javascript
Nginx上传文件全部缓存解决方案
2015/08/17 Javascript
WordPress中鼠标悬停显示和隐藏评论及引用按钮的实现
2016/01/12 Javascript
第十章之巨幕页头缩略图与警告框组件
2016/04/25 Javascript
JS字符串长度判断,超出进行自动截取的实例(支持中文)
2017/03/06 Javascript
vue 2.0组件与v-model详解
2017/03/27 Javascript
vue.js2.0点击获取自己的属性和jquery方法
2018/02/23 jQuery
JavaScript笛卡尔积超简单实现算法示例
2018/07/30 Javascript
使用mixins实现elementUI表单全局验证的解决方法
2019/04/02 Javascript
js 根据对象数组中的属性进行排序实现代码
2019/09/12 Javascript
刷新页面后让控制台的js代码继续执行
2019/09/20 Javascript
Vue.js自定义指令学习使用详解
2019/10/19 Javascript
jquery制作的移动端购物车效果完整示例
2020/02/24 jQuery
JS+CSS实现动态时钟
2021/02/19 Javascript
[02:12]2019完美世界全国高校联赛(春季赛)报名开启
2019/03/01 DOTA
python实现探测socket和web服务示例
2014/03/28 Python
python抓取网站的图片并下载到本地的方法
2018/05/22 Python
python中的colorlog库使用详解
2019/07/05 Python
Python3 集合set入门基础
2020/02/10 Python
python实现飞机大战项目
2020/03/11 Python
python 偷懒技巧——使用 keyboard 录制键盘事件
2020/09/21 Python
美国韩国化妆品和护肤品购物网站:Beautytap
2018/07/29 全球购物
大学生村官个人对照检查材料(群众路线)
2014/09/26 职场文书
领导班子四风问题对照检查材料
2014/09/27 职场文书
派出所副所长四风问题个人整改措施思想汇报
2014/10/13 职场文书
创业计划书之个人工作室
2019/08/22 职场文书
Python爬虫之自动爬取某车之家各车销售数据
2021/06/02 Python