pytorch-RNN进行回归曲线预测方式


Posted in Python onJanuary 14, 2020

任务

通过输入的sin曲线与预测出对应的cos曲线

#初始加载包 和定义参数
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
 
torch.manual_seed(1) #为了可复现
 
#超参数设定
TIME_SETP=10
INPUT_SIZE=1
LR=0.02
DOWNLoad_MNIST=True

定义RNN网络结构

from torch.autograd import Variable
class RNN(nn.Module):
  def __init__(self):
    #在这个函数中,两步走,先init,再逐步定义层结构
    super(RNN,self).__init__()
    
    self.rnn=nn.RNN(  #定义32隐层的rnn结构
     input_size=1,  
     hidden_size=32, #隐层有32个记忆体
     num_layers=1,   #隐层层数是1
     batch_first=True 
    )
    
    self.out=nn.Linear(32,1) #32个记忆体对应一个输出
  
  def forward(self,x,h_state):
    #前向过程,获取 rnn网络输出r_put(注意这里r_out并不是最后输出,最后要经过全连接层) 和 记忆体情况h_state
    r_out,h_state=self.rnn(x,h_state)    
    outs=[]#获取所有时间点下得到的预测值
    for time_step in range(r_out.size(1)): #将记忆rnn层的输出传到全连接层来得到最终输出。 这样每个输入对应一个输出,所以会有长度为10的输出
      outs.append(self.out(r_out[:,time_step,:]))
    return torch.stack(outs,dim=1),h_state #将10个数 通过stack方式压缩在一起
 
rnn=RNN()
print('RNN的网络体系结构为:',rnn)

pytorch-RNN进行回归曲线预测方式

创建数据集及网络训练

以sin曲线为特征,以cos曲线为标签进行网络的训练

#定义优化器和 损失函数
optimizer=torch.optim.Adam(rnn.parameters(),lr=LR)
loss_fun=nn.MSELoss()
h_state=None #记录的隐藏层状态,记住这就是记忆体,初始时候为空,之后每次后面的都会使用到前面的记忆,自动生成全0的
       #这样加入记忆信息后,每次都会在之前的记忆矩阵基础上再进行新的训练,初始是全0的形式。
#启动训练,这里假定训练的批次为100次
 
 
plt.ion() #可以设定持续不断的绘图,但是在这里看还是间断的,这是jupyter的问题
for step in range(100):
  #我们以一个π为一个时间步  定义数据,
  start,end=step*np.pi,(step+1)*np.pi
  
  steps=np.linspace(start,end,10,dtype=np.float32) #注意这里的10并不是间隔为10,而是将数按范围分成10等分了
  
  x_np=np.sin(steps)
  y_np=np.cos(steps)
  #将numpy类型转成torch类型  *****当需要 求梯度时,一个 op 的两个输入都必须是要 Variable,输入的一定要variable包下
  x=Variable(torch.from_numpy(x_np[np.newaxis,:,np.newaxis]))#增加两个维度,是三维的数据。
  y=Variable(torch.from_numpy(y_np[np.newaxis,:,np.newaxis]))
  
  #将每个时间步上的10个值 输入到rnn获得结果   这里rnn会自动执行forward前向过程. 这里输入时10个,输出也是10个,传递的是一个长度为32的记忆体
  predition,h_state=rnn(x,h_state)
  
  #更新新的中间状态
  h_state=Variable(h_state.data)  #擦,这点一定要从新包装
  loss=loss_fun(predition,y)
  #print('loss:',loss)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  
  
  # plotting  画图,这里先平展了 flatten,这样就是得到一个数组,更加直接
  
  plt.plot(steps, y_np.flatten(), 'r-')
  plt.plot(steps, predition.data.numpy().flatten(), 'b-')
  #plt.draw(); 
  plt.pause(0.05)
 
plt.ioff() #关闭交互模式
plt.show()

pytorch-RNN进行回归曲线预测方式

以上这篇pytorch-RNN进行回归曲线预测方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python json模块使用实例
Apr 11 Python
解决Python出现_warn_unsafe_extraction问题的方法
Mar 24 Python
Python 实现一个颜色色值转换的小工具
Dec 06 Python
Python机器学习之决策树算法
Dec 22 Python
pytorch构建网络模型的4种方法
Apr 13 Python
Python利用递归实现文件的复制方法
Oct 27 Python
Python面向对象之类的定义与继承用法示例
Jan 14 Python
详解Python sys.argv使用方法
May 10 Python
Python 利用高德地图api实现经纬度与地址的批量转换
Aug 14 Python
Python使用APScheduler实现定时任务过程解析
Sep 11 Python
解析Tensorflow之MNIST的使用
Jun 30 Python
django教程如何自学
Jul 31 Python
利用pytorch实现对CIFAR-10数据集的分类
Jan 14 #Python
pytorch下使用LSTM神经网络写诗实例
Jan 14 #Python
python使用openCV遍历文件夹里所有视频文件并保存成图片
Jan 14 #Python
pytorch实现mnist数据集的图像可视化及保存
Jan 14 #Python
Pytorch在dataloader类中设置shuffle的随机数种子方式
Jan 14 #Python
python3.7通过thrift操作hbase的示例代码
Jan 14 #Python
解决pytorch DataLoader num_workers出现的问题
Jan 14 #Python
You might like
php 静态变量与自定义常量的使用方法
2010/01/26 PHP
PHP 事务处理数据实现代码
2010/05/13 PHP
基于Linux调试工具strace与gdb的常用命令总结
2013/06/03 PHP
Destoon旺旺无法正常显示,点击提示“会员名不存在”的解决办法
2014/06/21 PHP
php实现支持中文的文件下载功能示例
2017/08/30 PHP
onpropertypchange
2006/07/01 Javascript
用javascript实现读取txt文档的脚本
2007/07/20 Javascript
Javascript 获取LI里的内容
2008/12/17 Javascript
javascript,jquery闭包概念分析
2010/06/19 Javascript
基于node.js的快速开发透明代理
2010/12/25 Javascript
javascript元素动态创建实现方法
2015/05/13 Javascript
jquery用ajax方式从后台获取json数据后如何将内容填充到下拉列表
2015/08/26 Javascript
jQuery实现输入框的放大和缩小功能示例
2018/07/21 jQuery
详解Nuxt.js 实战集锦
2019/11/19 Javascript
Vue中登录验证成功后保存token,并每次请求携带并验证token操作
2020/09/08 Javascript
python自定义异常实例详解
2017/07/11 Python
Python实现中一次读取多个值的方法
2018/04/22 Python
django请求返回不同的类型图片json,xml,html的实例
2018/05/22 Python
python TKinter获取文本框内容的方法
2018/10/11 Python
python集合是否可变总结
2019/06/20 Python
Python中pymysql 模块的使用详解
2019/08/12 Python
Python树莓派学习笔记之UDP传输视频帧操作详解
2019/11/15 Python
python中rb含义理解
2020/06/18 Python
移动端Web页面的CSS3 flex布局快速上手指南
2016/05/31 HTML / CSS
HTML5 canvas基本绘图之图形组合
2016/06/27 HTML / CSS
亚马逊巴西站:Amazon.com.br
2019/09/22 全球购物
六五普法宣传标语
2014/10/06 职场文书
拾金不昧表扬稿大全
2015/05/05 职场文书
太空授课观后感
2015/06/17 职场文书
幼儿园开学温馨提示
2015/07/15 职场文书
党员反四风学习心得体会
2016/01/22 职场文书
如何使用Python对NetCDF数据做空间相关分析
2021/04/21 Python
Vue实现跑马灯样式文字横向滚动
2021/11/23 Vue.js
mybatis源码解读之executor包语句处理功能
2022/02/15 Java/Android
「我的青春恋爱物语果然有问题。-妄言录-」第20卷封面公开
2022/03/21 日漫
Python pyecharts案例超市4年数据可视化分析
2022/08/14 Python