pytorch-RNN进行回归曲线预测方式


Posted in Python onJanuary 14, 2020

任务

通过输入的sin曲线与预测出对应的cos曲线

#初始加载包 和定义参数
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
 
torch.manual_seed(1) #为了可复现
 
#超参数设定
TIME_SETP=10
INPUT_SIZE=1
LR=0.02
DOWNLoad_MNIST=True

定义RNN网络结构

from torch.autograd import Variable
class RNN(nn.Module):
  def __init__(self):
    #在这个函数中,两步走,先init,再逐步定义层结构
    super(RNN,self).__init__()
    
    self.rnn=nn.RNN(  #定义32隐层的rnn结构
     input_size=1,  
     hidden_size=32, #隐层有32个记忆体
     num_layers=1,   #隐层层数是1
     batch_first=True 
    )
    
    self.out=nn.Linear(32,1) #32个记忆体对应一个输出
  
  def forward(self,x,h_state):
    #前向过程,获取 rnn网络输出r_put(注意这里r_out并不是最后输出,最后要经过全连接层) 和 记忆体情况h_state
    r_out,h_state=self.rnn(x,h_state)    
    outs=[]#获取所有时间点下得到的预测值
    for time_step in range(r_out.size(1)): #将记忆rnn层的输出传到全连接层来得到最终输出。 这样每个输入对应一个输出,所以会有长度为10的输出
      outs.append(self.out(r_out[:,time_step,:]))
    return torch.stack(outs,dim=1),h_state #将10个数 通过stack方式压缩在一起
 
rnn=RNN()
print('RNN的网络体系结构为:',rnn)

pytorch-RNN进行回归曲线预测方式

创建数据集及网络训练

以sin曲线为特征,以cos曲线为标签进行网络的训练

#定义优化器和 损失函数
optimizer=torch.optim.Adam(rnn.parameters(),lr=LR)
loss_fun=nn.MSELoss()
h_state=None #记录的隐藏层状态,记住这就是记忆体,初始时候为空,之后每次后面的都会使用到前面的记忆,自动生成全0的
       #这样加入记忆信息后,每次都会在之前的记忆矩阵基础上再进行新的训练,初始是全0的形式。
#启动训练,这里假定训练的批次为100次
 
 
plt.ion() #可以设定持续不断的绘图,但是在这里看还是间断的,这是jupyter的问题
for step in range(100):
  #我们以一个π为一个时间步  定义数据,
  start,end=step*np.pi,(step+1)*np.pi
  
  steps=np.linspace(start,end,10,dtype=np.float32) #注意这里的10并不是间隔为10,而是将数按范围分成10等分了
  
  x_np=np.sin(steps)
  y_np=np.cos(steps)
  #将numpy类型转成torch类型  *****当需要 求梯度时,一个 op 的两个输入都必须是要 Variable,输入的一定要variable包下
  x=Variable(torch.from_numpy(x_np[np.newaxis,:,np.newaxis]))#增加两个维度,是三维的数据。
  y=Variable(torch.from_numpy(y_np[np.newaxis,:,np.newaxis]))
  
  #将每个时间步上的10个值 输入到rnn获得结果   这里rnn会自动执行forward前向过程. 这里输入时10个,输出也是10个,传递的是一个长度为32的记忆体
  predition,h_state=rnn(x,h_state)
  
  #更新新的中间状态
  h_state=Variable(h_state.data)  #擦,这点一定要从新包装
  loss=loss_fun(predition,y)
  #print('loss:',loss)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  
  
  # plotting  画图,这里先平展了 flatten,这样就是得到一个数组,更加直接
  
  plt.plot(steps, y_np.flatten(), 'r-')
  plt.plot(steps, predition.data.numpy().flatten(), 'b-')
  #plt.draw(); 
  plt.pause(0.05)
 
plt.ioff() #关闭交互模式
plt.show()

pytorch-RNN进行回归曲线预测方式

以上这篇pytorch-RNN进行回归曲线预测方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python获取文件后缀名及批量更新目录下文件后缀名的方法
Nov 11 Python
Python脚本实现虾米网签到功能
Apr 12 Python
解决Python字典写入文件出行首行有空格的问题
Sep 27 Python
ubuntu17.4下为python和python3装上pip的方法
Jun 12 Python
python生成九宫格图片
Nov 19 Python
python filecmp.dircmp实现递归比对两个目录的方法
May 22 Python
Python如何基于Tesseract实现识别文字功能
Jun 05 Python
python 使用多线程创建一个Buffer缓存器的实现思路
Jul 02 Python
解决pip install psycopg2出错问题
Jul 09 Python
用python写PDF转换器的实现
Oct 29 Python
PyTorch的Debug指南
May 07 Python
python flappy bird小游戏分步实现流程
Feb 15 Python
利用pytorch实现对CIFAR-10数据集的分类
Jan 14 #Python
pytorch下使用LSTM神经网络写诗实例
Jan 14 #Python
python使用openCV遍历文件夹里所有视频文件并保存成图片
Jan 14 #Python
pytorch实现mnist数据集的图像可视化及保存
Jan 14 #Python
Pytorch在dataloader类中设置shuffle的随机数种子方式
Jan 14 #Python
python3.7通过thrift操作hbase的示例代码
Jan 14 #Python
解决pytorch DataLoader num_workers出现的问题
Jan 14 #Python
You might like
php用数组返回无限分类的列表数据的代码
2010/08/08 PHP
如何突破PHP程序员的技术瓶颈分析
2011/07/17 PHP
PHP里的中文变量说明
2011/07/23 PHP
Symfony2学习笔记之控制器用法详解
2016/03/17 PHP
JS this作用域以及GET传输值过长的问题解决方法
2013/08/06 Javascript
Jquery实现侧边栏跟随滚动条固定(兼容IE6)
2014/04/02 Javascript
JavaScript组合拼接字符串的效率对比测试
2014/11/06 Javascript
jquery读取xml文件实现省市县三级联动的方法
2015/05/29 Javascript
javascript实现下班倒计时效果的方法(可桌面通知)
2015/07/10 Javascript
JS获取当前使用的浏览器名字以及版本号实现方法
2016/08/19 Javascript
jQuery 检查某个元素在页面上是否存在实例代码
2016/10/27 Javascript
Angular ng-repeat指令实例以及扩展部分
2016/12/26 Javascript
Ajax基础知识详解
2017/02/17 Javascript
javascript九宫格图片随机打乱位置的实现方法
2017/03/15 Javascript
vue2中的keep-alive使用总结及注意事项
2017/12/21 Javascript
js实现简单的打印表格
2020/01/15 Javascript
JQuery表单元素取值赋值方法总结
2020/05/12 jQuery
微信小程序实现底部弹出框
2020/11/18 Javascript
python 判断网络连通的实现方法
2018/04/22 Python
详解django的serializer序列化model几种方法
2018/10/16 Python
python实现翻转棋游戏(othello)
2019/07/29 Python
jupyter lab文件导出/下载方式
2020/04/22 Python
在pytorch中动态调整优化器的学习率方式
2020/06/24 Python
Python GUI库Tkiner使用方法代码示例
2020/11/27 Python
简单的HTML5初步入门教程
2015/09/29 HTML / CSS
瑞典时尚耳机品牌:Urbanears
2017/07/26 全球购物
关于.NET, HTML的五个问题
2012/08/29 面试题
营销与策划个人求职信
2013/09/22 职场文书
人力资源经理的岗位职责
2014/03/02 职场文书
2015年班组建设工作总结
2015/05/13 职场文书
孙振耀退休感言
2015/08/01 职场文书
2016三八妇女节校园广播稿
2015/12/17 职场文书
写作技巧:怎样写好一份优秀工作总结?
2019/08/14 职场文书
导游词之江苏同里古镇
2019/11/18 职场文书
高端收音机+蓝牙音箱,JBL TUNER FM带收音蓝牙音箱评测
2021/04/24 无线电
详解非极大值抑制算法之Python实现
2021/06/28 Python