PyTorch的深度学习入门教程之构建神经网络


Posted in Python onJune 27, 2019

前言

本文参考PyTorch官网的教程,分为五个基本模块来介绍PyTorch。为了避免文章过长,这五个模块分别在五篇博文中介绍。

Part3:使用PyTorch构建一个神经网络

神经网络可以使用touch.nn来构建。nn依赖于autograd来定义模型,并且对其求导。一个nn.Module包含网络的层(layers),同时forward(input)可以返回output。

这是一个简单的前馈网络。它接受输入,然后一层一层向前传播,最后输出一个结果。

训练神经网络的典型步骤如下:

(1)  定义神经网络,该网络包含一些可以学习的参数(如权重)

(2)  在输入数据集上进行迭代

(3)  使用网络对输入数据进行处理

(4)  计算loss(输出值距离正确值有多远)

(5)  将梯度反向传播到网络参数中

(6)  更新网络的权重,使用简单的更新法则:weight = weight - learning_rate* gradient,即:新的权重=旧的权重-学习率*梯度值。

1 定义网络

我们先定义一个网络:

import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):

  def __init__(self):
    super(Net, self).__init__()
    # 1 input image channel, 6 output channels, 5x5 square convolution
    # kernel
    self.conv1 = nn.Conv2d(1, 6, 5)
    self.conv2 = nn.Conv2d(6, 16, 5)
    # an affine operation: y = Wx + b
    self.fc1 = nn.Linear(16 * 5 * 5, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)

  def forward(self, x):
    # Max pooling over a (2, 2) window
    x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
    # If the size is a square you can only specify a single number
    x = F.max_pool2d(F.relu(self.conv2(x)), 2)
    x = x.view(-1, self.num_flat_features(x))
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x

  def num_flat_features(self, x):
    size = x.size()[1:] # all dimensions except the batch dimension
    num_features = 1
    for s in size:
      num_features *= s
    return num_features


net = Net()
print(net)

预期输出:

Net (

 (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))

 (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))

 (fc1): Linear (400 ->120)

 (fc2): Linear (120 ->84)

 (fc3): Linear (84 ->10)

)

你只需要定义forward函数,那么backward函数(梯度在此函数中计算)就会利用autograd来自动定义。你可以在forward函数中使用Tensor的任何运算。

学习到的参数可以被net.parameters()返回。

params = list(net.parameters())
print(len(params))
print(params[0].size()) # conv1's .weight

预期输出:

10

torch.Size([6, 1, 5, 5])

前向计算的输入和输出都是autograd.Variable,注意,这个网络(LeNet)的输入尺寸是32*32。为了在MNIST数据集上使用这个网络,请把图像大小转变为32*32。

input = Variable(torch.randn(1, 1, 32, 32))
out = net(input)
print(out)

预期输出:

Variable containing:
-0.0796 0.0330 0.0103 0.0250 0.1153 -0.0136 0.0234 0.0881 0.0374 -0.0359
[torch.FloatTensor of size 1x10]

将梯度缓冲区归零,然后使用随机梯度值进行反向传播。

net.zero_grad()
out.backward(torch.randn(1, 10))

注意:torch.nn只支持mini-batches. 完整的torch.nn package只支持mini-batch形式的样本作为输入,并且不能只包含一个样本。例如,nn.Conv2d会采用一个4D的Tensor(nSamples* nChannels * Height * Width)。如果你有一个单样本,可以使用input.unsqueeze(0)来添加一个虚假的批量维度。

在继续之前,让我们回顾一下迄今为止所见过的所有类。

概述:

(1)  torch.Tensor——多维数组

(2)  autograd.Variable——包装了一个Tensor,并且记录了应用于其上的运算。与Tensor具有相同的API,同时增加了一些新东西例如backward()。并且有相对于该tensor的梯度值。

(3)  nn.Module——神经网络模块。封装参数的简便方式,对于参数向GPU移动,以及导出、加载等有帮助。

(4)  nn.Parameter——这是一种变量(Variable),当作为一个属性(attribute)分配到一个模块(Module)时,可以自动注册为一个参数(parameter)。

(5)  autograd.Function——执行自动求导运算的前向和反向定义。每一个Variable运算,创建至少一个单独的Function节点,该节点连接到创建了Variable并且编码了它的历史的函数身上。

2 损失函数(Loss Function)

损失函数采用输出值和目标值作为输入参数,来计算输出值距离目标值还有多大差距。在nn package中有很多种不同的损失函数,最简单的一个loss就是nn.MSELoss,它计算输出值和目标值之间的均方差。

例如:

output = net(input)
target = Variable(torch.arange(1, 11)) # a dummy target, for example
criterion = nn.MSELoss()

loss = criterion(output, target)
print(loss)

现在,从反向看loss,使用.grad_fn属性,你会看到一个计算graph如下:

input -> conv2d -> relu -> maxpool2d -> conv2d -> relu -> maxpool2d
   -> view -> linear -> relu -> linear -> relu -> linear
   -> MSELoss
   -> loss

当我们调用loss.backward(),整个的graph关于loss求导,graph中的所有Variables都会有他们自己的.grad变量。

为了理解,我们进行几个反向步骤。

print(loss.grad_fn) # MSELoss
print(loss.grad_fn.next_functions[0][0]) # Linear
print(loss.grad_fn.next_functions[0][0].next_functions[0][0]) # ReLU

预期输出:

<torch.autograd.function.MSELossBackwardobjectat0x7fb3c0dcf4f8>

<torch.autograd.function.AddmmBackwardobjectat0x7fb3c0dcf408>

<AccumulateGradobjectat0x7fb3c0db79e8>

3 反向传播(Backprop)

可以使用loss.backward()进行误差反向传播。你需要清除已经存在的梯度值,否则梯度将会积累到现有的梯度上。

现在,我们调用loss.backward(),看一看conv1的bias 梯度在backward之前和之后的值。

net.zero_grad()   # zeroes the gradient buffers of all parameters

print('conv1.bias.grad before backward')
print(net.conv1.bias.grad)

loss.backward()

print('conv1.bias.grad after backward')
print(net.conv1.bias.grad)

4 更新权重

实践当中最简单的更新法则就是随机梯度下降法( StochasticGradient Descent (SGD))

weight = weight - learning_rate * gradient

执行这个操作的python代码如下:

learning_rate = 0.01
for f in net.parameters():
  f.data.sub_(f.grad.data * learning_rate)

但是当你使用神经网络的时候,你可能会想要尝试多种不同的更新法则,例如SGD,Nesterov-SGD, Adam, RMSProp等。为了实现此功能,有一个package叫做torch.optim已经实现了这些。使用它也很方便:

import torch.optim as optim

# create your optimizer
optimizer = optim.SGD(net.parameters(), lr=0.01)

# in your training loop:
optimizer.zero_grad()  # zero the gradient buffers
output = net(input)
loss = criterion(output, target)
loss.backward()
optimizer.step()  # Does the update

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
go和python调用其它程序并得到程序输出
Feb 10 Python
python统计文本文件内单词数量的方法
May 30 Python
使用Python的Twisted框架编写非阻塞程序的代码示例
May 25 Python
Python实现的多线程http压力测试代码
Feb 08 Python
Python实现字符串格式化的方法小结
Feb 20 Python
python pandas 对series和dataframe的重置索引reindex方法
Jun 07 Python
pandas去重复行并分类汇总的实现方法
Jan 29 Python
Dlib+OpenCV深度学习人脸识别的方法示例
May 14 Python
Python3.6 + TensorFlow 安装配置图文教程(Windows 64 bit)
Feb 24 Python
浅谈python输出列表元素的所有排列形式
Feb 26 Python
Prometheus开发中间件Exporter过程详解
Nov 30 Python
利用Python第三方库实现预测NBA比赛结果
Jun 21 Python
PyTorch的深度学习入门之PyTorch安装和配置
Jun 27 #Python
解决pycharm remote deployment 配置的问题
Jun 27 #Python
python turtle库画一个方格和圆实例
Jun 27 #Python
Python实现的对一个数进行因式分解操作示例
Jun 27 #Python
python pytest进阶之xunit fixture详解
Jun 27 #Python
Python批量查询关键词微信指数实例方法
Jun 27 #Python
Django框架orM与自定义SQL语句混合事务控制操作
Jun 27 #Python
You might like
ThinkPHP缓存方法S()概述
2014/06/13 PHP
php中cookie实现二级域名可访问操作的方法
2014/11/11 PHP
PHP输出XML格式数据的方法总结
2017/02/08 PHP
php中html_entity_decode实现HTML实体转义
2018/06/13 PHP
php进程daemon化的正确实现方法
2018/09/06 PHP
Linux下 php7安装redis的方法
2018/11/01 PHP
php伪静态验证码不显示的解决方案
2019/09/26 PHP
JS获取客户端IP地址、MAC和主机名的7个方法汇总
2014/07/21 Javascript
AngularJS 自定义指令详解及示例代码
2016/08/17 Javascript
mui上拉加载功能实例详解
2017/04/13 Javascript
JS实现全屏预览F11功能的示例代码
2018/07/23 Javascript
KOA+egg.js集成kafka消息队列的示例
2018/11/09 Javascript
vue.js使用v-model实现表单元素(input) 双向数据绑定功能示例
2019/03/08 Javascript
[02:23]2014DOTA2国际邀请赛中国战队回顾
2014/08/01 DOTA
[51:29]完美世界DOTA2联赛循环赛 Matador vs Forest BO2第一场 11.05
2020/11/05 DOTA
python实现汉诺塔递归算法经典案例
2021/03/01 Python
深入解析Python的Tornado框架中内置的模板引擎
2016/07/11 Python
动感网页相册 python编写简单文件夹内图片浏览工具
2016/08/17 Python
python中subprocess批量执行linux命令
2018/04/27 Python
浅谈django orm 优化
2018/08/18 Python
python scatter函数用法实例详解
2020/02/11 Python
Python3之乱码\xe6\x97\xa0\xe6\xb3\x95处理方式
2020/05/11 Python
html5组织文档结构_动力节点Java学院整理
2017/07/11 HTML / CSS
美国领先的户外服装与装备用品店:Moosejaw
2016/08/25 全球购物
优衣库英国官网:UNIQLO英国
2016/12/25 全球购物
Three Graces London官网:英国奢侈品牌
2021/03/18 全球购物
应用服务器有那些
2012/01/19 面试题
Linux操作面试题
2012/05/16 面试题
青年志愿者事迹材料
2014/02/07 职场文书
爱牙日活动总结
2014/08/29 职场文书
乡镇党的群众路线教育实践活动个人对照检查材料
2014/09/23 职场文书
2014民事授权委托书范本
2014/09/29 职场文书
病房管理制度范本
2015/08/06 职场文书
mysql备份策略的实现(全量备份+增量备份)
2021/07/07 MySQL
golang使用map实现去除重复数组
2022/04/14 Golang
create-react-app开发常用配置教程
2022/06/25 Javascript