PyTorch学习笔记之回归实战


Posted in Python onMay 28, 2018

本文主要是用PyTorch来实现一个简单的回归任务。

编辑器:spyder

1.引入相应的包及生成伪数据

import torch
import torch.nn.functional as F # 主要实现激活函数
import matplotlib.pyplot as plt # 绘图的工具
from torch.autograd import Variable

# 生成伪数据
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim = 1)
y = x.pow(2) + 0.2 * torch.rand(x.size())

# 变为Variable
x, y = Variable(x), Variable(y)

其中torch.linspace是为了生成连续间断的数据,第一个参数表示起点,第二个参数表示终点,第三个参数表示将这个区间分成平均几份,即生成几个数据。因为torch只能处理二维的数据,所以我们用torch.unsqueeze给伪数据添加一个维度,dim表示添加在第几维。torch.rand返回的是[0,1)之间的均匀分布。

2.绘制数据图像

在上述代码后面加下面的代码,然后运行可得伪数据的图形化表示:

# 绘制数据图像
plt.scatter(x.data.numpy(), y.data.numpy())
plt.show()

PyTorch学习笔记之回归实战

3.建立神经网络

class Net(torch.nn.Module):
 def __init__(self, n_feature, n_hidden, n_output):
  super(Net, self).__init__()
  self.hidden = torch.nn.Linear(n_feature, n_hidden) # hidden layer
  self.predict = torch.nn.Linear(n_hidden, n_output) # output layer

 def forward(self, x):
  x = F.relu(self.hidden(x))  # activation function for hidden layer
  x = self.predict(x)    # linear output
  return x

net = Net(n_feature=1, n_hidden=10, n_output=1)  # define the network
print(net) # net architecture

一般神经网络的类都继承自torch.nn.Module__init__()和forward()两个函数是自定义类的主要函数。在__init__()中都要添加一句super(Net, self).__init__(),这是固定的标准写法,用于继承父类的初始化函数。__init__()中只是对神经网络的模块进行了声明,真正的搭建是在forwad()中实现。自定义类中的成员都通过self指针来进行访问,所以参数列表中都包含了self。

如果想查看网络结构,可以用print()函数直接打印网络。本文的网络结构输出如下:

Net (
 (hidden): Linear (1 -> 10)
 (predict): Linear (10 -> 1)
)

4.训练网络

# 训练100次
for t in range(100):
 prediction = net(x)  # input x and predict based on x

 loss = loss_func(prediction, y)  # 一定要是输出在前,标签在后 (1. nn output, 2. target)

 optimizer.zero_grad() # clear gradients for next train
 loss.backward()   # backpropagation, compute gradients
 optimizer.step()  # apply gradients

训练网络之前我们需要先定义优化器和损失函数。torch.optim包中包括了各种优化器,这里我们选用最常见的SGD作为优化器。因为我们要对网络的参数进行优化,所以我们要把网络的参数net.parameters()传入优化器中,并设置学习率(一般小于1)。

由于这里是回归任务,我们选择torch.nn.MSELoss()作为损失函数。

由于优化器是基于梯度来优化参数的,并且梯度会保存在其中。所以在每次优化前要通过optimizer.zero_grad()把梯度置零,然后再后向传播及更新。

5.可视化训练过程

plt.ion() # something about plotting

for t in range(100):
 ...

 if t % 5 == 0:
  # plot and show learning process
  plt.cla()
  plt.scatter(x.data.numpy(), y.data.numpy())
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
  plt.text(0.5, 0, 'Loss=%.4f' % loss.data[0], fontdict={'size': 20, 'color': 'red'})
  plt.pause(0.1)

plt.ioff()
plt.show()

6.运行结果

PyTorch学习笔记之回归实战

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python压缩和解压缩zip文件
Feb 14 Python
给Python入门者的一些编程建议
Jun 15 Python
Python实现的密码强度检测器示例
Aug 23 Python
Python针对给定列表中元素进行翻转操作的方法分析
Apr 27 Python
基于python requests库中的代理实例讲解
May 07 Python
Python二进制串转换为通用字符串的方法
Jul 23 Python
tensorflow学习教程之文本分类详析
Aug 07 Python
Django组件之cookie与session的使用方法
Jan 10 Python
python从list列表中选出一个数和其对应的坐标方法
Jul 20 Python
python和pywin32实现窗口查找、遍历和点击的示例代码
Apr 01 Python
Django {{ MEDIA_URL }}无法显示图片的解决方式
Apr 07 Python
pycharm实现猜数游戏
Dec 07 Python
Django 使用Ajax进行前后台交互的示例讲解
May 28 #Python
Python实现爬虫爬取NBA数据功能示例
May 28 #Python
Django+Ajax+jQuery实现网页动态更新的实例
May 28 #Python
Python实现合并两个列表的方法分析
May 28 #Python
django js实现部分页面刷新的示例代码
May 28 #Python
Django项目中用JS实现加载子页面并传值的方法
May 28 #Python
Python面向对象类继承和组合实例分析
May 28 #Python
You might like
用PHP去掉文件头的Unicode签名(BOM)方法
2017/06/22 PHP
phpfpm的作用和用法
2019/10/10 PHP
JS继承--原型链继承和类式继承
2013/04/08 Javascript
使用js检测浏览器的实现代码
2013/05/14 Javascript
jquery中页面Ajax方法$.load的功能使用介绍
2014/10/20 Javascript
Javascript中call和apply函数的比较和使用实例
2015/02/03 Javascript
JavaScript中Number.NEGATIVE_INFINITY值的使用详解
2015/06/05 Javascript
如何实现移动端浏览器不显示 pc 端的广告
2015/10/15 Javascript
谈一谈javascript闭包
2016/01/28 Javascript
浅析JS获取url中的参数实例代码
2016/06/14 Javascript
Kotlin学习第一步 kotlin语法特性
2017/05/25 Javascript
详解Require.js与Sea.js的区别
2018/08/05 Javascript
vue router 通过路由来实现切换头部标题功能
2019/04/24 Javascript
使用 js 简单的实现 bind、call 、aplly代码实例
2019/09/07 Javascript
[11:12]2018DOTA2国际邀请赛寻真——绿色长城OpTic
2018/08/10 DOTA
python 输出一个两行字符的变量
2009/02/05 Python
python 不关闭控制台的实现方法
2011/10/23 Python
python线程、进程和协程详解
2016/07/19 Python
python list排序的两种方法及实例讲解
2017/03/20 Python
Python中easy_install 和 pip 的安装及使用
2017/06/05 Python
python 阶乘累加和的实例
2019/02/01 Python
Python异常模块traceback用法实例分析
2019/10/22 Python
python GUI库图形界面开发之PyQt5单选按钮控件QRadioButton详细使用方法与实例
2020/02/28 Python
python网络编程:socketserver的基本使用方法实例分析
2020/04/09 Python
CSS3新属性transition-property transform box-shadow实例学习
2013/06/06 HTML / CSS
阿迪达斯芬兰官方网站:adidas芬兰
2017/01/30 全球购物
波兰最大的电商平台:Allegro.pl
2021/02/06 全球购物
创建索引时需要注意的事项
2013/05/13 面试题
夜大毕业生自我评价分享
2013/11/10 职场文书
酒店中秋节活动方案
2014/01/31 职场文书
团组织推优材料
2014/12/29 职场文书
2015年庆祝国庆节66周年演讲稿
2015/07/30 职场文书
高中同学会致辞
2015/08/01 职场文书
nginx对http请求处理的各个阶段详析
2021/03/31 Servers
常用的MongoDB查询语句的示例代码
2021/07/25 MongoDB
mysql中DCL常用的用户和权限控制
2022/03/31 MySQL