PyTorch学习笔记之回归实战


Posted in Python onMay 28, 2018

本文主要是用PyTorch来实现一个简单的回归任务。

编辑器:spyder

1.引入相应的包及生成伪数据

import torch
import torch.nn.functional as F # 主要实现激活函数
import matplotlib.pyplot as plt # 绘图的工具
from torch.autograd import Variable

# 生成伪数据
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim = 1)
y = x.pow(2) + 0.2 * torch.rand(x.size())

# 变为Variable
x, y = Variable(x), Variable(y)

其中torch.linspace是为了生成连续间断的数据,第一个参数表示起点,第二个参数表示终点,第三个参数表示将这个区间分成平均几份,即生成几个数据。因为torch只能处理二维的数据,所以我们用torch.unsqueeze给伪数据添加一个维度,dim表示添加在第几维。torch.rand返回的是[0,1)之间的均匀分布。

2.绘制数据图像

在上述代码后面加下面的代码,然后运行可得伪数据的图形化表示:

# 绘制数据图像
plt.scatter(x.data.numpy(), y.data.numpy())
plt.show()

PyTorch学习笔记之回归实战

3.建立神经网络

class Net(torch.nn.Module):
 def __init__(self, n_feature, n_hidden, n_output):
  super(Net, self).__init__()
  self.hidden = torch.nn.Linear(n_feature, n_hidden) # hidden layer
  self.predict = torch.nn.Linear(n_hidden, n_output) # output layer

 def forward(self, x):
  x = F.relu(self.hidden(x))  # activation function for hidden layer
  x = self.predict(x)    # linear output
  return x

net = Net(n_feature=1, n_hidden=10, n_output=1)  # define the network
print(net) # net architecture

一般神经网络的类都继承自torch.nn.Module__init__()和forward()两个函数是自定义类的主要函数。在__init__()中都要添加一句super(Net, self).__init__(),这是固定的标准写法,用于继承父类的初始化函数。__init__()中只是对神经网络的模块进行了声明,真正的搭建是在forwad()中实现。自定义类中的成员都通过self指针来进行访问,所以参数列表中都包含了self。

如果想查看网络结构,可以用print()函数直接打印网络。本文的网络结构输出如下:

Net (
 (hidden): Linear (1 -> 10)
 (predict): Linear (10 -> 1)
)

4.训练网络

# 训练100次
for t in range(100):
 prediction = net(x)  # input x and predict based on x

 loss = loss_func(prediction, y)  # 一定要是输出在前,标签在后 (1. nn output, 2. target)

 optimizer.zero_grad() # clear gradients for next train
 loss.backward()   # backpropagation, compute gradients
 optimizer.step()  # apply gradients

训练网络之前我们需要先定义优化器和损失函数。torch.optim包中包括了各种优化器,这里我们选用最常见的SGD作为优化器。因为我们要对网络的参数进行优化,所以我们要把网络的参数net.parameters()传入优化器中,并设置学习率(一般小于1)。

由于这里是回归任务,我们选择torch.nn.MSELoss()作为损失函数。

由于优化器是基于梯度来优化参数的,并且梯度会保存在其中。所以在每次优化前要通过optimizer.zero_grad()把梯度置零,然后再后向传播及更新。

5.可视化训练过程

plt.ion() # something about plotting

for t in range(100):
 ...

 if t % 5 == 0:
  # plot and show learning process
  plt.cla()
  plt.scatter(x.data.numpy(), y.data.numpy())
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
  plt.text(0.5, 0, 'Loss=%.4f' % loss.data[0], fontdict={'size': 20, 'color': 'red'})
  plt.pause(0.1)

plt.ioff()
plt.show()

6.运行结果

PyTorch学习笔记之回归实战

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现备份文件实例
Sep 16 Python
实例解析Python中的__new__特殊方法
Jun 02 Python
Python利用matplotlib生成图片背景及图例透明的效果
Apr 27 Python
python实现自动登录
Sep 17 Python
pyqt5与matplotlib的完美结合实例
Jun 21 Python
Django框架视图层URL映射与反向解析实例分析
Jul 29 Python
Flask框架钩子函数功能与用法分析
Aug 02 Python
python实现飞行棋游戏
Feb 05 Python
基于Django OneToOneField和ForeignKey的区别详解
Mar 30 Python
python求前n个阶乘的和实例
Apr 02 Python
Python迭代器协议及for循环工作机制详解
Jul 14 Python
python zip()函数的使用示例
Sep 23 Python
Django 使用Ajax进行前后台交互的示例讲解
May 28 #Python
Python实现爬虫爬取NBA数据功能示例
May 28 #Python
Django+Ajax+jQuery实现网页动态更新的实例
May 28 #Python
Python实现合并两个列表的方法分析
May 28 #Python
django js实现部分页面刷新的示例代码
May 28 #Python
Django项目中用JS实现加载子页面并传值的方法
May 28 #Python
Python面向对象类继承和组合实例分析
May 28 #Python
You might like
yii2整合百度编辑器umeditor及umeditor图片上传问题的解决办法
2016/04/20 PHP
Laravel5.5以下版本中如何自定义日志行为详解
2018/08/01 PHP
分析Node.js connect ECONNREFUSED错误
2013/04/09 Javascript
jquery.form.js用法之清空form的方法
2014/03/07 Javascript
Jquery选择器中使用变量实现动态选择例子
2014/07/25 Javascript
JS简单实现城市二级联动选择插件的方法
2015/08/19 Javascript
jquery实现树形菜单完整代码
2015/12/29 Javascript
JavaScript实现复制或剪切内容到剪贴板功能的方法
2016/05/23 Javascript
jQuery插件echarts设置折线图中折线线条颜色和折线点颜色的方法
2017/03/03 Javascript
简单快速的实现js计算器功能
2017/08/17 Javascript
使用vue2.6实现抖音【时间轮盘】屏保效果附源码
2019/04/24 Javascript
vue+element项目中过滤输入框特殊字符小结
2019/08/07 Javascript
JS实现时间校验的代码
2020/05/25 Javascript
python实现从web抓取文档的方法
2014/09/26 Python
Python对文件和目录进行操作的方法(file对象/os/os.path/shutil 模块)
2017/05/08 Python
Python实现发送QQ邮件的封装
2017/07/14 Python
python中正则表达式的使用方法
2018/02/25 Python
python 分离文件名和路径以及分离文件名和后缀的方法
2018/10/21 Python
Python的log日志功能及设置方法
2019/07/11 Python
关于阿里云oss获取sts凭证 app直传 python的实例
2019/08/20 Python
用python画一只可爱的皮卡丘实例
2019/11/21 Python
jupyter lab文件导出/下载方式
2020/04/22 Python
sklearn线性逻辑回归和非线性逻辑回归的实现
2020/06/09 Python
Django启动时找不到mysqlclient问题解决方案
2020/11/11 Python
python中os.remove()用法及注意事项
2021/01/31 Python
使用HTML5 Canvas为图片填充颜色和纹理的教程
2016/03/21 HTML / CSS
斯巴达比赛商店:Spartan Race
2019/01/08 全球购物
世界排名第一的运动鞋市场:Flight Club
2020/01/03 全球购物
sealed修饰符是干什么的
2012/10/23 面试题
思想汇报格式
2014/01/05 职场文书
大学生村官典型材料
2014/01/12 职场文书
爱国演讲稿400字
2014/05/07 职场文书
没有孩子的离婚协议书怎么写
2014/09/17 职场文书
六种css3实现的边框过渡效果
2021/04/22 HTML / CSS
仅仅使用 HTML/CSS 实现各类进度条的方式汇总
2021/11/11 HTML / CSS
PC版《死亡搁浅导剪版》现已发售 展开全新的探险
2022/04/03 其他游戏