Pytorch 如何实现LSTM时间序列预测


Posted in Python onMay 17, 2021

开发环境说明:

Python 35

Pytorch 0.2

CPU/GPU均可

1、LSTM简介

人类在进行学习时,往往不总是零开始,学习物理你会有数学基础、学习英语你会有中文基础等等。

于是对于机器而言,神经网络的学习亦可不再从零开始,于是出现了Transfer Learning,就是把一个领域已训练好的网络用于初始化另一个领域的任务,例如会下棋的神经网络可以用于打德州扑克。

我们这讲的是另一种不从零开始学习的神经网络——循环神经网络(Recurrent Neural Network, RNN),它的每一次迭代都是基于上一次的学习结果,不断循环以得到对于整体序列的学习,区别于传统的MLP神经网络,这种神经网络模型存在环型结构,

具体下所示:

Pytorch 如何实现LSTM时间序列预测

上图是RNN的基本单元,通过不断循环迭代展开模型如下所示,图中ht是神经网络的在t时刻的输出,xt是t时刻的输入数据。

这种循环结构对时间序列数据能够很好地建模,例如语音识别、语言建模、机器翻译等领域。

Pytorch 如何实现LSTM时间序列预测

但是普通的RNN对于长期依赖问题效果比较差,当序列本身比较长时,由于神经网络模型的训练是采用backward进行,在梯度链式法则中容易出现梯度消失和梯度爆炸的问题,需要进一步改进RNN的模型结构。

针对Simple RNN存在的问题,LSTM网络模型被提出,LSTM的核心是修改了增添了Cell State,即加入了LSTM CELL,通过输入门、输出门、遗忘门把上一时刻的hidden state和cell state传给下一个状态。

如下所示:

Pytorch 如何实现LSTM时间序列预测

遗忘门:ft = sigma(Wf*[ht-1, xt] + bf)

输入门:it = sigma(Wi*[ht-1, xt] + bi)

cell state initial: C't = tanh(Wc*[ht-1, xt] +bc)

cell state: Ct = ft*Ct-1+ itC't

输出门:ot = sigma(Wo*[ht-1, xt] + bo)

模型输出:ht = ot*tanh(Ct)

LSTM有很多种变型结构,实际工程化过程中用的比较多的是peephole,就是计算每个门的时候增添了cell state的信息,有兴趣的童鞋可以专研专研。

上一部分简单地介绍了LSTM的模型结构,下边将具体介绍使用LSTM模型进行时间序列预测的具体过程。

2、数据准备

对于时间序列,本文选取正弦波序列,事先产生一定数量的序列数据,然后截取前部分作为训练数据训练LSTM模型,后部分作为真实值与模型预测结果进行比较。正弦波的产生过程如下:

SeriesGen(N)方法用于产生长度为N的正弦波数值序列;

trainDataGen(seq,k)用于产生训练或测试数据,返回数据结构为输入输出数据。seq为序列数据,k为LSTM模型循环的长度,使用1~k的数据预测2~k+1的数据。

Pytorch 如何实现LSTM时间序列预测

3、模型构建

Pytorch的nn模块提供了LSTM方法,具体接口使用说明可以参见Pytorch的接口使用说明书。此处调用nn.LSTM构建LSTM神经网络,模型另增加了线性变化的全连接层Linear(),但并未加入激活函数。由于是单个数值的预测,这里input_size和output_size都为1.

Pytorch 如何实现LSTM时间序列预测

4、训练和测试

(1)模型定义、损失函数定义

Pytorch 如何实现LSTM时间序列预测

(2)训练与测试

Pytorch 如何实现LSTM时间序列预测

(3)结果展示

比较模型预测序列结果与真实值之间的差距

Pytorch 如何实现LSTM时间序列预测

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
Cython 三分钟入门教程
Sep 17 Python
Python函数返回值实例分析
Jun 08 Python
详解Python3中yield生成器的用法
Aug 20 Python
python jieba分词并统计词频后输出结果到Excel和txt文档方法
Feb 11 Python
解决python3 urllib 链接中有中文的问题
Jul 16 Python
numpy.ndarray 交换多维数组(矩阵)的行/列方法
Aug 02 Python
python gensim使用word2vec词向量处理中文语料的方法
Jul 05 Python
TensorFlow tensor的拼接实例
Jan 19 Python
python输出第n个默尼森数的实现示例
Mar 08 Python
python3 sorted 如何实现自定义排序标准
Mar 12 Python
Python面向对象魔法方法和单例模块代码实例
Mar 25 Python
python实战之90行代码写个猜数字游戏
Apr 22 Python
pytorch实现ResNet结构的实例代码
pytorch常用数据类型所占字节数对照表一览
May 17 #Python
python使用tkinter实现透明窗体上绘制随机出现的小球(实例代码)
Python编写可视化界面的全过程(Python+PyCharm+PyQt)
Pytorch 实现变量类型转换
Python进度条的使用
May 17 #Python
Python包管理工具pip的15 个使用小技巧
You might like
在PHP中执行系统外部命令
2006/10/09 PHP
php采集时被封ip的解决方法
2010/08/29 PHP
分享PHP守护进程类
2015/12/30 PHP
ThinkPHP的常用配置选项汇总
2016/03/24 PHP
ThinkPHP5.0 图片上传生成缩略图实例代码说明
2018/06/20 PHP
js禁止小键盘输入数字功能代码
2011/08/01 Javascript
一个JS的日期格式化算法示例
2013/07/31 Javascript
JQuery解析HTML、JSON和XML实例详解
2014/03/29 Javascript
jQuery实现选中弹出窗口选择框内容后赋值给文本框的方法
2015/11/23 Javascript
JS中数组重排序方法
2016/11/11 Javascript
JSON中key动态设置及JSON.parse和JSON.stringify()的区别
2016/12/29 Javascript
javascript 开发之网页兼容各种浏览器
2017/09/28 Javascript
浅谈Angular路由复用策略
2017/10/04 Javascript
Node使用Sequlize连接Mysql报错:Access denied for user ‘xxx’@‘localhost’
2018/01/03 Javascript
Vue结合Video.js播放m3u8视频流的方法示例
2018/05/04 Javascript
react中实现搜索结果中关键词高亮显示
2018/07/31 Javascript
简单了解微信小程序的目录结构
2019/07/01 Javascript
在vue中把含有html标签转为html渲染页面的实例
2019/10/28 Javascript
通过微信公众平台获取公众号文章的方法示例
2019/12/25 Javascript
浅析vue-router实现原理及两种模式
2020/02/11 Javascript
antd日期选择器禁止选择当天之前的时间操作
2020/10/29 Javascript
[02:26]2016国际邀请赛8月3日开战 中国军团出征西雅图
2016/08/02 DOTA
浅析Python 中整型对象存储的位置
2016/05/16 Python
Python Json模块中dumps、loads、dump、load函数介绍
2018/05/15 Python
在Pandas中处理NaN值的方法
2019/06/25 Python
详解CSS3实现响应式手风琴效果
2020/06/10 HTML / CSS
JD Sports德国官网:英国领先的运动鞋和运动服饰零售商
2018/02/26 全球购物
Mio Skincare英国官网:身体紧致及孕期身体护理
2018/08/19 全球购物
eDreams意大利:南欧领先的在线旅行社
2018/11/23 全球购物
LN-CC日本:高端男装和女装的奢侈时尚目的地
2019/09/01 全球购物
自主招生自荐信
2013/12/08 职场文书
倡议书格式
2014/04/14 职场文书
党支部先进事迹材料
2014/12/24 职场文书
网络管理员岗位职责
2015/02/12 职场文书
2015年校医个人工作总结
2015/07/24 职场文书
Tomcat配置访问日志和线程数
2022/05/06 Servers