PyTorch学习笔记之回归实战


Posted in Python onMay 28, 2018

本文主要是用PyTorch来实现一个简单的回归任务。

编辑器:spyder

1.引入相应的包及生成伪数据

import torch
import torch.nn.functional as F # 主要实现激活函数
import matplotlib.pyplot as plt # 绘图的工具
from torch.autograd import Variable

# 生成伪数据
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim = 1)
y = x.pow(2) + 0.2 * torch.rand(x.size())

# 变为Variable
x, y = Variable(x), Variable(y)

其中torch.linspace是为了生成连续间断的数据,第一个参数表示起点,第二个参数表示终点,第三个参数表示将这个区间分成平均几份,即生成几个数据。因为torch只能处理二维的数据,所以我们用torch.unsqueeze给伪数据添加一个维度,dim表示添加在第几维。torch.rand返回的是[0,1)之间的均匀分布。

2.绘制数据图像

在上述代码后面加下面的代码,然后运行可得伪数据的图形化表示:

# 绘制数据图像
plt.scatter(x.data.numpy(), y.data.numpy())
plt.show()

PyTorch学习笔记之回归实战

3.建立神经网络

class Net(torch.nn.Module):
 def __init__(self, n_feature, n_hidden, n_output):
  super(Net, self).__init__()
  self.hidden = torch.nn.Linear(n_feature, n_hidden) # hidden layer
  self.predict = torch.nn.Linear(n_hidden, n_output) # output layer

 def forward(self, x):
  x = F.relu(self.hidden(x))  # activation function for hidden layer
  x = self.predict(x)    # linear output
  return x

net = Net(n_feature=1, n_hidden=10, n_output=1)  # define the network
print(net) # net architecture

一般神经网络的类都继承自torch.nn.Module__init__()和forward()两个函数是自定义类的主要函数。在__init__()中都要添加一句super(Net, self).__init__(),这是固定的标准写法,用于继承父类的初始化函数。__init__()中只是对神经网络的模块进行了声明,真正的搭建是在forwad()中实现。自定义类中的成员都通过self指针来进行访问,所以参数列表中都包含了self。

如果想查看网络结构,可以用print()函数直接打印网络。本文的网络结构输出如下:

Net (
 (hidden): Linear (1 -> 10)
 (predict): Linear (10 -> 1)
)

4.训练网络

# 训练100次
for t in range(100):
 prediction = net(x)  # input x and predict based on x

 loss = loss_func(prediction, y)  # 一定要是输出在前,标签在后 (1. nn output, 2. target)

 optimizer.zero_grad() # clear gradients for next train
 loss.backward()   # backpropagation, compute gradients
 optimizer.step()  # apply gradients

训练网络之前我们需要先定义优化器和损失函数。torch.optim包中包括了各种优化器,这里我们选用最常见的SGD作为优化器。因为我们要对网络的参数进行优化,所以我们要把网络的参数net.parameters()传入优化器中,并设置学习率(一般小于1)。

由于这里是回归任务,我们选择torch.nn.MSELoss()作为损失函数。

由于优化器是基于梯度来优化参数的,并且梯度会保存在其中。所以在每次优化前要通过optimizer.zero_grad()把梯度置零,然后再后向传播及更新。

5.可视化训练过程

plt.ion() # something about plotting

for t in range(100):
 ...

 if t % 5 == 0:
  # plot and show learning process
  plt.cla()
  plt.scatter(x.data.numpy(), y.data.numpy())
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
  plt.text(0.5, 0, 'Loss=%.4f' % loss.data[0], fontdict={'size': 20, 'color': 'red'})
  plt.pause(0.1)

plt.ioff()
plt.show()

6.运行结果

PyTorch学习笔记之回归实战

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
有关wxpython pyqt内存占用问题分析
Jun 09 Python
python提示No module named images的解决方法
Sep 29 Python
Python的Flask框架中实现简单的登录功能的教程
Apr 20 Python
在Python中使用列表生成式的教程
Apr 27 Python
Python中关于使用模块的基础知识
May 24 Python
在Django的视图中使用form对象的方法
Jul 18 Python
Python中你应该知道的一些内置函数
Mar 31 Python
Python的argparse库使用详解
Oct 09 Python
感知器基础原理及python实现过程详解
Sep 30 Python
Python爬虫爬取电影票房数据及图表展示操作示例
Mar 27 Python
Python函数参数定义及传递方式解析
Jun 10 Python
详解Anaconda 的安装教程
Sep 23 Python
Django 使用Ajax进行前后台交互的示例讲解
May 28 #Python
Python实现爬虫爬取NBA数据功能示例
May 28 #Python
Django+Ajax+jQuery实现网页动态更新的实例
May 28 #Python
Python实现合并两个列表的方法分析
May 28 #Python
django js实现部分页面刷新的示例代码
May 28 #Python
Django项目中用JS实现加载子页面并传值的方法
May 28 #Python
Python面向对象类继承和组合实例分析
May 28 #Python
You might like
如何从一个php文件向另一个地址post数据,不用表单和隐藏的变量的
2007/03/06 PHP
ThinkPHP调用百度翻译类实现在线翻译
2014/06/26 PHP
Ubuntu VPS中wordpress网站打开时提示”建立数据库连接错误”的解决办法
2016/11/03 PHP
PHP实现无限分类的实现方法
2016/11/14 PHP
在云虚拟主机部署thinkphp5项目的步骤详解
2017/12/21 PHP
关于Laravel参数验证的一些疑与惑
2019/11/19 PHP
基于pthread_create,readlink,getpid等函数的学习与总结
2013/07/17 Javascript
使用cluster 将自己的Node服务器扩展为多线程服务器
2014/11/10 Javascript
JS是按值传递还是按引用传递
2015/01/30 Javascript
浅析jQuery Mobile的初始化事件
2015/12/03 Javascript
js不间断滚动的简单实现
2016/06/03 Javascript
BootStrap智能表单实战系列(六)表单编辑页面的数据绑定
2016/06/13 Javascript
JSONP和批量操作功能的实现方法
2016/08/21 Javascript
JS动态遍历json中所有键值对的方法(不知道属性名的情况)
2016/12/28 Javascript
用npm安装vue和vue-cli,并使用webpack创建项目的方法
2018/09/28 Javascript
Electron-vue脚手架改造vue项目的方法
2018/10/22 Javascript
微信小程序实现左滑修改、删除功能
2020/10/19 Javascript
深入解析Vue源码实例挂载与编译流程实现思路详解
2019/05/05 Javascript
用Vue.js方法创建模板并使用多个模板合成
2019/06/28 Javascript
python正常时间和unix时间戳相互转换的方法
2015/04/23 Python
python urllib urlopen()对象方法/代理的补充说明
2017/06/29 Python
python tensorflow学习之识别单张图片的实现的示例
2018/02/09 Python
python实现excel读写数据
2021/03/02 Python
解决Python列表字符不区分大小写的问题
2019/12/19 Python
python进行参数传递的方法
2020/05/12 Python
python如何绘制疫情图
2020/09/16 Python
使用django自带的user做外键的方法
2020/11/30 Python
预订旅游活动、景点和旅游:GetYourGuide
2019/09/29 全球购物
金属材料工程个人求职的自我评价
2013/12/04 职场文书
办公室秘书自我鉴定
2014/01/18 职场文书
旅游管理专业大学生职业规划书
2014/02/27 职场文书
追讨欠款律师函
2015/06/24 职场文书
2015年村级财务管理制度
2015/08/04 职场文书
CSS3通过var()和calc()函数实现动画特效
2021/03/30 HTML / CSS
日元符号 ¥
2022/02/17 杂记
Vue 打包后相对路径的引用问题
2022/06/05 Vue.js