pytorch-RNN进行回归曲线预测方式


Posted in Python onJanuary 14, 2020

任务

通过输入的sin曲线与预测出对应的cos曲线

#初始加载包 和定义参数
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
 
torch.manual_seed(1) #为了可复现
 
#超参数设定
TIME_SETP=10
INPUT_SIZE=1
LR=0.02
DOWNLoad_MNIST=True

定义RNN网络结构

from torch.autograd import Variable
class RNN(nn.Module):
  def __init__(self):
    #在这个函数中,两步走,先init,再逐步定义层结构
    super(RNN,self).__init__()
    
    self.rnn=nn.RNN(  #定义32隐层的rnn结构
     input_size=1,  
     hidden_size=32, #隐层有32个记忆体
     num_layers=1,   #隐层层数是1
     batch_first=True 
    )
    
    self.out=nn.Linear(32,1) #32个记忆体对应一个输出
  
  def forward(self,x,h_state):
    #前向过程,获取 rnn网络输出r_put(注意这里r_out并不是最后输出,最后要经过全连接层) 和 记忆体情况h_state
    r_out,h_state=self.rnn(x,h_state)    
    outs=[]#获取所有时间点下得到的预测值
    for time_step in range(r_out.size(1)): #将记忆rnn层的输出传到全连接层来得到最终输出。 这样每个输入对应一个输出,所以会有长度为10的输出
      outs.append(self.out(r_out[:,time_step,:]))
    return torch.stack(outs,dim=1),h_state #将10个数 通过stack方式压缩在一起
 
rnn=RNN()
print('RNN的网络体系结构为:',rnn)

pytorch-RNN进行回归曲线预测方式

创建数据集及网络训练

以sin曲线为特征,以cos曲线为标签进行网络的训练

#定义优化器和 损失函数
optimizer=torch.optim.Adam(rnn.parameters(),lr=LR)
loss_fun=nn.MSELoss()
h_state=None #记录的隐藏层状态,记住这就是记忆体,初始时候为空,之后每次后面的都会使用到前面的记忆,自动生成全0的
       #这样加入记忆信息后,每次都会在之前的记忆矩阵基础上再进行新的训练,初始是全0的形式。
#启动训练,这里假定训练的批次为100次
 
 
plt.ion() #可以设定持续不断的绘图,但是在这里看还是间断的,这是jupyter的问题
for step in range(100):
  #我们以一个π为一个时间步  定义数据,
  start,end=step*np.pi,(step+1)*np.pi
  
  steps=np.linspace(start,end,10,dtype=np.float32) #注意这里的10并不是间隔为10,而是将数按范围分成10等分了
  
  x_np=np.sin(steps)
  y_np=np.cos(steps)
  #将numpy类型转成torch类型  *****当需要 求梯度时,一个 op 的两个输入都必须是要 Variable,输入的一定要variable包下
  x=Variable(torch.from_numpy(x_np[np.newaxis,:,np.newaxis]))#增加两个维度,是三维的数据。
  y=Variable(torch.from_numpy(y_np[np.newaxis,:,np.newaxis]))
  
  #将每个时间步上的10个值 输入到rnn获得结果   这里rnn会自动执行forward前向过程. 这里输入时10个,输出也是10个,传递的是一个长度为32的记忆体
  predition,h_state=rnn(x,h_state)
  
  #更新新的中间状态
  h_state=Variable(h_state.data)  #擦,这点一定要从新包装
  loss=loss_fun(predition,y)
  #print('loss:',loss)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  
  
  # plotting  画图,这里先平展了 flatten,这样就是得到一个数组,更加直接
  
  plt.plot(steps, y_np.flatten(), 'r-')
  plt.plot(steps, predition.data.numpy().flatten(), 'b-')
  #plt.draw(); 
  plt.pause(0.05)
 
plt.ioff() #关闭交互模式
plt.show()

pytorch-RNN进行回归曲线预测方式

以上这篇pytorch-RNN进行回归曲线预测方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 自动化将markdown文件转成html文件的方法
Sep 23 Python
Python实现Linux的find命令实例分享
Jun 04 Python
Python之Scrapy爬虫框架安装及使用详解
Nov 16 Python
机器学习的框架偏向于Python的13个原因
Dec 07 Python
Python实现KNN邻近算法
Jan 28 Python
对numpy中的数组条件筛选功能详解
Jul 02 Python
Python使用post及get方式提交数据的实例
Jan 24 Python
python redis连接 有序集合去重的代码
Aug 04 Python
Python异常模块traceback用法实例分析
Oct 22 Python
使用TensorFlow-Slim进行图像分类的实现
Dec 31 Python
python opencv如何实现图片绘制
Jan 19 Python
Python预测2020高考分数和录取情况
Jul 08 Python
利用pytorch实现对CIFAR-10数据集的分类
Jan 14 #Python
pytorch下使用LSTM神经网络写诗实例
Jan 14 #Python
python使用openCV遍历文件夹里所有视频文件并保存成图片
Jan 14 #Python
pytorch实现mnist数据集的图像可视化及保存
Jan 14 #Python
Pytorch在dataloader类中设置shuffle的随机数种子方式
Jan 14 #Python
python3.7通过thrift操作hbase的示例代码
Jan 14 #Python
解决pytorch DataLoader num_workers出现的问题
Jan 14 #Python
You might like
来自PHP.NET的入门教程
2006/10/09 PHP
PHP strtr() 函数使用说明
2008/11/21 PHP
让PHP以ROOT权限执行系统命令的方法
2011/02/10 PHP
优化PHP程序的方法小结
2012/02/23 PHP
php获取数据库中数据的实现方法
2017/06/01 PHP
因str_replace导致的注入问题总结
2019/08/08 PHP
PHP基于进程控制函数实现多线程
2020/12/09 PHP
用javascript实现兼容IE7的类库 IE7_0_9.zip提供下载
2007/08/08 Javascript
ExtJs3.0中Store添加 baseParams 的Bug
2010/03/10 Javascript
浅谈JS日期(Date)处理函数
2014/12/07 Javascript
使用postMesssage()实现跨域iframe页面间的信息传递方法
2016/03/29 Javascript
轻松掌握JavaScript状态模式
2016/09/07 Javascript
js以分隔符分隔数组中的元素并转换为字符串的方法
2016/11/16 Javascript
angularjs2 ng2 密码隐藏显示的实例代码
2017/08/01 Javascript
Angular使用过滤器uppercase/lowercase实现字母大小写转换功能示例
2018/03/27 Javascript
JS使用setInterval实现的简单计时器功能示例
2018/04/19 Javascript
JavaScript常用数学函数用法示例
2018/05/14 Javascript
微信小程序自定义对话框弹出和隐藏动画
2018/07/19 Javascript
详解JavaScript之ES5的继承
2020/07/08 Javascript
Python的动态重新封装的教程
2015/04/11 Python
python smtplib模块自动收发邮件功能(一)
2018/05/22 Python
python bmp转换为jpg 并删除原图的方法
2018/10/25 Python
wxPython实现分隔窗口
2019/11/19 Python
Jupyter notebook如何修改平台字体
2020/05/13 Python
keras多显卡训练方式
2020/06/10 Python
浅谈keras 模型用于预测时的注意事项
2020/06/27 Python
详解Python中的路径问题
2020/09/02 Python
整理HTML5移动端开发的常用触摸事件
2016/04/15 HTML / CSS
英国排名第一的最新设计师品牌手表独立零售商:TIC Watches
2016/09/24 全球购物
德国家具折扣店:POCO
2020/02/28 全球购物
工程造价自荐信
2013/10/09 职场文书
模范家庭事迹材料
2014/02/10 职场文书
银行简历自我评价
2014/02/11 职场文书
2014年教研组工作总结
2014/11/26 职场文书
主持人大赛开场白
2015/05/29 职场文书
2016元旦主持人开场白
2015/12/03 职场文书