PyTorch学习笔记之回归实战


Posted in Python onMay 28, 2018

本文主要是用PyTorch来实现一个简单的回归任务。

编辑器:spyder

1.引入相应的包及生成伪数据

import torch
import torch.nn.functional as F # 主要实现激活函数
import matplotlib.pyplot as plt # 绘图的工具
from torch.autograd import Variable

# 生成伪数据
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim = 1)
y = x.pow(2) + 0.2 * torch.rand(x.size())

# 变为Variable
x, y = Variable(x), Variable(y)

其中torch.linspace是为了生成连续间断的数据,第一个参数表示起点,第二个参数表示终点,第三个参数表示将这个区间分成平均几份,即生成几个数据。因为torch只能处理二维的数据,所以我们用torch.unsqueeze给伪数据添加一个维度,dim表示添加在第几维。torch.rand返回的是[0,1)之间的均匀分布。

2.绘制数据图像

在上述代码后面加下面的代码,然后运行可得伪数据的图形化表示:

# 绘制数据图像
plt.scatter(x.data.numpy(), y.data.numpy())
plt.show()

PyTorch学习笔记之回归实战

3.建立神经网络

class Net(torch.nn.Module):
 def __init__(self, n_feature, n_hidden, n_output):
  super(Net, self).__init__()
  self.hidden = torch.nn.Linear(n_feature, n_hidden) # hidden layer
  self.predict = torch.nn.Linear(n_hidden, n_output) # output layer

 def forward(self, x):
  x = F.relu(self.hidden(x))  # activation function for hidden layer
  x = self.predict(x)    # linear output
  return x

net = Net(n_feature=1, n_hidden=10, n_output=1)  # define the network
print(net) # net architecture

一般神经网络的类都继承自torch.nn.Module__init__()和forward()两个函数是自定义类的主要函数。在__init__()中都要添加一句super(Net, self).__init__(),这是固定的标准写法,用于继承父类的初始化函数。__init__()中只是对神经网络的模块进行了声明,真正的搭建是在forwad()中实现。自定义类中的成员都通过self指针来进行访问,所以参数列表中都包含了self。

如果想查看网络结构,可以用print()函数直接打印网络。本文的网络结构输出如下:

Net (
 (hidden): Linear (1 -> 10)
 (predict): Linear (10 -> 1)
)

4.训练网络

# 训练100次
for t in range(100):
 prediction = net(x)  # input x and predict based on x

 loss = loss_func(prediction, y)  # 一定要是输出在前,标签在后 (1. nn output, 2. target)

 optimizer.zero_grad() # clear gradients for next train
 loss.backward()   # backpropagation, compute gradients
 optimizer.step()  # apply gradients

训练网络之前我们需要先定义优化器和损失函数。torch.optim包中包括了各种优化器,这里我们选用最常见的SGD作为优化器。因为我们要对网络的参数进行优化,所以我们要把网络的参数net.parameters()传入优化器中,并设置学习率(一般小于1)。

由于这里是回归任务,我们选择torch.nn.MSELoss()作为损失函数。

由于优化器是基于梯度来优化参数的,并且梯度会保存在其中。所以在每次优化前要通过optimizer.zero_grad()把梯度置零,然后再后向传播及更新。

5.可视化训练过程

plt.ion() # something about plotting

for t in range(100):
 ...

 if t % 5 == 0:
  # plot and show learning process
  plt.cla()
  plt.scatter(x.data.numpy(), y.data.numpy())
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
  plt.text(0.5, 0, 'Loss=%.4f' % loss.data[0], fontdict={'size': 20, 'color': 'red'})
  plt.pause(0.1)

plt.ioff()
plt.show()

6.运行结果

PyTorch学习笔记之回归实战

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Python中使用next()方法操作文件的教程
May 24 Python
Python的Twisted框架中使用Deferred对象来管理回调函数
May 25 Python
Python合并多个Excel数据的方法
Jul 16 Python
python随机在一张图像上截取任意大小图片的方法
Jan 24 Python
如何在Python中实现goto语句的方法
May 18 Python
PYTHON绘制雷达图代码实例
Oct 15 Python
win10环境下配置vscode python开发环境的教程详解
Oct 16 Python
python计算二维矩形IOU实例
Jan 18 Python
matplotlib实现数据实时刷新的示例代码
Jan 05 Python
Python tkinter之ComboBox(下拉框)的使用简介
Feb 05 Python
详解Python小数据池和代码块缓存机制
Apr 07 Python
python numpy中multiply与*及matul 的区别说明
May 26 Python
Django 使用Ajax进行前后台交互的示例讲解
May 28 #Python
Python实现爬虫爬取NBA数据功能示例
May 28 #Python
Django+Ajax+jQuery实现网页动态更新的实例
May 28 #Python
Python实现合并两个列表的方法分析
May 28 #Python
django js实现部分页面刷新的示例代码
May 28 #Python
Django项目中用JS实现加载子页面并传值的方法
May 28 #Python
Python面向对象类继承和组合实例分析
May 28 #Python
You might like
PHP 和 XML: 使用expat函数(一)
2006/10/09 PHP
thinkphp实现163、QQ邮箱收发邮件的方法
2015/12/18 PHP
PHP内存缓存功能memcached示例
2016/10/19 PHP
用javascript动态调整iframe高度的方法
2007/03/06 Javascript
javascript 字符串连接的性能问题(多浏览器)
2008/11/18 Javascript
jquery表单验证插件(jquery.validate.js)的3种使用方式
2015/03/28 Javascript
javascript实现table选中的行以指定颜色高亮显示的方法
2015/05/13 Javascript
jQuery实现默认是闭合的FAQ展开效果菜单
2015/09/14 Javascript
浅谈addEventListener和attachEvent的区别
2016/07/14 Javascript
AngularJS 依赖注入详解及示例代码
2016/08/17 Javascript
Vuejs入门教程之Vue生命周期,数据,手动挂载,指令,过滤器
2017/04/19 Javascript
vue加载自定义的js文件方法
2018/03/13 Javascript
JS验证输入的是否是数字及保留几位小数问题
2018/05/09 Javascript
nodejs中express入门和基础知识点学习
2018/09/13 NodeJs
使用 electron 实现类似新版 QQ 的登录界面效果(阴影、背景动画、窗体3D翻转)
2018/10/23 Javascript
Python中AND、OR的一个使用小技巧
2015/02/18 Python
python绘制双柱形图代码实例
2017/12/14 Python
浅谈Django自定义模板标签template_tags的用处
2017/12/20 Python
flask框架使用orm连接数据库的方法示例
2018/07/16 Python
python被修饰的函数消失问题解决(基于wraps函数)
2019/11/04 Python
tensorflow模型转ncnn的操作方式
2020/05/25 Python
python函数map()和partial()的知识点总结
2020/05/26 Python
基于python实现操作git过程代码解析
2020/07/27 Python
Python图像识别+KNN求解数独的实现
2020/11/13 Python
python元组拆包实现方法
2021/02/28 Python
美国在线奢侈品寄售商店:Luxury Garage Sale
2018/08/19 全球购物
杭州-飞时达软件有限公司.net笔面试
2012/04/28 面试题
自荐信范文
2013/12/10 职场文书
优质服务活动实施方案
2014/05/02 职场文书
中层干部培训方案
2014/06/16 职场文书
焦裕禄精神心得体会
2014/09/02 职场文书
试用期员工工作自我评价
2014/09/10 职场文书
建筑质检员岗位职责
2015/04/08 职场文书
教师文明餐桌光盘行动倡议书
2015/04/28 职场文书
中学教师教学工作总结
2015/08/13 职场文书
在vue中import()语法不能传入变量的问题及解决
2022/04/01 Vue.js