keras在构建LSTM模型时对变长序列的处理操作


Posted in Python onJune 29, 2020

我就废话不多说了,大家还是直接看代码吧~

print(np.shape(X))#(1920, 45, 20)
X=sequence.pad_sequences(X, maxlen=100, padding='post')
print(np.shape(X))#(1920, 100, 20)

model = Sequential()
model.add(Masking(mask_value=0,input_shape=(100,20)))
model.add(LSTM(128,dropout_W=0.5,dropout_U=0.5))
model.add(Dense(13,activation='softmax'))
model.compile(loss='categorical_crossentropy',
       optimizer='adam',
       metrics=['accuracy'])

# 用于保存验证集误差最小的参数,当验证集误差减少时,保存下来
checkpointer = ModelCheckpoint(filepath="keras_rnn.hdf5", verbose=1, save_best_only=True, )
history = LossHistory()
result = model.fit(X, Y, batch_size=10,
          nb_epoch=500, verbose=1, validation_data=(testX, testY),
          callbacks=[checkpointer, history])

model.save('keras_rnn_epochend.hdf5')

补充知识:RNN(LSTM)数据形式及Padding操作处理变长时序序列dynamic_rnn

Summary

RNN

样本一样,计算的状态值和输出结构一致,也即是说只要当前时刻的输入值也前一状态值一样,那么其当前状态值和当前输出结果一致,因为在当前这一轮训练中权重参数和偏置均未更新

RNN的最终状态值与最后一个时刻的输出值一致

输入数据要求格式为,shape=(batch_size, step_time_size, input_size),那么,state的shape=(batch_size, state_size);output的shape=(batch_size, step_time_size, state_size),并且最后一个有效输出(有效序列长度,不包括padding的部分)与状态值会一样

LSTM

LSTM与RNN基本一致,不同在于其状态有两个c_state和h_state,它们的shape一样,输出值output的最后一个有效输出与h_state一致

用变长RNN训练,要求其输入格式仍然要求为shape=(batch_size, step_time_size, input_size),但可指定每一个批次中各个样本的有效序列长度,这样在有效长度内其状态值和输出值原理不变,但超过有效长度的部分的状态值将不会发生改变,而输出值都将是shape=(state_size,)的零向量(注:RNN也是这个原理)

需要说明的是,不是因为无效序列长度部分全padding为0而引起输出全为0,状态不变,因为输出值和状态值得计算不仅依赖当前时刻的输入值,也依赖于上一时刻的状态值。其内部原理是利用一个mask matrix矩阵标记有效部分和无效部分,这样在无效部分就不用计算了,也就是说,这一部分不会造成反向传播时对参数的更新。当然,如果padding不是零,那么padding的这部分输出和状态同样与padding为零的结果是一样的

'''
#样本数据为(batch_size,time_step_size, input_size[embedding_size])的形式,其中samples=4,timesteps=3,features=3,其中第二个、第四个样本是只有一个时间步长和二个时间步长的,这里自动补零
'''
import pandas as pd
import numpy as np
import tensorflow as tf

train_X = np.array([[[0, 1, 2], [9, 8, 7], [3,6,8]], 
          [[3, 4, 5], [0, 10, 110], [0,0,0]], 
          [[6, 7, 8], [6, 5, 4], [1,7,4]], 
          [[9, 0, 1], [3, 7, 4], [0,0,0]],
          [[9, 0, 1], [3, 3, 4], [0,0,0]]
          ])
          
sequence_length = [3, 1, 3, 2, 2]

train_X.shape, train_X[:,2:3,:].reshape(5, 3)
tf.reset_default_graph()

x = tf.placeholder(tf.float32, shape=(None, 3, 3)) # 输入数据只需能够迭代并符合要求shape即可,list也行,shape不指定表示没有shape约束,任意shape均可
rnn_cell = tf.nn.rnn_cell.BasicRNNCell(num_units=6) # state_size[hidden_size]
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=6) # state_size[hidden_size]
outputs1, state1 = tf.nn.dynamic_rnn(rnn_cell, x, dtype=tf.float32, sequence_length=sequence_length)
outputs2, state2 = tf.nn.dynamic_rnn(lstm_cell, x, dtype=tf.float32, sequence_length=sequence_length)

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer()) # 初始化rnn_cell中参数变量
  outputs1, state1 = sess.run((outputs1, state1), feed_dict={x: train_X})
  outputs2, state2 = sess.run([outputs2, state2], feed_dict={x: train_X})
  print(outputs1.shape, state1.shape) # (4, 3, 5)->(batch_size, time_step_size, state_size), (4, 5)->(batch_size, state_size)
  print(outputs2.shape) # state2为LSTMStateTuple(c_state, h_state)
  print("---------output1<rnn>state1-----------")
  print(outputs1) # 可以看出output1的最后一个时刻的输出即为state1, 即output1[:,-1,:]与state1相等
  print(state1)
  print(np.all(outputs1[:,-1,:] == state1))
  print("---------output2<lstm>state2-----------")
  print(outputs2) # 可以看出output2的最后一个时刻的输出即为LSTMStateTuple中的h
  print(state2)
  print(np.all(outputs2[:,-1,:] == state2[1]))

再来怼怼dynamic_rnn中数据序列长度tricks

keras在构建LSTM模型时对变长序列的处理操作

思路样例代码

from collections import Counter
import numpy as np

origin_data = np.array([[1, 2, 3],
            [3, 0, 2],
            [1, 1, 4],
            [2, 1, 2],
            [0, 1, 1],
            [2, 0, 3]
            ])
# 按照指定列索引进行分组(看作RNN中一个样本序列),如下为按照第二列分组的结果
# [[[1, 2, 3], [0, 0, 0], [0, 0, 0]],
# [[3, 0, 2], [2, 0, 3], [0, 0, 0]],
# [[1, 1, 4], [2, 1, 2], [0, 1, 1]]]

# 第一步,将原始数据按照某列序列化使之成为一个序列数据
def groupby(a, col_index): # 未加入索引越界判断
  max_len = max(Counter(a[:, col_index]).values())
  for i in set(a[:, col_index]):
    d[i] = []
  for sample in a:
    d[sample[col_index]].append(list(sample))
#   for key in d:
#     d[key].extend([[0]*a.shape[1] for _ in range(max_len-len(d[key]))])
  return list(d.values()), [len(_) for _ in d.values()]

samples, sizes = groupby(origin_data, 2)
# 第二步,根据当前这一批次的中最大序列长度max(sizes)作为padding标准(不同批次的样本序列长度可以不一样,但同一批次要求一样(包括padding的部分)),当然也可以一次性将所有样本(不按照批量)按照最大序列长度padding也行,可能空间浪费
paddig_samples = np.zeros([len(samples), max(sizes), 3])
for seq_index, seq in enumerate(samples):
  paddig_samples[seq_index, :len(seq), :] = seq
paddig_samples

以上这篇keras在构建LSTM模型时对变长序列的处理操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中文乱码的解决方法
Nov 04 Python
关于pip的安装,更新,卸载模块以及使用方法(详解)
May 19 Python
Python基于递归算法实现的走迷宫问题
Aug 04 Python
Python的SimpleHTTPServer模块用处及使用方法简介
Jan 22 Python
在VS Code上搭建Python开发环境的方法
Apr 06 Python
Python中的二维数组实例(list与numpy.array)
Apr 13 Python
pycharm使用matplotlib.pyplot不显示图形的解决方法
Oct 28 Python
Python基于opencv实现的简单画板功能示例
Mar 04 Python
PyQt5中向单元格添加控件的方法示例
Mar 24 Python
基于python检查矩阵计算结果
May 21 Python
浅谈django不使用restframework自定义接口与使用的区别
Jul 15 Python
Python+pyaudio实现音频控制示例详解
Jul 23 Python
Python爬虫爬取博客实现可视化过程解析
Jun 29 #Python
使用keras框架cnn+ctc_loss识别不定长字符图片操作
Jun 29 #Python
浅谈keras中的后端backend及其相关函数(K.prod,K.cast)
Jun 29 #Python
如何使用python记录室友的抖音在线时间
Jun 29 #Python
Python sublime安装及配置过程详解
Jun 29 #Python
keras K.function获取某层的输出操作
Jun 29 #Python
Python pytesseract验证码识别库用法解析
Jun 29 #Python
You might like
PHP与MySQL开发中页面出现乱码的一种解决方法
2007/07/29 PHP
php 什么是PEAR?(第三篇)
2009/03/19 PHP
用PHP获取Google AJAX Search API 数据的代码
2010/03/12 PHP
PHP 函数call_user_func和call_user_func_array用法详解
2014/03/02 PHP
yii2中添加验证码的实现方法
2016/01/09 PHP
PHP中empty,isset,is_null用法和区别
2017/02/19 PHP
基于jquery的一个OutlookBar类,动态创建导航条
2010/11/19 Javascript
jquery+css3打造一款ajax分页插件(自写)
2014/06/18 Javascript
基于NodeJS的前后端分离的思考与实践(三)轻量级的接口配置建模框架
2014/09/26 NodeJs
js实现仿百度瀑布流的方法
2015/02/05 Javascript
jQuery插件支持同一页面被多次调用
2016/02/14 Javascript
基于JS代码实现图片在页面中旋转效果
2016/06/16 Javascript
多功能jQuery树插件zTree实现权限列表简单实例
2016/07/12 Javascript
JavaScript中this的用法及this在不同应用场景的作用解析
2017/04/13 Javascript
详解Vue爬坑之vuex初识
2017/06/14 Javascript
EasyUI框架 使用Ajax提交注册信息的实现代码
2017/09/27 Javascript
详解vue项目首页加载速度优化
2017/10/18 Javascript
javaScript实现鼠标在文字上悬浮时弹出悬浮层效果
2020/04/12 Javascript
浅析Vue下的components模板使用及应用
2019/11/27 Javascript
微信小程序点击按钮动态切换input的disabled禁用/启用状态功能
2020/03/07 Javascript
Node.js API详解之 util模块用法实例分析
2020/05/09 Javascript
Vue实现穿梭框效果
2020/09/30 Javascript
[01:02:46]VGJ.S vs NB 2018国际邀请赛小组赛BO2 第二场 8.18
2018/08/19 DOTA
在Python下尝试多线程编程
2015/04/28 Python
python 基础教程之Map使用方法
2017/01/17 Python
python 解压、复制、删除 文件的实例代码
2020/02/26 Python
Python使用matplotlib绘制圆形代码实例
2020/05/27 Python
轻松掌握CSS3中的字体大小单位rem的使用方法
2016/05/24 HTML / CSS
菲律宾优惠券网站:MetroDeal
2019/04/12 全球购物
年会搞笑主持词
2014/03/27 职场文书
校庆筹备方案
2014/03/30 职场文书
感恩小明星事迹材料
2014/05/23 职场文书
中学生运动会通讯稿大全
2014/09/18 职场文书
中学总务处工作总结
2015/08/12 职场文书
Python趣味挑战之教你用pygame画进度条
2021/05/31 Python
使用canvas仿Echarts实现金字塔图的实例代码
2021/11/11 HTML / CSS