pytorch-RNN进行回归曲线预测方式


Posted in Python onJanuary 14, 2020

任务

通过输入的sin曲线与预测出对应的cos曲线

#初始加载包 和定义参数
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
 
torch.manual_seed(1) #为了可复现
 
#超参数设定
TIME_SETP=10
INPUT_SIZE=1
LR=0.02
DOWNLoad_MNIST=True

定义RNN网络结构

from torch.autograd import Variable
class RNN(nn.Module):
  def __init__(self):
    #在这个函数中,两步走,先init,再逐步定义层结构
    super(RNN,self).__init__()
    
    self.rnn=nn.RNN(  #定义32隐层的rnn结构
     input_size=1,  
     hidden_size=32, #隐层有32个记忆体
     num_layers=1,   #隐层层数是1
     batch_first=True 
    )
    
    self.out=nn.Linear(32,1) #32个记忆体对应一个输出
  
  def forward(self,x,h_state):
    #前向过程,获取 rnn网络输出r_put(注意这里r_out并不是最后输出,最后要经过全连接层) 和 记忆体情况h_state
    r_out,h_state=self.rnn(x,h_state)    
    outs=[]#获取所有时间点下得到的预测值
    for time_step in range(r_out.size(1)): #将记忆rnn层的输出传到全连接层来得到最终输出。 这样每个输入对应一个输出,所以会有长度为10的输出
      outs.append(self.out(r_out[:,time_step,:]))
    return torch.stack(outs,dim=1),h_state #将10个数 通过stack方式压缩在一起
 
rnn=RNN()
print('RNN的网络体系结构为:',rnn)

pytorch-RNN进行回归曲线预测方式

创建数据集及网络训练

以sin曲线为特征,以cos曲线为标签进行网络的训练

#定义优化器和 损失函数
optimizer=torch.optim.Adam(rnn.parameters(),lr=LR)
loss_fun=nn.MSELoss()
h_state=None #记录的隐藏层状态,记住这就是记忆体,初始时候为空,之后每次后面的都会使用到前面的记忆,自动生成全0的
       #这样加入记忆信息后,每次都会在之前的记忆矩阵基础上再进行新的训练,初始是全0的形式。
#启动训练,这里假定训练的批次为100次
 
 
plt.ion() #可以设定持续不断的绘图,但是在这里看还是间断的,这是jupyter的问题
for step in range(100):
  #我们以一个π为一个时间步  定义数据,
  start,end=step*np.pi,(step+1)*np.pi
  
  steps=np.linspace(start,end,10,dtype=np.float32) #注意这里的10并不是间隔为10,而是将数按范围分成10等分了
  
  x_np=np.sin(steps)
  y_np=np.cos(steps)
  #将numpy类型转成torch类型  *****当需要 求梯度时,一个 op 的两个输入都必须是要 Variable,输入的一定要variable包下
  x=Variable(torch.from_numpy(x_np[np.newaxis,:,np.newaxis]))#增加两个维度,是三维的数据。
  y=Variable(torch.from_numpy(y_np[np.newaxis,:,np.newaxis]))
  
  #将每个时间步上的10个值 输入到rnn获得结果   这里rnn会自动执行forward前向过程. 这里输入时10个,输出也是10个,传递的是一个长度为32的记忆体
  predition,h_state=rnn(x,h_state)
  
  #更新新的中间状态
  h_state=Variable(h_state.data)  #擦,这点一定要从新包装
  loss=loss_fun(predition,y)
  #print('loss:',loss)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  
  
  # plotting  画图,这里先平展了 flatten,这样就是得到一个数组,更加直接
  
  plt.plot(steps, y_np.flatten(), 'r-')
  plt.plot(steps, predition.data.numpy().flatten(), 'b-')
  #plt.draw(); 
  plt.pause(0.05)
 
plt.ioff() #关闭交互模式
plt.show()

pytorch-RNN进行回归曲线预测方式

以上这篇pytorch-RNN进行回归曲线预测方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python打开url并按指定块读取网页内容的方法
Apr 29 Python
Python实现查询某个目录下修改时间最新的文件示例
Aug 29 Python
python hook监听事件详解
Oct 25 Python
python创造虚拟环境方法总结
Mar 04 Python
PyCharm搭建Spark开发环境实现第一个pyspark程序
Jun 13 Python
用python做游戏的细节详解
Jun 25 Python
python装饰器常见使用方法分析
Jun 26 Python
OpenCV 模板匹配
Jul 10 Python
python列表,字典,元组简单用法示例
Jul 11 Python
TensorFlow2.0:张量的合并与分割实例
Jan 19 Python
对tensorflow中tf.nn.conv1d和layers.conv1d的区别详解
Feb 11 Python
python 爬虫网页登陆的简单实现
Nov 30 Python
利用pytorch实现对CIFAR-10数据集的分类
Jan 14 #Python
pytorch下使用LSTM神经网络写诗实例
Jan 14 #Python
python使用openCV遍历文件夹里所有视频文件并保存成图片
Jan 14 #Python
pytorch实现mnist数据集的图像可视化及保存
Jan 14 #Python
Pytorch在dataloader类中设置shuffle的随机数种子方式
Jan 14 #Python
python3.7通过thrift操作hbase的示例代码
Jan 14 #Python
解决pytorch DataLoader num_workers出现的问题
Jan 14 #Python
You might like
CodeIgniter上传图片成功的全部过程分享
2013/08/12 PHP
PHP利用imagick生成组合缩略图
2016/02/19 PHP
Symfony模板的快捷变量用法实例
2016/03/17 PHP
PHP QRCODE生成彩色二维码的方法
2016/05/19 PHP
功能强大的php分页函数
2016/07/20 PHP
Laravel 框架基于自带的用户系统实现登录注册及错误处理功能分析
2020/04/14 PHP
javascript实现节点(div)名称编辑
2014/12/17 Javascript
chrome不支持form.submit的解决方案
2015/04/28 Javascript
jQuery表单验证功能实例
2015/08/28 Javascript
JavaScript缓冲运动实现方法(2则示例)
2016/01/08 Javascript
JavaScript评论点赞功能的实现方法
2017/03/13 Javascript
jQuery实现多张图片上传预览(不经过后端处理)
2017/04/29 jQuery
关于redux-saga中take使用方法详解
2018/02/27 Javascript
jQuery.validate.js表单验证插件的使用代码详解
2018/10/22 jQuery
微信内置开发 iOS修改键盘换行为搜索的解决方案
2019/11/06 Javascript
windows如何把已安装的nodejs高版本降级为低版本(图文教程)
2020/12/14 NodeJs
使用python在校内发人人网状态(人人网看状态)
2014/02/19 Python
python实现教务管理系统
2018/03/12 Python
利用Anaconda简单安装scrapy框架的方法
2018/06/13 Python
Python高级特性切片(Slice)操作详解
2018/09/27 Python
python 高效去重复 支持GB级别大文件的示例代码
2018/11/08 Python
JupyterNotebook设置Python环境的方法步骤
2019/12/03 Python
DataFrame 数据合并实现(merge,join,concat)
2020/06/14 Python
使用Python提取文本中含有特定字符串的方法示例
2020/12/09 Python
使用sublime text3搭建Python编辑环境的实现
2021/01/12 Python
css3 column实现卡片瀑布流布局的示例代码
2018/06/22 HTML / CSS
中海讯通笔试题
2015/09/15 面试题
幼儿园美术教学反思
2014/01/31 职场文书
人力资源部经理岗位职责规定
2014/02/23 职场文书
财产公证书样本
2014/04/04 职场文书
公路绿化方案
2014/05/12 职场文书
环境建议书
2015/02/04 职场文书
2015年党员岗位承诺书
2015/04/27 职场文书
幼儿园保教工作总结2015
2015/10/15 职场文书
Python djanjo之csrf防跨站攻击实验过程
2021/05/14 Python
配置Kubernetes外网访问集群
2022/03/31 Servers