PyTorch学习笔记之回归实战


Posted in Python onMay 28, 2018

本文主要是用PyTorch来实现一个简单的回归任务。

编辑器:spyder

1.引入相应的包及生成伪数据

import torch
import torch.nn.functional as F # 主要实现激活函数
import matplotlib.pyplot as plt # 绘图的工具
from torch.autograd import Variable

# 生成伪数据
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim = 1)
y = x.pow(2) + 0.2 * torch.rand(x.size())

# 变为Variable
x, y = Variable(x), Variable(y)

其中torch.linspace是为了生成连续间断的数据,第一个参数表示起点,第二个参数表示终点,第三个参数表示将这个区间分成平均几份,即生成几个数据。因为torch只能处理二维的数据,所以我们用torch.unsqueeze给伪数据添加一个维度,dim表示添加在第几维。torch.rand返回的是[0,1)之间的均匀分布。

2.绘制数据图像

在上述代码后面加下面的代码,然后运行可得伪数据的图形化表示:

# 绘制数据图像
plt.scatter(x.data.numpy(), y.data.numpy())
plt.show()

PyTorch学习笔记之回归实战

3.建立神经网络

class Net(torch.nn.Module):
 def __init__(self, n_feature, n_hidden, n_output):
  super(Net, self).__init__()
  self.hidden = torch.nn.Linear(n_feature, n_hidden) # hidden layer
  self.predict = torch.nn.Linear(n_hidden, n_output) # output layer

 def forward(self, x):
  x = F.relu(self.hidden(x))  # activation function for hidden layer
  x = self.predict(x)    # linear output
  return x

net = Net(n_feature=1, n_hidden=10, n_output=1)  # define the network
print(net) # net architecture

一般神经网络的类都继承自torch.nn.Module__init__()和forward()两个函数是自定义类的主要函数。在__init__()中都要添加一句super(Net, self).__init__(),这是固定的标准写法,用于继承父类的初始化函数。__init__()中只是对神经网络的模块进行了声明,真正的搭建是在forwad()中实现。自定义类中的成员都通过self指针来进行访问,所以参数列表中都包含了self。

如果想查看网络结构,可以用print()函数直接打印网络。本文的网络结构输出如下:

Net (
 (hidden): Linear (1 -> 10)
 (predict): Linear (10 -> 1)
)

4.训练网络

# 训练100次
for t in range(100):
 prediction = net(x)  # input x and predict based on x

 loss = loss_func(prediction, y)  # 一定要是输出在前,标签在后 (1. nn output, 2. target)

 optimizer.zero_grad() # clear gradients for next train
 loss.backward()   # backpropagation, compute gradients
 optimizer.step()  # apply gradients

训练网络之前我们需要先定义优化器和损失函数。torch.optim包中包括了各种优化器,这里我们选用最常见的SGD作为优化器。因为我们要对网络的参数进行优化,所以我们要把网络的参数net.parameters()传入优化器中,并设置学习率(一般小于1)。

由于这里是回归任务,我们选择torch.nn.MSELoss()作为损失函数。

由于优化器是基于梯度来优化参数的,并且梯度会保存在其中。所以在每次优化前要通过optimizer.zero_grad()把梯度置零,然后再后向传播及更新。

5.可视化训练过程

plt.ion() # something about plotting

for t in range(100):
 ...

 if t % 5 == 0:
  # plot and show learning process
  plt.cla()
  plt.scatter(x.data.numpy(), y.data.numpy())
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
  plt.text(0.5, 0, 'Loss=%.4f' % loss.data[0], fontdict={'size': 20, 'color': 'red'})
  plt.pause(0.1)

plt.ioff()
plt.show()

6.运行结果

PyTorch学习笔记之回归实战

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中zfill()方法的使用教程
May 20 Python
Python数据操作方法封装类实例
Jun 23 Python
Python set常用操作函数集锦
Nov 15 Python
VSCode下配置python调试运行环境的方法
Apr 06 Python
Python实现定时精度可调节的定时器
Apr 15 Python
Python实现html转换为pdf报告(生成pdf报告)功能示例
May 04 Python
python画蝴蝶曲线图的实例
Nov 21 Python
windows环境中利用celery实现简单任务队列过程解析
Nov 29 Python
python实现飞机大战游戏(pygame版)
Oct 26 Python
python代码实现TSNE降维数据可视化教程
Feb 28 Python
Python实现画图软件功能方法详解
Jul 28 Python
使用Djongo模块在Django中使用MongoDB数据库
Jun 20 Python
Django 使用Ajax进行前后台交互的示例讲解
May 28 #Python
Python实现爬虫爬取NBA数据功能示例
May 28 #Python
Django+Ajax+jQuery实现网页动态更新的实例
May 28 #Python
Python实现合并两个列表的方法分析
May 28 #Python
django js实现部分页面刷新的示例代码
May 28 #Python
Django项目中用JS实现加载子页面并传值的方法
May 28 #Python
Python面向对象类继承和组合实例分析
May 28 #Python
You might like
一个简单计数器的源代码
2006/10/09 PHP
php 禁止页面缓存输出
2009/01/07 PHP
php高级编程-函数-郑阿奇
2011/07/04 PHP
THINKPHP项目开发中的日志记录实例分析
2014/12/01 PHP
php实现指定字符串中查找子字符串的方法
2015/03/17 PHP
PHP lcfirst()函数定义与用法
2019/03/08 PHP
input、button的不同type值在ajax提交表单时导致的陷阱
2009/02/24 Javascript
javascript 禁用IE工具栏,导航栏等等实现代码
2013/04/01 Javascript
jsp+javascript打造级连菜单的实例代码
2013/06/14 Javascript
javascript删除数组元素并且数组长度减小的简单实例
2014/02/14 Javascript
介绍JavaScript的一个微型模版
2015/06/24 Javascript
全系IE支持Bootstrap的解决方法
2015/10/19 Javascript
js显示当前日期时间和星期几
2015/10/22 Javascript
JS仿京东移动端手指拨动切换轮播图效果
2020/04/10 Javascript
Bootstrap 下拉多选框插件Bootstrap Multiselect
2017/01/22 Javascript
js求数组中全部数字可拼接出的最大整数示例代码
2017/08/25 Javascript
react-native DatePicker日期选择组件的实现代码
2017/09/12 Javascript
微信小程序之swiper轮播图中的图片自适应高度的方法
2018/04/23 Javascript
JS加密插件CryptoJS实现AES加密操作示例
2018/08/16 Javascript
详解react-refetch的使用小例子
2019/02/15 Javascript
[53:52]EG vs VGJ.T 2018国际邀请赛小组赛BO2 第一场 8.16
2018/08/17 DOTA
linux系统使用python获取内存使用信息脚本分享
2014/01/15 Python
Python中functools模块的常用函数解析
2016/06/30 Python
python中快速进行多个字符替换的方法小结
2016/12/15 Python
Python2/3中urllib库的一些常见用法
2017/12/19 Python
PyQt5每天必学之工具提示功能
2018/04/19 Python
Python异常的检测和处理方法
2018/10/26 Python
python绘制双Y轴折线图以及单Y轴双变量柱状图的实例
2019/07/08 Python
利用Python代码实现一键抠背景功能
2019/12/29 Python
西班牙在线宠物商店:zooplus.es
2017/02/24 全球购物
Merrell迈乐澳大利亚网站:购买户外登山鞋
2017/05/28 全球购物
安全生产标语
2014/06/06 职场文书
公安机关纪律作风整顿剖析
2014/10/10 职场文书
2014年社区卫生工作总结
2014/12/18 职场文书
怎样写工作总结啊!
2019/06/18 职场文书
Redis入门基础常用操作命令整理
2022/06/01 Redis