PyTorch学习笔记之回归实战


Posted in Python onMay 28, 2018

本文主要是用PyTorch来实现一个简单的回归任务。

编辑器:spyder

1.引入相应的包及生成伪数据

import torch
import torch.nn.functional as F # 主要实现激活函数
import matplotlib.pyplot as plt # 绘图的工具
from torch.autograd import Variable

# 生成伪数据
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim = 1)
y = x.pow(2) + 0.2 * torch.rand(x.size())

# 变为Variable
x, y = Variable(x), Variable(y)

其中torch.linspace是为了生成连续间断的数据,第一个参数表示起点,第二个参数表示终点,第三个参数表示将这个区间分成平均几份,即生成几个数据。因为torch只能处理二维的数据,所以我们用torch.unsqueeze给伪数据添加一个维度,dim表示添加在第几维。torch.rand返回的是[0,1)之间的均匀分布。

2.绘制数据图像

在上述代码后面加下面的代码,然后运行可得伪数据的图形化表示:

# 绘制数据图像
plt.scatter(x.data.numpy(), y.data.numpy())
plt.show()

PyTorch学习笔记之回归实战

3.建立神经网络

class Net(torch.nn.Module):
 def __init__(self, n_feature, n_hidden, n_output):
  super(Net, self).__init__()
  self.hidden = torch.nn.Linear(n_feature, n_hidden) # hidden layer
  self.predict = torch.nn.Linear(n_hidden, n_output) # output layer

 def forward(self, x):
  x = F.relu(self.hidden(x))  # activation function for hidden layer
  x = self.predict(x)    # linear output
  return x

net = Net(n_feature=1, n_hidden=10, n_output=1)  # define the network
print(net) # net architecture

一般神经网络的类都继承自torch.nn.Module__init__()和forward()两个函数是自定义类的主要函数。在__init__()中都要添加一句super(Net, self).__init__(),这是固定的标准写法,用于继承父类的初始化函数。__init__()中只是对神经网络的模块进行了声明,真正的搭建是在forwad()中实现。自定义类中的成员都通过self指针来进行访问,所以参数列表中都包含了self。

如果想查看网络结构,可以用print()函数直接打印网络。本文的网络结构输出如下:

Net (
 (hidden): Linear (1 -> 10)
 (predict): Linear (10 -> 1)
)

4.训练网络

# 训练100次
for t in range(100):
 prediction = net(x)  # input x and predict based on x

 loss = loss_func(prediction, y)  # 一定要是输出在前,标签在后 (1. nn output, 2. target)

 optimizer.zero_grad() # clear gradients for next train
 loss.backward()   # backpropagation, compute gradients
 optimizer.step()  # apply gradients

训练网络之前我们需要先定义优化器和损失函数。torch.optim包中包括了各种优化器,这里我们选用最常见的SGD作为优化器。因为我们要对网络的参数进行优化,所以我们要把网络的参数net.parameters()传入优化器中,并设置学习率(一般小于1)。

由于这里是回归任务,我们选择torch.nn.MSELoss()作为损失函数。

由于优化器是基于梯度来优化参数的,并且梯度会保存在其中。所以在每次优化前要通过optimizer.zero_grad()把梯度置零,然后再后向传播及更新。

5.可视化训练过程

plt.ion() # something about plotting

for t in range(100):
 ...

 if t % 5 == 0:
  # plot and show learning process
  plt.cla()
  plt.scatter(x.data.numpy(), y.data.numpy())
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
  plt.text(0.5, 0, 'Loss=%.4f' % loss.data[0], fontdict={'size': 20, 'color': 'red'})
  plt.pause(0.1)

plt.ioff()
plt.show()

6.运行结果

PyTorch学习笔记之回归实战

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python进阶教程之动态类型详解
Aug 30 Python
Python实现统计英文单词个数及字符串分割代码
May 28 Python
Python快速从注释生成文档的方法
Dec 26 Python
Python模拟三级菜单效果
Sep 11 Python
Python使用sklearn库实现的各种分类算法简单应用小结
Jul 04 Python
如何通过50行Python代码获取公众号全部文章
Jul 12 Python
python实现邮件自动发送
Aug 10 Python
pygame实现俄罗斯方块游戏(基础篇1)
Oct 29 Python
python3图片文件批量重命名处理
Oct 31 Python
解决pycharm不能自动补全第三方库的函数和属性问题
Mar 12 Python
在python下实现word2vec词向量训练与加载实例
Jun 09 Python
安装pyecharts1.8.0版本后导入pyecharts模块绘图时报错: “所有图表类型将在 v1.9.0 版本开始强制使用 ChartItem 进行数据项配置 ”的解决方法
Aug 18 Python
Django 使用Ajax进行前后台交互的示例讲解
May 28 #Python
Python实现爬虫爬取NBA数据功能示例
May 28 #Python
Django+Ajax+jQuery实现网页动态更新的实例
May 28 #Python
Python实现合并两个列表的方法分析
May 28 #Python
django js实现部分页面刷新的示例代码
May 28 #Python
Django项目中用JS实现加载子页面并传值的方法
May 28 #Python
Python面向对象类继承和组合实例分析
May 28 #Python
You might like
Uncaught exception com_exception with message Failed to create COM object
2012/01/11 PHP
php-cli简介(不会Shell语言一样用Shell)
2013/06/03 PHP
php中利用str_pad函数生成数字递增形式的产品编号
2013/09/30 PHP
详解php反序列化
2020/06/10 PHP
关于递归运算的顺序测试代码
2011/11/30 Javascript
逐一介绍Jquery data()、Jquery stop()、jquery delay()函数(详)
2015/11/04 Javascript
使用jQuery5分钟快速搞定双色表格的简单实例
2016/08/08 Javascript
jQuery实现点击查看大图并以弹框的形式居中
2016/08/08 Javascript
全面解析标签页的切换方式
2016/08/21 Javascript
KnockoutJS 3.X API 第四章之事件event绑定
2016/10/10 Javascript
easyui datagrid 大数据加载效率慢,优化解决方法(推荐)
2016/11/09 Javascript
jQuery实现radio第一次点击选中第二次点击取消功能
2017/05/15 jQuery
React通过父组件传递类名给子组件的实现方法
2017/11/13 Javascript
Vue.js 2.0和Cordova开发webApp环境搭建方法
2018/02/26 Javascript
Angular4 ElementRef的应用
2018/02/26 Javascript
Vue项目webpack打包部署到Tomcat刷新报404错误问题的解决方案
2018/05/15 Javascript
Vue插件打包与发布的方法示例
2018/08/20 Javascript
js实现鼠标点击页面弹出自定义文字效果
2019/12/24 Javascript
Python中__call__用法实例
2014/08/29 Python
Google开源的Python格式化工具YAPF的安装和使用教程
2016/05/31 Python
python Pandas如何对数据集随机抽样
2019/07/29 Python
使用Python的Turtle库绘制森林的实例
2019/12/18 Python
Pytorch 保存模型生成图片方式
2020/01/10 Python
10个python3常用排序算法详细说明与实例(快速排序,冒泡排序,桶排序,基数排序,堆排序,希尔排序,归并排序,计数排序)
2020/03/17 Python
Python word文本自动化操作实现方法解析
2020/11/05 Python
Python基于execjs运行js过程解析
2020/11/27 Python
东方电视购物:东方CJ
2016/10/12 全球购物
如果NULL和0作为空指针常数是等价的,那我到底该用哪一个
2014/09/16 面试题
银行职业规划书范文
2013/12/28 职场文书
驻村工作先进事迹
2014/08/14 职场文书
2014年减负工作总结
2014/12/10 职场文书
行政文员岗位职责
2015/02/04 职场文书
自书遗嘱范文
2015/08/07 职场文书
nginx内存池源码解析
2021/11/20 Servers
使用Java去实现超市会员管理系统
2022/03/18 Java/Android
zabbix如何添加监控主机和自定义监控项
2022/08/14 Servers