PyTorch学习笔记之回归实战


Posted in Python onMay 28, 2018

本文主要是用PyTorch来实现一个简单的回归任务。

编辑器:spyder

1.引入相应的包及生成伪数据

import torch
import torch.nn.functional as F # 主要实现激活函数
import matplotlib.pyplot as plt # 绘图的工具
from torch.autograd import Variable

# 生成伪数据
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim = 1)
y = x.pow(2) + 0.2 * torch.rand(x.size())

# 变为Variable
x, y = Variable(x), Variable(y)

其中torch.linspace是为了生成连续间断的数据,第一个参数表示起点,第二个参数表示终点,第三个参数表示将这个区间分成平均几份,即生成几个数据。因为torch只能处理二维的数据,所以我们用torch.unsqueeze给伪数据添加一个维度,dim表示添加在第几维。torch.rand返回的是[0,1)之间的均匀分布。

2.绘制数据图像

在上述代码后面加下面的代码,然后运行可得伪数据的图形化表示:

# 绘制数据图像
plt.scatter(x.data.numpy(), y.data.numpy())
plt.show()

PyTorch学习笔记之回归实战

3.建立神经网络

class Net(torch.nn.Module):
 def __init__(self, n_feature, n_hidden, n_output):
  super(Net, self).__init__()
  self.hidden = torch.nn.Linear(n_feature, n_hidden) # hidden layer
  self.predict = torch.nn.Linear(n_hidden, n_output) # output layer

 def forward(self, x):
  x = F.relu(self.hidden(x))  # activation function for hidden layer
  x = self.predict(x)    # linear output
  return x

net = Net(n_feature=1, n_hidden=10, n_output=1)  # define the network
print(net) # net architecture

一般神经网络的类都继承自torch.nn.Module__init__()和forward()两个函数是自定义类的主要函数。在__init__()中都要添加一句super(Net, self).__init__(),这是固定的标准写法,用于继承父类的初始化函数。__init__()中只是对神经网络的模块进行了声明,真正的搭建是在forwad()中实现。自定义类中的成员都通过self指针来进行访问,所以参数列表中都包含了self。

如果想查看网络结构,可以用print()函数直接打印网络。本文的网络结构输出如下:

Net (
 (hidden): Linear (1 -> 10)
 (predict): Linear (10 -> 1)
)

4.训练网络

# 训练100次
for t in range(100):
 prediction = net(x)  # input x and predict based on x

 loss = loss_func(prediction, y)  # 一定要是输出在前,标签在后 (1. nn output, 2. target)

 optimizer.zero_grad() # clear gradients for next train
 loss.backward()   # backpropagation, compute gradients
 optimizer.step()  # apply gradients

训练网络之前我们需要先定义优化器和损失函数。torch.optim包中包括了各种优化器,这里我们选用最常见的SGD作为优化器。因为我们要对网络的参数进行优化,所以我们要把网络的参数net.parameters()传入优化器中,并设置学习率(一般小于1)。

由于这里是回归任务,我们选择torch.nn.MSELoss()作为损失函数。

由于优化器是基于梯度来优化参数的,并且梯度会保存在其中。所以在每次优化前要通过optimizer.zero_grad()把梯度置零,然后再后向传播及更新。

5.可视化训练过程

plt.ion() # something about plotting

for t in range(100):
 ...

 if t % 5 == 0:
  # plot and show learning process
  plt.cla()
  plt.scatter(x.data.numpy(), y.data.numpy())
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
  plt.text(0.5, 0, 'Loss=%.4f' % loss.data[0], fontdict={'size': 20, 'color': 'red'})
  plt.pause(0.1)

plt.ioff()
plt.show()

6.运行结果

PyTorch学习笔记之回归实战

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
跟老齐学Python之永远强大的函数
Sep 14 Python
Python实现的Excel文件读写类
Jul 30 Python
python定时关机小脚本
Jun 20 Python
Python爬虫框架Scrapy基本用法入门教程
Jul 26 Python
Python编程中flask的简介与简单使用
Dec 28 Python
python pygame实现五子棋小游戏
Oct 26 Python
python Django中models进行模糊查询的示例
Jul 18 Python
Python Scrapy图片爬取原理及代码实例
Jun 12 Python
Python配置pip国内镜像源的实现
Aug 20 Python
Python利用imshow制作自定义渐变填充柱状图(colorbar)
Dec 10 Python
Pycharm在指定目录下生成文件和删除文件的实现
Dec 28 Python
Django 使用Ajax进行前后台交互的示例讲解
May 28 #Python
Python实现爬虫爬取NBA数据功能示例
May 28 #Python
Django+Ajax+jQuery实现网页动态更新的实例
May 28 #Python
Python实现合并两个列表的方法分析
May 28 #Python
django js实现部分页面刷新的示例代码
May 28 #Python
Django项目中用JS实现加载子页面并传值的方法
May 28 #Python
Python面向对象类继承和组合实例分析
May 28 #Python
You might like
php创建基本身份认证站点的方法详解
2013/06/08 PHP
PHP json_encode中文乱码问题的解决办法
2013/09/09 PHP
PHP两种快速排序算法实例
2015/02/15 PHP
9个比较实用的php代码片段
2016/03/15 PHP
php同时使用session和cookie来保存用户登录信息的实现代码
2016/05/13 PHP
yii2中dropDownList实现二级和三级联动写法
2017/04/26 PHP
JavaScript 浏览器验证代码(来自discuz)
2010/07/17 Javascript
javascript使用window.open提示“已经计划系统关机”的原因
2014/08/15 Javascript
jQuery实现带分组数据的Table表头排序实例分析
2015/11/24 Javascript
JavaScript编写检测用户所使用的浏览器的代码示例
2016/05/05 Javascript
js实现楼层效果的简单实例
2016/07/15 Javascript
JS获取input[file]的值并显示在页面的实现方法
2018/03/09 Javascript
nodejs简单读写excel内容的方法示例
2018/03/16 NodeJs
Vue SPA单页应用首屏优化实践
2018/06/28 Javascript
如何使用electron-builder及electron-updater给项目配置自动更新
2018/12/24 Javascript
使用 webpack 插件自动生成 vue 路由文件的方法
2019/08/20 Javascript
webpack 动态批量加载文件的实现方法
2020/03/19 Javascript
文章或博客自动生成章节目录索引(支持三级)的实现代码
2020/05/10 Javascript
[01:20:05]DOTA2-DPC中国联赛 正赛 Ehome vs VG BO3 第二场 2月5日
2021/03/11 DOTA
python获取当前时间对应unix时间戳的方法
2015/05/15 Python
深入理解Python装饰器
2016/07/27 Python
Python中工作日类库Busines Holiday的介绍与使用
2017/07/06 Python
在Pytorch中计算卷积方法的区别详解(conv2d的区别)
2020/01/03 Python
python利用faker库批量生成测试数据
2020/10/15 Python
浅谈Python __init__.py的作用
2020/10/28 Python
css3 给背景设置渐变色的方法
2019/09/12 HTML / CSS
使用Html5 Stream开发实时监控系统
2020/06/02 HTML / CSS
戴森比利时官方网站:Dyson BE
2020/10/03 全球购物
审计工作个人的自我评价
2013/12/25 职场文书
捐款倡议书怎么写
2014/05/13 职场文书
我爱我校演讲稿
2014/05/21 职场文书
四风自我剖析材料
2014/09/30 职场文书
离婚协议书范本及离婚须知
2014/10/15 职场文书
2014办公室年度工作总结
2014/12/09 职场文书
Nest.js参数校验和自定义返回数据格式详解
2021/03/29 Javascript
Python读写yaml文件
2022/03/20 Python