PyTorch学习笔记之回归实战


Posted in Python onMay 28, 2018

本文主要是用PyTorch来实现一个简单的回归任务。

编辑器:spyder

1.引入相应的包及生成伪数据

import torch
import torch.nn.functional as F # 主要实现激活函数
import matplotlib.pyplot as plt # 绘图的工具
from torch.autograd import Variable

# 生成伪数据
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim = 1)
y = x.pow(2) + 0.2 * torch.rand(x.size())

# 变为Variable
x, y = Variable(x), Variable(y)

其中torch.linspace是为了生成连续间断的数据,第一个参数表示起点,第二个参数表示终点,第三个参数表示将这个区间分成平均几份,即生成几个数据。因为torch只能处理二维的数据,所以我们用torch.unsqueeze给伪数据添加一个维度,dim表示添加在第几维。torch.rand返回的是[0,1)之间的均匀分布。

2.绘制数据图像

在上述代码后面加下面的代码,然后运行可得伪数据的图形化表示:

# 绘制数据图像
plt.scatter(x.data.numpy(), y.data.numpy())
plt.show()

PyTorch学习笔记之回归实战

3.建立神经网络

class Net(torch.nn.Module):
 def __init__(self, n_feature, n_hidden, n_output):
  super(Net, self).__init__()
  self.hidden = torch.nn.Linear(n_feature, n_hidden) # hidden layer
  self.predict = torch.nn.Linear(n_hidden, n_output) # output layer

 def forward(self, x):
  x = F.relu(self.hidden(x))  # activation function for hidden layer
  x = self.predict(x)    # linear output
  return x

net = Net(n_feature=1, n_hidden=10, n_output=1)  # define the network
print(net) # net architecture

一般神经网络的类都继承自torch.nn.Module__init__()和forward()两个函数是自定义类的主要函数。在__init__()中都要添加一句super(Net, self).__init__(),这是固定的标准写法,用于继承父类的初始化函数。__init__()中只是对神经网络的模块进行了声明,真正的搭建是在forwad()中实现。自定义类中的成员都通过self指针来进行访问,所以参数列表中都包含了self。

如果想查看网络结构,可以用print()函数直接打印网络。本文的网络结构输出如下:

Net (
 (hidden): Linear (1 -> 10)
 (predict): Linear (10 -> 1)
)

4.训练网络

# 训练100次
for t in range(100):
 prediction = net(x)  # input x and predict based on x

 loss = loss_func(prediction, y)  # 一定要是输出在前,标签在后 (1. nn output, 2. target)

 optimizer.zero_grad() # clear gradients for next train
 loss.backward()   # backpropagation, compute gradients
 optimizer.step()  # apply gradients

训练网络之前我们需要先定义优化器和损失函数。torch.optim包中包括了各种优化器,这里我们选用最常见的SGD作为优化器。因为我们要对网络的参数进行优化,所以我们要把网络的参数net.parameters()传入优化器中,并设置学习率(一般小于1)。

由于这里是回归任务,我们选择torch.nn.MSELoss()作为损失函数。

由于优化器是基于梯度来优化参数的,并且梯度会保存在其中。所以在每次优化前要通过optimizer.zero_grad()把梯度置零,然后再后向传播及更新。

5.可视化训练过程

plt.ion() # something about plotting

for t in range(100):
 ...

 if t % 5 == 0:
  # plot and show learning process
  plt.cla()
  plt.scatter(x.data.numpy(), y.data.numpy())
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
  plt.text(0.5, 0, 'Loss=%.4f' % loss.data[0], fontdict={'size': 20, 'color': 'red'})
  plt.pause(0.1)

plt.ioff()
plt.show()

6.运行结果

PyTorch学习笔记之回归实战

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
基于python的汉字转GBK码实现代码
Feb 19 Python
tornado框架blog模块分析与使用
Nov 21 Python
pip matplotlib报错equired packages can not be built解决
Jan 06 Python
Python使用try except处理程序异常的三种常用方法分析
Sep 05 Python
Python实现FM算法解析
Jun 18 Python
基于Python+Appium实现京东双十一自动领金币功能
Oct 31 Python
Python Pandas 对列/行进行选择,增加,删除操作
May 17 Python
django models里数据表插入数据id自增操作
Jul 15 Python
PyCharm2019 安装和配置教程详解附激活码
Jul 31 Python
Django3中的自定义用户模型实例详解
Aug 23 Python
python绘图pyecharts+pandas的使用详解
Dec 13 Python
用python开发一款操作MySQL的小工具
May 12 Python
Django 使用Ajax进行前后台交互的示例讲解
May 28 #Python
Python实现爬虫爬取NBA数据功能示例
May 28 #Python
Django+Ajax+jQuery实现网页动态更新的实例
May 28 #Python
Python实现合并两个列表的方法分析
May 28 #Python
django js实现部分页面刷新的示例代码
May 28 #Python
Django项目中用JS实现加载子页面并传值的方法
May 28 #Python
Python面向对象类继承和组合实例分析
May 28 #Python
You might like
php获取远程图片体积大小的实例
2013/11/12 PHP
php数组去重复数据示例
2014/02/25 PHP
Parse正式发布开源PHP SDK
2014/08/11 PHP
PHP7多线程搭建教程
2017/04/21 PHP
Laravel使用RabbitMQ的方法示例
2019/06/18 PHP
兼容FireFox 的 js 日历 支持时间的获取
2009/03/04 Javascript
Javascript 键盘keyCode键码值表
2009/12/24 Javascript
jQuery EasyUI API 中文文档 - Calendar日历使用
2011/10/19 Javascript
基于jquery的bankInput银行卡账号格式化
2012/08/22 Javascript
jQuery多项选项卡的实现思路附样式及代码
2014/06/03 Javascript
html文本框提示效果的示例代码
2014/06/28 Javascript
JavaScript实现信用卡校验方法
2015/04/07 Javascript
js操作数组函数实例小结
2015/12/10 Javascript
JS定义类的六种方式详解
2016/05/12 Javascript
使用jquery datatable和bootsrap创建表格实例代码
2017/03/17 Javascript
浅谈js中的this问题
2017/08/31 Javascript
nodejs项目windows下开机自启动的方法
2017/11/22 NodeJs
微信小程序scroll-view隐藏滚动条的方法详解
2020/03/25 Javascript
Python使用pylab库实现绘制直方图功能示例
2018/06/01 Python
python使用epoll实现服务端的方法
2018/10/16 Python
Python collections中的双向队列deque简单介绍详解
2019/11/04 Python
Python pip使用超时问题解决方案
2020/08/03 Python
python实现不同数据库间数据同步功能
2021/02/25 Python
英国123鲜花网站:123 Flowers
2019/07/07 全球购物
工程测量与监理专业应届生求职信
2013/11/27 职场文书
财政专业求职信范文
2014/02/19 职场文书
《木笛》教学反思
2014/03/01 职场文书
开业庆典主持词
2014/03/21 职场文书
慈善晚会策划方案
2014/05/14 职场文书
计算机应用专业毕业生求职信
2014/06/03 职场文书
银行业务授权委托书
2014/10/10 职场文书
2015年房产销售工作总结范文
2015/05/22 职场文书
宝宝满月宴答谢词
2015/09/30 职场文书
数据结构课程设计心得体会
2016/01/15 职场文书
优秀乡村医生事迹材料(2016精选版)
2016/02/29 职场文书
mongodb数据库迁移变更的解决方案
2021/09/04 MongoDB