PyTorch学习笔记之回归实战


Posted in Python onMay 28, 2018

本文主要是用PyTorch来实现一个简单的回归任务。

编辑器:spyder

1.引入相应的包及生成伪数据

import torch
import torch.nn.functional as F # 主要实现激活函数
import matplotlib.pyplot as plt # 绘图的工具
from torch.autograd import Variable

# 生成伪数据
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim = 1)
y = x.pow(2) + 0.2 * torch.rand(x.size())

# 变为Variable
x, y = Variable(x), Variable(y)

其中torch.linspace是为了生成连续间断的数据,第一个参数表示起点,第二个参数表示终点,第三个参数表示将这个区间分成平均几份,即生成几个数据。因为torch只能处理二维的数据,所以我们用torch.unsqueeze给伪数据添加一个维度,dim表示添加在第几维。torch.rand返回的是[0,1)之间的均匀分布。

2.绘制数据图像

在上述代码后面加下面的代码,然后运行可得伪数据的图形化表示:

# 绘制数据图像
plt.scatter(x.data.numpy(), y.data.numpy())
plt.show()

PyTorch学习笔记之回归实战

3.建立神经网络

class Net(torch.nn.Module):
 def __init__(self, n_feature, n_hidden, n_output):
  super(Net, self).__init__()
  self.hidden = torch.nn.Linear(n_feature, n_hidden) # hidden layer
  self.predict = torch.nn.Linear(n_hidden, n_output) # output layer

 def forward(self, x):
  x = F.relu(self.hidden(x))  # activation function for hidden layer
  x = self.predict(x)    # linear output
  return x

net = Net(n_feature=1, n_hidden=10, n_output=1)  # define the network
print(net) # net architecture

一般神经网络的类都继承自torch.nn.Module__init__()和forward()两个函数是自定义类的主要函数。在__init__()中都要添加一句super(Net, self).__init__(),这是固定的标准写法,用于继承父类的初始化函数。__init__()中只是对神经网络的模块进行了声明,真正的搭建是在forwad()中实现。自定义类中的成员都通过self指针来进行访问,所以参数列表中都包含了self。

如果想查看网络结构,可以用print()函数直接打印网络。本文的网络结构输出如下:

Net (
 (hidden): Linear (1 -> 10)
 (predict): Linear (10 -> 1)
)

4.训练网络

# 训练100次
for t in range(100):
 prediction = net(x)  # input x and predict based on x

 loss = loss_func(prediction, y)  # 一定要是输出在前,标签在后 (1. nn output, 2. target)

 optimizer.zero_grad() # clear gradients for next train
 loss.backward()   # backpropagation, compute gradients
 optimizer.step()  # apply gradients

训练网络之前我们需要先定义优化器和损失函数。torch.optim包中包括了各种优化器,这里我们选用最常见的SGD作为优化器。因为我们要对网络的参数进行优化,所以我们要把网络的参数net.parameters()传入优化器中,并设置学习率(一般小于1)。

由于这里是回归任务,我们选择torch.nn.MSELoss()作为损失函数。

由于优化器是基于梯度来优化参数的,并且梯度会保存在其中。所以在每次优化前要通过optimizer.zero_grad()把梯度置零,然后再后向传播及更新。

5.可视化训练过程

plt.ion() # something about plotting

for t in range(100):
 ...

 if t % 5 == 0:
  # plot and show learning process
  plt.cla()
  plt.scatter(x.data.numpy(), y.data.numpy())
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
  plt.text(0.5, 0, 'Loss=%.4f' % loss.data[0], fontdict={'size': 20, 'color': 'red'})
  plt.pause(0.1)

plt.ioff()
plt.show()

6.运行结果

PyTorch学习笔记之回归实战

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用cookielib库示例分享
Mar 03 Python
使用django-suit为django 1.7 admin后台添加模板
Nov 18 Python
python算法演练_One Rule 算法(详解)
May 17 Python
python flask实现分页的示例代码
Aug 02 Python
小白入门篇使用Python搭建点击率预估模型
Oct 12 Python
python爬虫实现中英翻译词典
Jun 25 Python
python实现的汉诺塔算法示例
Oct 23 Python
浅谈django 模型类使用save()方法的好处与注意事项
Mar 28 Python
Python基于模块Paramiko实现SSHv2协议
Apr 28 Python
Python把图片转化为pdf代码实例
Jul 28 Python
python中绕过反爬虫的方法总结
Nov 25 Python
Python爬虫实战案例之爬取喜马拉雅音频数据详解
Dec 07 Python
Django 使用Ajax进行前后台交互的示例讲解
May 28 #Python
Python实现爬虫爬取NBA数据功能示例
May 28 #Python
Django+Ajax+jQuery实现网页动态更新的实例
May 28 #Python
Python实现合并两个列表的方法分析
May 28 #Python
django js实现部分页面刷新的示例代码
May 28 #Python
Django项目中用JS实现加载子页面并传值的方法
May 28 #Python
Python面向对象类继承和组合实例分析
May 28 #Python
You might like
PHP中魔术变量__METHOD__与__FUNCTION__的区别
2014/09/29 PHP
laravel 5.4中实现无限级分类的方法示例
2017/07/27 PHP
php语言注释,单行注释和多行注释
2018/01/21 PHP
PHP如何使用cURL实现Get和Post请求
2020/07/11 PHP
javascript语句中的CDATA标签的意义
2007/05/09 Javascript
js左侧多级菜单动态的解决方案
2010/02/01 Javascript
jquery插件冲突(jquery.noconflict)解决方法分享
2014/03/20 Javascript
JS继承用法实例分析
2015/02/05 Javascript
详解javascript中的事件处理
2015/11/06 Javascript
原生js配合cookie制作保存路径的拖拽
2015/12/29 Javascript
值得分享和收藏的Bootstrap学习教程
2016/05/12 Javascript
javascript中使用未定义变量或值的情况分析
2016/07/19 Javascript
Js 获取、判断浏览器版本信息的简单方法
2016/08/08 Javascript
在原生不支持的旧环境中添加兼容的Object.keys实现方法
2017/09/11 Javascript
Node.js中,在cmd界面,进入退出Node.js运行环境的方法
2018/05/12 Javascript
怎样在vue项目下添加ESLint的方法
2019/05/16 Javascript
vue+element 模态框表格形式的可编辑表单实现
2019/06/07 Javascript
vue+elementui 对话框取消 表单验证重置示例
2019/10/29 Javascript
vue 实现tab切换保持数据状态
2020/07/21 Javascript
Python的Django中django-userena组件的简单使用教程
2015/05/30 Python
Python pass详细介绍及实例代码
2016/11/24 Python
Python生成随机数组的方法小结
2017/04/15 Python
深入理解Django中内置的用户认证
2017/10/06 Python
单利模式及python实现方式详解
2018/03/20 Python
python中的decorator的作用详解
2018/07/26 Python
在unittest中使用 logging 模块记录测试数据的方法
2018/11/30 Python
VSCode基础使用与VSCode调试python程序入门的图文教程
2020/03/30 Python
Python是怎样处理json模块的
2020/07/16 Python
如何使用pycharm连接Databricks的步骤详解
2020/09/23 Python
Html5 语法与规则简要概述
2014/07/29 HTML / CSS
简历中个人自我评价分享
2014/03/15 职场文书
抗震救灾标语
2014/06/26 职场文书
清明节扫墓活动总结
2015/02/09 职场文书
vue响应式原理与双向数据的深入解析
2021/06/04 Vue.js
Windows 11要来了?微软文档揭示Win11太阳谷 / Win10有两个不同版本
2021/11/21 数码科技
CSS使用伪类控制边框长度的方法
2022/01/18 HTML / CSS