Pytorch 如何实现LSTM时间序列预测


Posted in Python onMay 17, 2021

开发环境说明:

Python 35

Pytorch 0.2

CPU/GPU均可

1、LSTM简介

人类在进行学习时,往往不总是零开始,学习物理你会有数学基础、学习英语你会有中文基础等等。

于是对于机器而言,神经网络的学习亦可不再从零开始,于是出现了Transfer Learning,就是把一个领域已训练好的网络用于初始化另一个领域的任务,例如会下棋的神经网络可以用于打德州扑克。

我们这讲的是另一种不从零开始学习的神经网络——循环神经网络(Recurrent Neural Network, RNN),它的每一次迭代都是基于上一次的学习结果,不断循环以得到对于整体序列的学习,区别于传统的MLP神经网络,这种神经网络模型存在环型结构,

具体下所示:

Pytorch 如何实现LSTM时间序列预测

上图是RNN的基本单元,通过不断循环迭代展开模型如下所示,图中ht是神经网络的在t时刻的输出,xt是t时刻的输入数据。

这种循环结构对时间序列数据能够很好地建模,例如语音识别、语言建模、机器翻译等领域。

Pytorch 如何实现LSTM时间序列预测

但是普通的RNN对于长期依赖问题效果比较差,当序列本身比较长时,由于神经网络模型的训练是采用backward进行,在梯度链式法则中容易出现梯度消失和梯度爆炸的问题,需要进一步改进RNN的模型结构。

针对Simple RNN存在的问题,LSTM网络模型被提出,LSTM的核心是修改了增添了Cell State,即加入了LSTM CELL,通过输入门、输出门、遗忘门把上一时刻的hidden state和cell state传给下一个状态。

如下所示:

Pytorch 如何实现LSTM时间序列预测

遗忘门:ft = sigma(Wf*[ht-1, xt] + bf)

输入门:it = sigma(Wi*[ht-1, xt] + bi)

cell state initial: C't = tanh(Wc*[ht-1, xt] +bc)

cell state: Ct = ft*Ct-1+ itC't

输出门:ot = sigma(Wo*[ht-1, xt] + bo)

模型输出:ht = ot*tanh(Ct)

LSTM有很多种变型结构,实际工程化过程中用的比较多的是peephole,就是计算每个门的时候增添了cell state的信息,有兴趣的童鞋可以专研专研。

上一部分简单地介绍了LSTM的模型结构,下边将具体介绍使用LSTM模型进行时间序列预测的具体过程。

2、数据准备

对于时间序列,本文选取正弦波序列,事先产生一定数量的序列数据,然后截取前部分作为训练数据训练LSTM模型,后部分作为真实值与模型预测结果进行比较。正弦波的产生过程如下:

SeriesGen(N)方法用于产生长度为N的正弦波数值序列;

trainDataGen(seq,k)用于产生训练或测试数据,返回数据结构为输入输出数据。seq为序列数据,k为LSTM模型循环的长度,使用1~k的数据预测2~k+1的数据。

Pytorch 如何实现LSTM时间序列预测

3、模型构建

Pytorch的nn模块提供了LSTM方法,具体接口使用说明可以参见Pytorch的接口使用说明书。此处调用nn.LSTM构建LSTM神经网络,模型另增加了线性变化的全连接层Linear(),但并未加入激活函数。由于是单个数值的预测,这里input_size和output_size都为1.

Pytorch 如何实现LSTM时间序列预测

4、训练和测试

(1)模型定义、损失函数定义

Pytorch 如何实现LSTM时间序列预测

(2)训练与测试

Pytorch 如何实现LSTM时间序列预测

(3)结果展示

比较模型预测序列结果与真实值之间的差距

Pytorch 如何实现LSTM时间序列预测

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
如何用itertools解决无序排列组合的问题
May 18 Python
Python 将Matrix、Dict保存到文件的方法
Oct 30 Python
用python标准库difflib比较两份文件的异同详解
Nov 16 Python
在Pandas中DataFrame数据合并,连接(concat,merge,join)的实例
Jan 29 Python
利用Python查看微信共同好友功能的实现代码
Apr 24 Python
python-pyinstaller、打包后获取路径的实例
Jun 10 Python
Python FtpLib模块应用操作详解
Dec 12 Python
Python面向对象封装操作案例详解 II
Jan 02 Python
python中with用法讲解
Feb 07 Python
PyCharm设置Ipython交互环境和宏快捷键进行数据分析图文详解
Apr 23 Python
python的reverse函数翻转结果为None的问题
May 11 Python
M1芯片安装python3.9.1的实现
Feb 02 Python
pytorch实现ResNet结构的实例代码
pytorch常用数据类型所占字节数对照表一览
May 17 #Python
python使用tkinter实现透明窗体上绘制随机出现的小球(实例代码)
Python编写可视化界面的全过程(Python+PyCharm+PyQt)
Pytorch 实现变量类型转换
Python进度条的使用
May 17 #Python
Python包管理工具pip的15 个使用小技巧
You might like
PHP 5.3.1 安装包 VC9 VC6不同版本的区别是什么
2010/07/04 PHP
用 Composer构建自己的 PHP 框架之使用 ORM
2014/10/30 PHP
yii2.0之GridView自定义按钮和链接用法
2014/12/15 PHP
PIGCMS 如何关闭聊天机器人
2015/02/12 PHP
PHP图像处理技术实例总结【绘图、水印、验证码、图像压缩】
2018/12/08 PHP
php高性能日志系统 seaslog 的安装与使用方法分析
2020/02/29 PHP
做网页的一些技巧
2007/02/01 Javascript
轻轻松松学JS调试(不下载任何工具)
2010/04/14 Javascript
js获取网页可见区域、正文以及屏幕分辨率的高度
2014/05/15 Javascript
基于zepto.js实现仿手机QQ空间的大图查看组件ImageView.js详解
2015/03/05 Javascript
jQuery实现的简洁下拉菜单导航效果代码
2015/08/26 Javascript
JS只能输入正整数的简单实例
2016/10/07 Javascript
微信小程序 wx.request(OBJECT)发起请求详解
2016/10/13 Javascript
JS button按钮实现submit按钮提交效果
2016/11/01 Javascript
浅谈Javascript中的Label语句
2016/12/14 Javascript
JS模拟超市简易收银台小程序代码解析
2017/08/18 Javascript
详解.vue文件中监听input输入事件(oninput)
2017/09/19 Javascript
vue项目中的webpack-dev-sever配置方法
2017/12/14 Javascript
vue刷新和tab切换实例
2018/02/11 Javascript
swiper 自动图片无限轮播实现代码
2018/05/21 Javascript
nodejs express配置自签名https服务器的方法
2018/05/22 NodeJs
一个小时快速搭建微信小程序的方法步骤
2019/04/15 Javascript
基于Nuxt.js项目的服务端性能优化与错误检测(容错处理)
2019/10/23 Javascript
微信小程序封装多张图片上传api代码实例
2019/12/30 Javascript
小程序如何写动态标签的实现方法
2020/02/05 Javascript
前端深入理解Typescript泛型概念
2020/03/09 Javascript
jquery更改元素属性attr()方法操作示例
2020/05/22 jQuery
实例讲解React 组件生命周期
2020/07/08 Javascript
Antd的table组件表格的序号自增操作
2020/10/27 Javascript
Pyramid添加Middleware的方法实例
2013/11/27 Python
python3 读取Excel表格中的数据
2018/10/16 Python
Python全面分析系统的时域特性和频率域特性
2020/02/26 Python
jupyter notebook 重装教程
2020/04/16 Python
基于pytorch中的Sequential用法说明
2020/06/24 Python
html5 css3实例教程 一款html5和css3实现的小机器人走路动画
2014/10/20 HTML / CSS
西班牙语在线票务市场:SuperBoletería
2019/06/10 全球购物