PyTorch学习笔记之回归实战


Posted in Python onMay 28, 2018

本文主要是用PyTorch来实现一个简单的回归任务。

编辑器:spyder

1.引入相应的包及生成伪数据

import torch
import torch.nn.functional as F # 主要实现激活函数
import matplotlib.pyplot as plt # 绘图的工具
from torch.autograd import Variable

# 生成伪数据
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim = 1)
y = x.pow(2) + 0.2 * torch.rand(x.size())

# 变为Variable
x, y = Variable(x), Variable(y)

其中torch.linspace是为了生成连续间断的数据,第一个参数表示起点,第二个参数表示终点,第三个参数表示将这个区间分成平均几份,即生成几个数据。因为torch只能处理二维的数据,所以我们用torch.unsqueeze给伪数据添加一个维度,dim表示添加在第几维。torch.rand返回的是[0,1)之间的均匀分布。

2.绘制数据图像

在上述代码后面加下面的代码,然后运行可得伪数据的图形化表示:

# 绘制数据图像
plt.scatter(x.data.numpy(), y.data.numpy())
plt.show()

PyTorch学习笔记之回归实战

3.建立神经网络

class Net(torch.nn.Module):
 def __init__(self, n_feature, n_hidden, n_output):
  super(Net, self).__init__()
  self.hidden = torch.nn.Linear(n_feature, n_hidden) # hidden layer
  self.predict = torch.nn.Linear(n_hidden, n_output) # output layer

 def forward(self, x):
  x = F.relu(self.hidden(x))  # activation function for hidden layer
  x = self.predict(x)    # linear output
  return x

net = Net(n_feature=1, n_hidden=10, n_output=1)  # define the network
print(net) # net architecture

一般神经网络的类都继承自torch.nn.Module__init__()和forward()两个函数是自定义类的主要函数。在__init__()中都要添加一句super(Net, self).__init__(),这是固定的标准写法,用于继承父类的初始化函数。__init__()中只是对神经网络的模块进行了声明,真正的搭建是在forwad()中实现。自定义类中的成员都通过self指针来进行访问,所以参数列表中都包含了self。

如果想查看网络结构,可以用print()函数直接打印网络。本文的网络结构输出如下:

Net (
 (hidden): Linear (1 -> 10)
 (predict): Linear (10 -> 1)
)

4.训练网络

# 训练100次
for t in range(100):
 prediction = net(x)  # input x and predict based on x

 loss = loss_func(prediction, y)  # 一定要是输出在前,标签在后 (1. nn output, 2. target)

 optimizer.zero_grad() # clear gradients for next train
 loss.backward()   # backpropagation, compute gradients
 optimizer.step()  # apply gradients

训练网络之前我们需要先定义优化器和损失函数。torch.optim包中包括了各种优化器,这里我们选用最常见的SGD作为优化器。因为我们要对网络的参数进行优化,所以我们要把网络的参数net.parameters()传入优化器中,并设置学习率(一般小于1)。

由于这里是回归任务,我们选择torch.nn.MSELoss()作为损失函数。

由于优化器是基于梯度来优化参数的,并且梯度会保存在其中。所以在每次优化前要通过optimizer.zero_grad()把梯度置零,然后再后向传播及更新。

5.可视化训练过程

plt.ion() # something about plotting

for t in range(100):
 ...

 if t % 5 == 0:
  # plot and show learning process
  plt.cla()
  plt.scatter(x.data.numpy(), y.data.numpy())
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
  plt.text(0.5, 0, 'Loss=%.4f' % loss.data[0], fontdict={'size': 20, 'color': 'red'})
  plt.pause(0.1)

plt.ioff()
plt.show()

6.运行结果

PyTorch学习笔记之回归实战

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
以Python的Pyspider为例剖析搜索引擎的网络爬虫实现方法
Mar 30 Python
Python中decorator使用实例
Apr 14 Python
Python编程中的异常处理教程
Aug 21 Python
Python实现的选择排序算法示例
Nov 29 Python
Python使用Tkinter实现机器人走迷宫
Jan 22 Python
Django中Forms的使用代码解析
Feb 10 Python
numpy按列连接两个维数不同的数组方式
Dec 06 Python
Pytorch 神经网络—自定义数据集上实现教程
Jan 07 Python
Pytorch之finetune使用详解
Jan 18 Python
pandas的resample重采样的使用
Apr 24 Python
Django-simple-captcha验证码包使用方法详解
Nov 28 Python
Python字典和列表性能之间的比较
Jun 07 Python
Django 使用Ajax进行前后台交互的示例讲解
May 28 #Python
Python实现爬虫爬取NBA数据功能示例
May 28 #Python
Django+Ajax+jQuery实现网页动态更新的实例
May 28 #Python
Python实现合并两个列表的方法分析
May 28 #Python
django js实现部分页面刷新的示例代码
May 28 #Python
Django项目中用JS实现加载子页面并传值的方法
May 28 #Python
Python面向对象类继承和组合实例分析
May 28 #Python
You might like
BBS(php & mysql)完整版(二)
2006/10/09 PHP
php radio 单选框获取与保持值的实现代码
2010/05/15 PHP
php使用PDO获取结果集的方法
2017/02/16 PHP
Laravel find in set排序实例
2019/10/09 PHP
Docker 安装 PHP并与Nginx的部署实例讲解
2021/02/27 PHP
jQuery 源码分析笔记(7) Queue
2011/06/19 Javascript
jquery $.each() 使用小探
2013/08/23 Javascript
Jquery实现的角色左右选择特效
2014/05/21 Javascript
Javascript控制input输入时间格式的方法
2015/01/28 Javascript
jquery easyui datagrid实现增加,修改,删除方法总结
2016/05/25 Javascript
AngularJS 自定义指令详解及示例代码
2016/08/17 Javascript
网页瀑布流布局jQuery实现代码
2016/10/21 Javascript
关于webuploader插件使用过程遇到的小问题
2016/11/07 Javascript
详解Sea.js中Module.exports和exports的区别
2017/02/12 Javascript
p5.js实现斐波那契螺旋的示例代码
2018/03/22 Javascript
js与jQuery实现获取table中的数据并拼成json字符串操作示例
2018/07/12 jQuery
Layui实现带查询条件的分页
2019/07/27 Javascript
JavaScript 实现下雪特效的示例代码
2020/09/09 Javascript
python利用Guetzli批量压缩图片
2017/03/23 Python
Python使用matplotlib的pie函数绘制饼状图功能示例
2018/01/08 Python
想学python 这5本书籍你必看!
2018/12/11 Python
Python如何筛选序列中的元素的方法实现
2019/07/15 Python
django 简单实现登录验证给你
2019/11/06 Python
Python3基本输入与输出操作实例分析
2020/02/14 Python
Python configparser模块操作代码实例
2020/06/08 Python
pycharm 使用tab跳出正在编辑的括号(){}{}等问题
2021/02/26 Python
一款html5 canvas实现的图片玻璃碎片特效
2014/09/11 HTML / CSS
阿迪达斯比利时官方商城:adidas比利时
2016/10/10 全球购物
法国大使拉杆箱官网:DELSEY Paris
2018/03/20 全球购物
国外的一些J2EE面试题一
2012/10/13 面试题
行政部工作岗位职责范本
2014/03/05 职场文书
小学作文评语大全
2014/04/21 职场文书
护士业务学习心得体会
2016/01/25 职场文书
python实现图片九宫格分割的示例
2021/04/25 Python
如何制作自己的原生JavaScript路由
2021/05/05 Javascript
实战 快速定位MySQL的慢SQL
2022/03/22 MySQL