PyTorch学习笔记之回归实战


Posted in Python onMay 28, 2018

本文主要是用PyTorch来实现一个简单的回归任务。

编辑器:spyder

1.引入相应的包及生成伪数据

import torch
import torch.nn.functional as F # 主要实现激活函数
import matplotlib.pyplot as plt # 绘图的工具
from torch.autograd import Variable

# 生成伪数据
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim = 1)
y = x.pow(2) + 0.2 * torch.rand(x.size())

# 变为Variable
x, y = Variable(x), Variable(y)

其中torch.linspace是为了生成连续间断的数据,第一个参数表示起点,第二个参数表示终点,第三个参数表示将这个区间分成平均几份,即生成几个数据。因为torch只能处理二维的数据,所以我们用torch.unsqueeze给伪数据添加一个维度,dim表示添加在第几维。torch.rand返回的是[0,1)之间的均匀分布。

2.绘制数据图像

在上述代码后面加下面的代码,然后运行可得伪数据的图形化表示:

# 绘制数据图像
plt.scatter(x.data.numpy(), y.data.numpy())
plt.show()

PyTorch学习笔记之回归实战

3.建立神经网络

class Net(torch.nn.Module):
 def __init__(self, n_feature, n_hidden, n_output):
  super(Net, self).__init__()
  self.hidden = torch.nn.Linear(n_feature, n_hidden) # hidden layer
  self.predict = torch.nn.Linear(n_hidden, n_output) # output layer

 def forward(self, x):
  x = F.relu(self.hidden(x))  # activation function for hidden layer
  x = self.predict(x)    # linear output
  return x

net = Net(n_feature=1, n_hidden=10, n_output=1)  # define the network
print(net) # net architecture

一般神经网络的类都继承自torch.nn.Module__init__()和forward()两个函数是自定义类的主要函数。在__init__()中都要添加一句super(Net, self).__init__(),这是固定的标准写法,用于继承父类的初始化函数。__init__()中只是对神经网络的模块进行了声明,真正的搭建是在forwad()中实现。自定义类中的成员都通过self指针来进行访问,所以参数列表中都包含了self。

如果想查看网络结构,可以用print()函数直接打印网络。本文的网络结构输出如下:

Net (
 (hidden): Linear (1 -> 10)
 (predict): Linear (10 -> 1)
)

4.训练网络

# 训练100次
for t in range(100):
 prediction = net(x)  # input x and predict based on x

 loss = loss_func(prediction, y)  # 一定要是输出在前,标签在后 (1. nn output, 2. target)

 optimizer.zero_grad() # clear gradients for next train
 loss.backward()   # backpropagation, compute gradients
 optimizer.step()  # apply gradients

训练网络之前我们需要先定义优化器和损失函数。torch.optim包中包括了各种优化器,这里我们选用最常见的SGD作为优化器。因为我们要对网络的参数进行优化,所以我们要把网络的参数net.parameters()传入优化器中,并设置学习率(一般小于1)。

由于这里是回归任务,我们选择torch.nn.MSELoss()作为损失函数。

由于优化器是基于梯度来优化参数的,并且梯度会保存在其中。所以在每次优化前要通过optimizer.zero_grad()把梯度置零,然后再后向传播及更新。

5.可视化训练过程

plt.ion() # something about plotting

for t in range(100):
 ...

 if t % 5 == 0:
  # plot and show learning process
  plt.cla()
  plt.scatter(x.data.numpy(), y.data.numpy())
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
  plt.text(0.5, 0, 'Loss=%.4f' % loss.data[0], fontdict={'size': 20, 'color': 'red'})
  plt.pause(0.1)

plt.ioff()
plt.show()

6.运行结果

PyTorch学习笔记之回归实战

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
跟老齐学Python之复习if语句
Oct 02 Python
详解Python当中的字符串和编码
Apr 25 Python
python中迭代器(iterator)用法实例分析
Apr 29 Python
Python的Flask站点中集成xhEditor文本编辑器的教程
Jun 13 Python
python+selenium实现京东自动登录及秒杀功能
Nov 18 Python
Python range、enumerate和zip函数用法详解
Sep 11 Python
pandas按行按列遍历Dataframe的几种方式
Oct 23 Python
Django配置文件代码说明
Dec 04 Python
Python3 Click模块的使用方法详解
Feb 12 Python
Django Serializer HiddenField隐藏字段实例
Mar 31 Python
python爬虫学习笔记之pyquery模块基本用法详解
Apr 09 Python
使用pipenv管理python虚拟环境的全过程
Sep 25 Python
Django 使用Ajax进行前后台交互的示例讲解
May 28 #Python
Python实现爬虫爬取NBA数据功能示例
May 28 #Python
Django+Ajax+jQuery实现网页动态更新的实例
May 28 #Python
Python实现合并两个列表的方法分析
May 28 #Python
django js实现部分页面刷新的示例代码
May 28 #Python
Django项目中用JS实现加载子页面并传值的方法
May 28 #Python
Python面向对象类继承和组合实例分析
May 28 #Python
You might like
Optimizer与Debugger兼容性问题的解决方法
2008/12/01 PHP
PHP5中虚函数的实现方法分享
2011/04/20 PHP
php mssql扩展SQL查询中文字段名解决方法
2012/10/15 PHP
Linux中用PHP判断程序运行状态的2个方法
2014/05/04 PHP
destoon会员注册提示“数据校验失败(2)”解决方法
2014/06/21 PHP
自写的利用PDO对mysql数据库增删改查操作类
2018/02/19 PHP
js substr、substring和slice使用说明小记
2011/09/15 Javascript
jquery validate和jquery form 插件组合实现验证表单后AJAX提交
2015/08/26 Javascript
限制复选框最多选择项的实现代码
2016/05/30 Javascript
BootStrap 智能表单实战系列(十)自动完成组件的支持
2016/06/13 Javascript
javascript 常用验证函数总结
2016/06/28 Javascript
Bootstrap编写一个兼容主流浏览器的受众巨幕式风格页面
2016/07/01 Javascript
JavaScript数值千分位格式化的两种简单实现方法
2016/08/01 Javascript
js链表操作(实例讲解)
2017/08/29 Javascript
vue监听input标签的value值方法
2018/08/27 Javascript
ES6 十大特性简介
2020/12/09 Javascript
Python实现求最大公约数及判断素数的方法
2015/05/26 Python
Perl中著名的Schwartzian转换问题解决实现
2015/06/02 Python
基于Django用户认证系统详解
2018/02/21 Python
python基础教程项目四之新闻聚合
2018/04/02 Python
Django 使用logging打印日志的实例
2018/04/28 Python
python接口自动化(十七)--Json 数据处理---一次爬坑记(详解)
2019/04/18 Python
Python OpenCV中的resize()函数的使用
2019/06/20 Python
Python3环境安装Scrapy爬虫框架过程及常见错误
2019/07/12 Python
python3+django2开发一个简单的人员管理系统过程详解
2019/07/23 Python
Python获取统计自己的qq群成员信息的方法
2019/11/15 Python
python GUI库图形界面开发之PyQt5开发环境配置与基础使用
2020/02/25 Python
HTML5 Convas APIs方法详解
2015/04/24 HTML / CSS
中国最大的潮流商品购物网站:YOHO!BUY有货
2017/01/07 全球购物
教师的实习鉴定
2013/12/15 职场文书
八一建军节活动方案
2014/02/10 职场文书
3.15国际消费者权益日主题活动活动总结
2014/03/16 职场文书
小学综合实践活动总结
2014/07/07 职场文书
消费者投诉书范文
2015/07/02 职场文书
年会邀请函的格式及范文五篇
2019/11/02 职场文书
如何利用React实现图片识别App
2022/02/18 Javascript