pytorch-RNN进行回归曲线预测方式


Posted in Python onJanuary 14, 2020

任务

通过输入的sin曲线与预测出对应的cos曲线

#初始加载包 和定义参数
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
 
torch.manual_seed(1) #为了可复现
 
#超参数设定
TIME_SETP=10
INPUT_SIZE=1
LR=0.02
DOWNLoad_MNIST=True

定义RNN网络结构

from torch.autograd import Variable
class RNN(nn.Module):
  def __init__(self):
    #在这个函数中,两步走,先init,再逐步定义层结构
    super(RNN,self).__init__()
    
    self.rnn=nn.RNN(  #定义32隐层的rnn结构
     input_size=1,  
     hidden_size=32, #隐层有32个记忆体
     num_layers=1,   #隐层层数是1
     batch_first=True 
    )
    
    self.out=nn.Linear(32,1) #32个记忆体对应一个输出
  
  def forward(self,x,h_state):
    #前向过程,获取 rnn网络输出r_put(注意这里r_out并不是最后输出,最后要经过全连接层) 和 记忆体情况h_state
    r_out,h_state=self.rnn(x,h_state)    
    outs=[]#获取所有时间点下得到的预测值
    for time_step in range(r_out.size(1)): #将记忆rnn层的输出传到全连接层来得到最终输出。 这样每个输入对应一个输出,所以会有长度为10的输出
      outs.append(self.out(r_out[:,time_step,:]))
    return torch.stack(outs,dim=1),h_state #将10个数 通过stack方式压缩在一起
 
rnn=RNN()
print('RNN的网络体系结构为:',rnn)

pytorch-RNN进行回归曲线预测方式

创建数据集及网络训练

以sin曲线为特征,以cos曲线为标签进行网络的训练

#定义优化器和 损失函数
optimizer=torch.optim.Adam(rnn.parameters(),lr=LR)
loss_fun=nn.MSELoss()
h_state=None #记录的隐藏层状态,记住这就是记忆体,初始时候为空,之后每次后面的都会使用到前面的记忆,自动生成全0的
       #这样加入记忆信息后,每次都会在之前的记忆矩阵基础上再进行新的训练,初始是全0的形式。
#启动训练,这里假定训练的批次为100次
 
 
plt.ion() #可以设定持续不断的绘图,但是在这里看还是间断的,这是jupyter的问题
for step in range(100):
  #我们以一个π为一个时间步  定义数据,
  start,end=step*np.pi,(step+1)*np.pi
  
  steps=np.linspace(start,end,10,dtype=np.float32) #注意这里的10并不是间隔为10,而是将数按范围分成10等分了
  
  x_np=np.sin(steps)
  y_np=np.cos(steps)
  #将numpy类型转成torch类型  *****当需要 求梯度时,一个 op 的两个输入都必须是要 Variable,输入的一定要variable包下
  x=Variable(torch.from_numpy(x_np[np.newaxis,:,np.newaxis]))#增加两个维度,是三维的数据。
  y=Variable(torch.from_numpy(y_np[np.newaxis,:,np.newaxis]))
  
  #将每个时间步上的10个值 输入到rnn获得结果   这里rnn会自动执行forward前向过程. 这里输入时10个,输出也是10个,传递的是一个长度为32的记忆体
  predition,h_state=rnn(x,h_state)
  
  #更新新的中间状态
  h_state=Variable(h_state.data)  #擦,这点一定要从新包装
  loss=loss_fun(predition,y)
  #print('loss:',loss)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  
  
  # plotting  画图,这里先平展了 flatten,这样就是得到一个数组,更加直接
  
  plt.plot(steps, y_np.flatten(), 'r-')
  plt.plot(steps, predition.data.numpy().flatten(), 'b-')
  #plt.draw(); 
  plt.pause(0.05)
 
plt.ioff() #关闭交互模式
plt.show()

pytorch-RNN进行回归曲线预测方式

以上这篇pytorch-RNN进行回归曲线预测方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Linux下编译安装MySQL-Python教程
Feb 02 Python
python3实现短网址和数字相互转换的方法
Apr 28 Python
基于Python实现文件大小输出
Jan 11 Python
Python中使用asyncio 封装文件读写
Sep 11 Python
python3实现字符串的全排列的方法(无重复字符)
Jul 07 Python
python os.path模块常用方法实例详解
Sep 16 Python
python爬虫获取新浪新闻教学
Dec 23 Python
详解Python3定时器任务代码
Sep 23 Python
jupyter notebook 的工作空间设置操作
Apr 20 Python
Pytest测试框架基本使用方法详解
Nov 25 Python
python 爬取英雄联盟皮肤并下载的示例
Dec 04 Python
详解Python中下划线的5种含义
Jul 15 Python
利用pytorch实现对CIFAR-10数据集的分类
Jan 14 #Python
pytorch下使用LSTM神经网络写诗实例
Jan 14 #Python
python使用openCV遍历文件夹里所有视频文件并保存成图片
Jan 14 #Python
pytorch实现mnist数据集的图像可视化及保存
Jan 14 #Python
Pytorch在dataloader类中设置shuffle的随机数种子方式
Jan 14 #Python
python3.7通过thrift操作hbase的示例代码
Jan 14 #Python
解决pytorch DataLoader num_workers出现的问题
Jan 14 #Python
You might like
PHP 七大优势分析
2009/06/23 PHP
PHP新手NOTICE错误常见解决方法
2011/12/07 PHP
php表单提交与$_POST实例分析
2015/01/26 PHP
php使用cookie显示用户上次访问网站日期的方法
2015/01/26 PHP
php实现改变图片直接打开为下载的方法
2015/04/14 PHP
php实现给二维数组中所有一维数组添加值的方法
2017/02/04 PHP
40个新鲜出炉的jQuery 插件和免费教程[上]
2012/07/24 Javascript
jquery使用ColorBox弹出图片组浏览层实例演示
2013/03/14 Javascript
JS 获取浏览器和屏幕宽高等信息的实现思路及代码
2013/07/31 Javascript
ParseInt函数参数设置介绍
2014/01/02 Javascript
js实现从右向左缓缓浮出网页浮动层广告的方法
2015/05/09 Javascript
JavaSacript中charCodeAt()方法的使用详解
2015/06/05 Javascript
jquery简单实现幻灯片的方法
2015/08/03 Javascript
js生成随机颜色方法代码分享(三种)
2016/12/29 Javascript
js仿搜狐视频记录片列表展示效果
2020/05/30 Javascript
Vue 进阶教程之v-model详解
2017/05/06 Javascript
react native 获取地理位置的方法示例
2018/08/28 Javascript
傻瓜式解读koa中间件处理模块koa-compose的使用
2018/10/30 Javascript
js实现聊天对话框
2020/02/08 Javascript
[46:50]Liquid vs Mineski 2018国际邀请赛小组赛BO2 第二场 8.18
2018/08/19 DOTA
python os.listdir按文件存取时间顺序列出目录的实例
2018/10/21 Python
python3实现网络爬虫之BeautifulSoup使用详解
2018/12/19 Python
PyQt5笔记之弹出窗口大全
2019/06/20 Python
pytorch打印网络结构的实例
2019/08/19 Python
详解如何在cmd命令窗口中搭建简单的python开发环境
2019/08/29 Python
python列表推导和生成器表达式知识点总结
2020/01/10 Python
演讲稿怎么写才完美
2014/01/02 职场文书
公证书样本
2014/04/10 职场文书
一年级小学生评语
2014/04/22 职场文书
2014年预备党员群众路线教育实践活动对照检查材料思想汇报
2014/10/02 职场文书
2014年预算员工作总结
2014/12/05 职场文书
学前班学生评语
2014/12/29 职场文书
战马观后感
2015/06/08 职场文书
利用JuiceFS使MySQL 备份验证性能提升 10 倍
2022/03/17 MySQL
Minikube搭建Kubernetes集群
2022/03/31 Servers
Pandas实现批量拆分与合并Excel的示例代码
2022/05/30 Python