pytorch-RNN进行回归曲线预测方式


Posted in Python onJanuary 14, 2020

任务

通过输入的sin曲线与预测出对应的cos曲线

#初始加载包 和定义参数
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
 
torch.manual_seed(1) #为了可复现
 
#超参数设定
TIME_SETP=10
INPUT_SIZE=1
LR=0.02
DOWNLoad_MNIST=True

定义RNN网络结构

from torch.autograd import Variable
class RNN(nn.Module):
  def __init__(self):
    #在这个函数中,两步走,先init,再逐步定义层结构
    super(RNN,self).__init__()
    
    self.rnn=nn.RNN(  #定义32隐层的rnn结构
     input_size=1,  
     hidden_size=32, #隐层有32个记忆体
     num_layers=1,   #隐层层数是1
     batch_first=True 
    )
    
    self.out=nn.Linear(32,1) #32个记忆体对应一个输出
  
  def forward(self,x,h_state):
    #前向过程,获取 rnn网络输出r_put(注意这里r_out并不是最后输出,最后要经过全连接层) 和 记忆体情况h_state
    r_out,h_state=self.rnn(x,h_state)    
    outs=[]#获取所有时间点下得到的预测值
    for time_step in range(r_out.size(1)): #将记忆rnn层的输出传到全连接层来得到最终输出。 这样每个输入对应一个输出,所以会有长度为10的输出
      outs.append(self.out(r_out[:,time_step,:]))
    return torch.stack(outs,dim=1),h_state #将10个数 通过stack方式压缩在一起
 
rnn=RNN()
print('RNN的网络体系结构为:',rnn)

pytorch-RNN进行回归曲线预测方式

创建数据集及网络训练

以sin曲线为特征,以cos曲线为标签进行网络的训练

#定义优化器和 损失函数
optimizer=torch.optim.Adam(rnn.parameters(),lr=LR)
loss_fun=nn.MSELoss()
h_state=None #记录的隐藏层状态,记住这就是记忆体,初始时候为空,之后每次后面的都会使用到前面的记忆,自动生成全0的
       #这样加入记忆信息后,每次都会在之前的记忆矩阵基础上再进行新的训练,初始是全0的形式。
#启动训练,这里假定训练的批次为100次
 
 
plt.ion() #可以设定持续不断的绘图,但是在这里看还是间断的,这是jupyter的问题
for step in range(100):
  #我们以一个π为一个时间步  定义数据,
  start,end=step*np.pi,(step+1)*np.pi
  
  steps=np.linspace(start,end,10,dtype=np.float32) #注意这里的10并不是间隔为10,而是将数按范围分成10等分了
  
  x_np=np.sin(steps)
  y_np=np.cos(steps)
  #将numpy类型转成torch类型  *****当需要 求梯度时,一个 op 的两个输入都必须是要 Variable,输入的一定要variable包下
  x=Variable(torch.from_numpy(x_np[np.newaxis,:,np.newaxis]))#增加两个维度,是三维的数据。
  y=Variable(torch.from_numpy(y_np[np.newaxis,:,np.newaxis]))
  
  #将每个时间步上的10个值 输入到rnn获得结果   这里rnn会自动执行forward前向过程. 这里输入时10个,输出也是10个,传递的是一个长度为32的记忆体
  predition,h_state=rnn(x,h_state)
  
  #更新新的中间状态
  h_state=Variable(h_state.data)  #擦,这点一定要从新包装
  loss=loss_fun(predition,y)
  #print('loss:',loss)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  
  
  # plotting  画图,这里先平展了 flatten,这样就是得到一个数组,更加直接
  
  plt.plot(steps, y_np.flatten(), 'r-')
  plt.plot(steps, predition.data.numpy().flatten(), 'b-')
  #plt.draw(); 
  plt.pause(0.05)
 
plt.ioff() #关闭交互模式
plt.show()

pytorch-RNN进行回归曲线预测方式

以上这篇pytorch-RNN进行回归曲线预测方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Python脚本操作MongoDB的教程
Apr 16 Python
Python实现各种排序算法的代码示例总结
Dec 11 Python
Python基础教程之浅拷贝和深拷贝实例详解
Jul 15 Python
TensorFlow 合并/连接数组的方法
Jul 27 Python
对python中数组的del,remove,pop区别详解
Nov 07 Python
python面向对象入门教程之从代码复用开始(一)
Dec 11 Python
解决python字典对值(值为列表)赋值出现重复的问题
Jan 20 Python
python用类实现文章敏感词的过滤方法示例
Oct 27 Python
Python 寻找局部最高点的实现
Dec 05 Python
python 已知三条边求三角形的角度案例
Apr 12 Python
基于pycharm实现批量修改变量名
Jun 02 Python
opencv python 对指针仪表读数识别的两种方式
Jan 14 Python
利用pytorch实现对CIFAR-10数据集的分类
Jan 14 #Python
pytorch下使用LSTM神经网络写诗实例
Jan 14 #Python
python使用openCV遍历文件夹里所有视频文件并保存成图片
Jan 14 #Python
pytorch实现mnist数据集的图像可视化及保存
Jan 14 #Python
Pytorch在dataloader类中设置shuffle的随机数种子方式
Jan 14 #Python
python3.7通过thrift操作hbase的示例代码
Jan 14 #Python
解决pytorch DataLoader num_workers出现的问题
Jan 14 #Python
You might like
PHP中使用gettext来支持多语言的方法
2011/05/02 PHP
PHP中字符与字节的区别及字符串与字节转换示例
2016/10/15 PHP
Nigma vs Alliance BO5 第五场2.14
2021/03/10 DOTA
30个最好的jQuery 灯箱插件分享
2011/04/25 Javascript
基于jQuery中对数组进行操作的方法
2013/04/16 Javascript
wap手机图片滑动切换特效无css3元素js脚本编写
2014/07/28 Javascript
Jquery中$.post和$.ajax的用法小结
2015/04/28 Javascript
windows下安装nodejs及框架express
2015/08/07 NodeJs
基于jQuery实现带动画效果超炫酷的弹出对话框(附源码下载)
2016/02/22 Javascript
AngularJS控制器详解及示例代码
2016/08/16 Javascript
js时间比较 js计算时间差的简单实现方法
2016/08/26 Javascript
ajax图片上传,图片异步上传,更新实例
2016/12/30 Javascript
Vue.js系列之项目搭建(1)
2017/01/03 Javascript
js实现无缝滚动图
2017/02/22 Javascript
利用jQuery实现一个简单的表格上下翻页效果
2017/03/14 Javascript
详解angularjs利用ui-route异步加载组件
2017/05/21 Javascript
JavaScript实现的可变动态数字键盘控件方式实例代码
2017/07/15 Javascript
JavaScript中的连续赋值问题实例分析
2019/07/12 Javascript
详解钉钉小程序组件之自定义模态框(弹窗封装实现)
2020/03/07 Javascript
package.json中homepage属性的作用详解
2020/03/11 Javascript
[53:44]DOTA2-DPC中国联赛 正赛 PSG.LGD vs Magma BO3 第一场 1月31日
2021/03/11 DOTA
Python中文竖排显示的方法
2015/07/28 Python
Python视频爬虫实现下载头条视频功能示例
2018/05/07 Python
Pycharm 操作Django Model的简单运用方法
2018/05/23 Python
Python 函数list&read&seek详解
2019/08/28 Python
django 连接数据库出现1045错误的解决方式
2020/05/14 Python
英国电动工具购买网站:Anglia Tool Centre
2017/04/25 全球购物
捷克母婴用品购物网站:Feedo.cz
2020/12/28 全球购物
需求分析说明书
2014/05/09 职场文书
化验室岗位职责
2015/02/14 职场文书
慰问信格式
2015/02/14 职场文书
赵氏孤儿观后感
2015/06/09 职场文书
法人身份证明书
2015/06/18 职场文书
售房协议书范本
2015/08/11 职场文书
小程序wx.getUserProfile接口的具体使用
2021/06/02 Javascript
Python实现拼音转换
2021/06/07 Python