pytorch-RNN进行回归曲线预测方式


Posted in Python onJanuary 14, 2020

任务

通过输入的sin曲线与预测出对应的cos曲线

#初始加载包 和定义参数
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
 
torch.manual_seed(1) #为了可复现
 
#超参数设定
TIME_SETP=10
INPUT_SIZE=1
LR=0.02
DOWNLoad_MNIST=True

定义RNN网络结构

from torch.autograd import Variable
class RNN(nn.Module):
  def __init__(self):
    #在这个函数中,两步走,先init,再逐步定义层结构
    super(RNN,self).__init__()
    
    self.rnn=nn.RNN(  #定义32隐层的rnn结构
     input_size=1,  
     hidden_size=32, #隐层有32个记忆体
     num_layers=1,   #隐层层数是1
     batch_first=True 
    )
    
    self.out=nn.Linear(32,1) #32个记忆体对应一个输出
  
  def forward(self,x,h_state):
    #前向过程,获取 rnn网络输出r_put(注意这里r_out并不是最后输出,最后要经过全连接层) 和 记忆体情况h_state
    r_out,h_state=self.rnn(x,h_state)    
    outs=[]#获取所有时间点下得到的预测值
    for time_step in range(r_out.size(1)): #将记忆rnn层的输出传到全连接层来得到最终输出。 这样每个输入对应一个输出,所以会有长度为10的输出
      outs.append(self.out(r_out[:,time_step,:]))
    return torch.stack(outs,dim=1),h_state #将10个数 通过stack方式压缩在一起
 
rnn=RNN()
print('RNN的网络体系结构为:',rnn)

pytorch-RNN进行回归曲线预测方式

创建数据集及网络训练

以sin曲线为特征,以cos曲线为标签进行网络的训练

#定义优化器和 损失函数
optimizer=torch.optim.Adam(rnn.parameters(),lr=LR)
loss_fun=nn.MSELoss()
h_state=None #记录的隐藏层状态,记住这就是记忆体,初始时候为空,之后每次后面的都会使用到前面的记忆,自动生成全0的
       #这样加入记忆信息后,每次都会在之前的记忆矩阵基础上再进行新的训练,初始是全0的形式。
#启动训练,这里假定训练的批次为100次
 
 
plt.ion() #可以设定持续不断的绘图,但是在这里看还是间断的,这是jupyter的问题
for step in range(100):
  #我们以一个π为一个时间步  定义数据,
  start,end=step*np.pi,(step+1)*np.pi
  
  steps=np.linspace(start,end,10,dtype=np.float32) #注意这里的10并不是间隔为10,而是将数按范围分成10等分了
  
  x_np=np.sin(steps)
  y_np=np.cos(steps)
  #将numpy类型转成torch类型  *****当需要 求梯度时,一个 op 的两个输入都必须是要 Variable,输入的一定要variable包下
  x=Variable(torch.from_numpy(x_np[np.newaxis,:,np.newaxis]))#增加两个维度,是三维的数据。
  y=Variable(torch.from_numpy(y_np[np.newaxis,:,np.newaxis]))
  
  #将每个时间步上的10个值 输入到rnn获得结果   这里rnn会自动执行forward前向过程. 这里输入时10个,输出也是10个,传递的是一个长度为32的记忆体
  predition,h_state=rnn(x,h_state)
  
  #更新新的中间状态
  h_state=Variable(h_state.data)  #擦,这点一定要从新包装
  loss=loss_fun(predition,y)
  #print('loss:',loss)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  
  
  # plotting  画图,这里先平展了 flatten,这样就是得到一个数组,更加直接
  
  plt.plot(steps, y_np.flatten(), 'r-')
  plt.plot(steps, predition.data.numpy().flatten(), 'b-')
  #plt.draw(); 
  plt.pause(0.05)
 
plt.ioff() #关闭交互模式
plt.show()

pytorch-RNN进行回归曲线预测方式

以上这篇pytorch-RNN进行回归曲线预测方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python处理中文编码和判断编码示例
Feb 26 Python
Python中的jquery PyQuery库使用小结
May 13 Python
python中self原理实例分析
Apr 30 Python
python实现在图片上画特定大小角度矩形框
Oct 24 Python
解决python写入带有中文的字符到文件错误的问题
Jan 31 Python
Python通过TensorFlow卷积神经网络实现猫狗识别
Mar 14 Python
Python使用crontab模块设置和清除定时任务操作详解
Apr 09 Python
Python单链表原理与实现方法详解
Feb 22 Python
Python函数的迭代器与生成器的示例代码
Jun 18 Python
Python经典五人分鱼实例讲解
Jan 04 Python
python 数据类型强制转换的总结
Jan 25 Python
如何将numpy二维数组中的np.nan值替换为指定的值
May 14 Python
利用pytorch实现对CIFAR-10数据集的分类
Jan 14 #Python
pytorch下使用LSTM神经网络写诗实例
Jan 14 #Python
python使用openCV遍历文件夹里所有视频文件并保存成图片
Jan 14 #Python
pytorch实现mnist数据集的图像可视化及保存
Jan 14 #Python
Pytorch在dataloader类中设置shuffle的随机数种子方式
Jan 14 #Python
python3.7通过thrift操作hbase的示例代码
Jan 14 #Python
解决pytorch DataLoader num_workers出现的问题
Jan 14 #Python
You might like
浅谈PHP 闭包特性在实际应用中的问题
2009/10/30 PHP
PHP连接MySQL进行增、删、改、查操作
2017/02/19 PHP
php微信开发之图片回复功能
2018/06/14 PHP
checkbox 复选框不能为空
2009/07/11 Javascript
jQuery.query.js 取参数的两点问题分析
2012/08/06 Javascript
jQuery的animate函数学习记录
2014/08/08 Javascript
轻松学习jQuery插件EasyUI EasyUI实现树形网络基本操作(2)
2015/11/30 Javascript
JS获取月份最后天数、最大天数与某日周数的方法
2015/12/08 Javascript
AngularJS基础 ng-csp 指令详解
2016/08/01 Javascript
js事件冒泡、事件捕获和阻止默认事件详解
2016/08/04 Javascript
利用JQuery阻止事件冒泡
2016/12/01 Javascript
jquery select插件异步实时搜索实例代码
2017/10/20 jQuery
微信小程序实现复选框效果
2018/12/28 Javascript
详解如何使用nvm管理Node.js多版本
2019/05/06 Javascript
vue项目前端错误收集之sentry教程详解
2019/05/27 Javascript
小程序实现图片预览裁剪插件
2019/11/22 Javascript
webpack优化之代码分割与公共代码提取详解
2019/11/22 Javascript
写一个Vue loading 插件
2020/11/09 Javascript
python2.7删除文件夹和删除文件代码实例
2013/12/18 Python
Python基础之函数用法实例详解
2014/09/10 Python
python django使用haystack:全文检索的框架(实例讲解)
2017/09/27 Python
如何在python中使用selenium的示例
2017/12/26 Python
python使用socket创建tcp服务器和客户端
2018/04/12 Python
Python何时应该使用Lambda函数
2019/07/02 Python
Django实现微信小程序的登录验证功能并维护登录态
2019/07/04 Python
基于Python中Remove函数的用法讨论
2020/12/11 Python
python绘图pyecharts+pandas的使用详解
2020/12/13 Python
构造方法和其他方法的区别?怎么调用父类的构造方法
2013/09/22 面试题
机械制造与自动化应届生求职信
2013/11/16 职场文书
信息专业学生学习的自我评价
2014/02/17 职场文书
彩色的非洲教学反思
2014/02/18 职场文书
中专生自荐信
2014/06/25 职场文书
班主任先进事迹材料
2014/12/17 职场文书
2014年变电站工作总结
2014/12/19 职场文书
2016应届大学生自荐信模板
2016/01/28 职场文书
python保存大型 .mat 数据文件报错超出 IO 限制的操作
2021/05/10 Python