PyTorch的深度学习入门教程之构建神经网络


Posted in Python onJune 27, 2019

前言

本文参考PyTorch官网的教程,分为五个基本模块来介绍PyTorch。为了避免文章过长,这五个模块分别在五篇博文中介绍。

Part3:使用PyTorch构建一个神经网络

神经网络可以使用touch.nn来构建。nn依赖于autograd来定义模型,并且对其求导。一个nn.Module包含网络的层(layers),同时forward(input)可以返回output。

这是一个简单的前馈网络。它接受输入,然后一层一层向前传播,最后输出一个结果。

训练神经网络的典型步骤如下:

(1)  定义神经网络,该网络包含一些可以学习的参数(如权重)

(2)  在输入数据集上进行迭代

(3)  使用网络对输入数据进行处理

(4)  计算loss(输出值距离正确值有多远)

(5)  将梯度反向传播到网络参数中

(6)  更新网络的权重,使用简单的更新法则:weight = weight - learning_rate* gradient,即:新的权重=旧的权重-学习率*梯度值。

1 定义网络

我们先定义一个网络:

import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):

  def __init__(self):
    super(Net, self).__init__()
    # 1 input image channel, 6 output channels, 5x5 square convolution
    # kernel
    self.conv1 = nn.Conv2d(1, 6, 5)
    self.conv2 = nn.Conv2d(6, 16, 5)
    # an affine operation: y = Wx + b
    self.fc1 = nn.Linear(16 * 5 * 5, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)

  def forward(self, x):
    # Max pooling over a (2, 2) window
    x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
    # If the size is a square you can only specify a single number
    x = F.max_pool2d(F.relu(self.conv2(x)), 2)
    x = x.view(-1, self.num_flat_features(x))
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x

  def num_flat_features(self, x):
    size = x.size()[1:] # all dimensions except the batch dimension
    num_features = 1
    for s in size:
      num_features *= s
    return num_features


net = Net()
print(net)

预期输出:

Net (

 (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))

 (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))

 (fc1): Linear (400 ->120)

 (fc2): Linear (120 ->84)

 (fc3): Linear (84 ->10)

)

你只需要定义forward函数,那么backward函数(梯度在此函数中计算)就会利用autograd来自动定义。你可以在forward函数中使用Tensor的任何运算。

学习到的参数可以被net.parameters()返回。

params = list(net.parameters())
print(len(params))
print(params[0].size()) # conv1's .weight

预期输出:

10

torch.Size([6, 1, 5, 5])

前向计算的输入和输出都是autograd.Variable,注意,这个网络(LeNet)的输入尺寸是32*32。为了在MNIST数据集上使用这个网络,请把图像大小转变为32*32。

input = Variable(torch.randn(1, 1, 32, 32))
out = net(input)
print(out)

预期输出:

Variable containing:
-0.0796 0.0330 0.0103 0.0250 0.1153 -0.0136 0.0234 0.0881 0.0374 -0.0359
[torch.FloatTensor of size 1x10]

将梯度缓冲区归零,然后使用随机梯度值进行反向传播。

net.zero_grad()
out.backward(torch.randn(1, 10))

注意:torch.nn只支持mini-batches. 完整的torch.nn package只支持mini-batch形式的样本作为输入,并且不能只包含一个样本。例如,nn.Conv2d会采用一个4D的Tensor(nSamples* nChannels * Height * Width)。如果你有一个单样本,可以使用input.unsqueeze(0)来添加一个虚假的批量维度。

在继续之前,让我们回顾一下迄今为止所见过的所有类。

概述:

(1)  torch.Tensor——多维数组

(2)  autograd.Variable——包装了一个Tensor,并且记录了应用于其上的运算。与Tensor具有相同的API,同时增加了一些新东西例如backward()。并且有相对于该tensor的梯度值。

(3)  nn.Module——神经网络模块。封装参数的简便方式,对于参数向GPU移动,以及导出、加载等有帮助。

(4)  nn.Parameter——这是一种变量(Variable),当作为一个属性(attribute)分配到一个模块(Module)时,可以自动注册为一个参数(parameter)。

(5)  autograd.Function——执行自动求导运算的前向和反向定义。每一个Variable运算,创建至少一个单独的Function节点,该节点连接到创建了Variable并且编码了它的历史的函数身上。

2 损失函数(Loss Function)

损失函数采用输出值和目标值作为输入参数,来计算输出值距离目标值还有多大差距。在nn package中有很多种不同的损失函数,最简单的一个loss就是nn.MSELoss,它计算输出值和目标值之间的均方差。

例如:

output = net(input)
target = Variable(torch.arange(1, 11)) # a dummy target, for example
criterion = nn.MSELoss()

loss = criterion(output, target)
print(loss)

现在,从反向看loss,使用.grad_fn属性,你会看到一个计算graph如下:

input -> conv2d -> relu -> maxpool2d -> conv2d -> relu -> maxpool2d
   -> view -> linear -> relu -> linear -> relu -> linear
   -> MSELoss
   -> loss

当我们调用loss.backward(),整个的graph关于loss求导,graph中的所有Variables都会有他们自己的.grad变量。

为了理解,我们进行几个反向步骤。

print(loss.grad_fn) # MSELoss
print(loss.grad_fn.next_functions[0][0]) # Linear
print(loss.grad_fn.next_functions[0][0].next_functions[0][0]) # ReLU

预期输出:

<torch.autograd.function.MSELossBackwardobjectat0x7fb3c0dcf4f8>

<torch.autograd.function.AddmmBackwardobjectat0x7fb3c0dcf408>

<AccumulateGradobjectat0x7fb3c0db79e8>

3 反向传播(Backprop)

可以使用loss.backward()进行误差反向传播。你需要清除已经存在的梯度值,否则梯度将会积累到现有的梯度上。

现在,我们调用loss.backward(),看一看conv1的bias 梯度在backward之前和之后的值。

net.zero_grad()   # zeroes the gradient buffers of all parameters

print('conv1.bias.grad before backward')
print(net.conv1.bias.grad)

loss.backward()

print('conv1.bias.grad after backward')
print(net.conv1.bias.grad)

4 更新权重

实践当中最简单的更新法则就是随机梯度下降法( StochasticGradient Descent (SGD))

weight = weight - learning_rate * gradient

执行这个操作的python代码如下:

learning_rate = 0.01
for f in net.parameters():
  f.data.sub_(f.grad.data * learning_rate)

但是当你使用神经网络的时候,你可能会想要尝试多种不同的更新法则,例如SGD,Nesterov-SGD, Adam, RMSProp等。为了实现此功能,有一个package叫做torch.optim已经实现了这些。使用它也很方便:

import torch.optim as optim

# create your optimizer
optimizer = optim.SGD(net.parameters(), lr=0.01)

# in your training loop:
optimizer.zero_grad()  # zero the gradient buffers
output = net(input)
loss = criterion(output, target)
loss.backward()
optimizer.step()  # Does the update

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
win7上python2.7连接mysql数据库的方法
Jan 14 Python
python 信息同时输出到控制台与文件的实例讲解
May 11 Python
python实现批量图片格式转换
Jun 16 Python
Python爬取商家联系电话以及各种数据的方法
Nov 10 Python
Python数据报表之Excel操作模块用法分析
Mar 11 Python
详解小白之KMP算法及python实现
Apr 04 Python
Python获取统计自己的qq群成员信息的方法
Nov 15 Python
Pytorch: 自定义网络层实例
Jan 07 Python
关于python pycharm中输出的内容不全的解决办法
Jan 10 Python
Python3爬虫中Ajax的用法
Jul 10 Python
学生如何注册Pycharm专业版以及pycharm的安装
Sep 24 Python
python设置 matplotlib 正确显示中文的四种方式
May 10 Python
PyTorch的深度学习入门之PyTorch安装和配置
Jun 27 #Python
解决pycharm remote deployment 配置的问题
Jun 27 #Python
python turtle库画一个方格和圆实例
Jun 27 #Python
Python实现的对一个数进行因式分解操作示例
Jun 27 #Python
python pytest进阶之xunit fixture详解
Jun 27 #Python
Python批量查询关键词微信指数实例方法
Jun 27 #Python
Django框架orM与自定义SQL语句混合事务控制操作
Jun 27 #Python
You might like
PHP缓存技术的使用说明
2011/08/06 PHP
PHP字符串中特殊符号的过滤方法介绍
2014/02/18 PHP
PHP程序员基本要求和必备技能
2014/05/09 PHP
php绘图之在图片上写中文和英文的方法
2015/01/24 PHP
php微信公众平台开发(一) 配置接口
2016/12/06 PHP
php实现的统计字数函数定义与使用示例
2017/07/26 PHP
PHP自定义递归函数实现数组转JSON功能【支持GBK编码】
2018/07/17 PHP
javascript 动态调整图片尺寸实现代码
2009/12/28 Javascript
JavaScript 学习笔记(十五)
2010/01/28 Javascript
利用jQuery的deferred对象实现异步按顺序加载JS文件
2013/03/17 Javascript
js获取url中&quot;?&quot;后面的字串方法
2014/05/15 Javascript
jQuery 自定义下拉框(DropDown)附源码下载
2016/07/22 Javascript
jQuery旋转插件jqueryrotate用法详解
2016/10/13 Javascript
Vue.directive自定义指令的使用详解
2017/03/10 Javascript
JS库中的Particles.js在vue上的运用案例分析
2017/09/13 Javascript
JS实现的简单表单验证功能完整实例
2017/10/14 Javascript
JavaScript设计模式之单例模式简单实例教程
2018/07/02 Javascript
vue 实现左右拖拽元素并且不超过他的父元素的宽度
2018/11/30 Javascript
Layui 数据表格批量删除和多条件搜索的实例
2019/09/04 Javascript
Vue中key的作用示例代码详解
2020/06/10 Javascript
[05:01]3.19DOTA2发布会 我们都是刀塔人
2014/03/25 DOTA
[01:28]2014DOTA2国际邀请赛中国区预选赛四大豪门直升机抵达会场
2014/05/24 DOTA
python中set常用操作汇总
2016/06/30 Python
解决Python 命令行执行脚本时,提示导入的包找不到的问题
2019/01/19 Python
利用OpenCV和Python实现查找图片差异
2019/12/19 Python
Python字典深浅拷贝与循环方式方法详解
2020/02/09 Python
完美解决ARIMA模型中plot_acf画不出图的问题
2020/06/04 Python
意大利火车票和铁路通行证专家:ItaliaRail
2019/01/22 全球购物
意大利运动服减价商店:ScontoSport
2020/03/10 全球购物
武汉瑞得软件笔试题
2015/10/27 面试题
个人评价范文分享
2014/01/11 职场文书
法学专业毕业生求职信
2014/06/12 职场文书
继承权公证书范本
2015/01/23 职场文书
护士心得体会范文
2016/01/25 职场文书
JavaScript实现两个数组的交集
2022/03/25 Javascript
Python使用华为API为图像设置多个锚点标签
2022/04/12 Python