Pytorch: 自定义网络层实例


Posted in Python onJanuary 07, 2020

自定义Autograd函数

对于浅层的网络,我们可以手动的书写前向传播和反向传播过程。但是当网络变得很大时,特别是在做深度学习时,网络结构变得复杂。前向传播和反向传播也随之变得复杂,手动书写这两个过程就会存在很大的困难。幸运地是在pytorch中存在了自动微分的包,可以用来解决该问题。在使用自动求导的时候,网络的前向传播会定义一个计算图(computational graph),图中的节点是张量(tensor),两个节点之间的边对应了两个张量之间变换关系的函数。有了计算图的存在,张量的梯度计算也变得容易了些。例如, x是一个张量,其属性 x.requires_grad = True,那么 x.grad就是一个保存这个张量x的梯度的一些标量值。

最基础的自动求导操作在底层就是作用在两个张量上。前向传播函数是从输入张量到输出张量的计算过程;反向传播是输入输出张量的梯度(一些标量)并输出输入张量的梯度(一些标量)。在pytorch中我们可以很容易地定义自己的自动求导操作,通过继承torch.autograd.Function并定义forward和backward函数。

forward(): 前向传播操作。可以输入任意多的参数,任意的python对象都可以。

backward():反向传播(梯度公式)。输出的梯度个数需要与所使用的张量个数保持一致,且返回的顺序也要对应起来。

# Inherit from Function
class LinearFunction(Function):

  # Note that both forward and backward are @staticmethods
  @staticmethod
  # bias is an optional argument
  def forward(ctx, input, weight, bias=None):
    # ctx在这里类似self,ctx的属性可以在backward中调用
    ctx.save_for_backward(input, weight, bias)
    output = input.mm(weight.t())
    if bias is not None:
      output += bias.unsqueeze(0).expand_as(output)
    return output

  # This function has only a single output, so it gets only one gradient
  @staticmethod
  def backward(ctx, grad_output):
    # This is a pattern that is very convenient - at the top of backward
    # unpack saved_tensors and initialize all gradients w.r.t. inputs to
    # None. Thanks to the fact that additional trailing Nones are
    # ignored, the return statement is simple even when the function has
    # optional inputs.
    input, weight, bias = ctx.saved_tensors
    grad_input = grad_weight = grad_bias = None

    # These needs_input_grad checks are optional and there only to
    # improve efficiency. If you want to make your code simpler, you can
    # skip them. Returning gradients for inputs that don't require it is
    # not an error.
    if ctx.needs_input_grad[0]:
      grad_input = grad_output.mm(weight)
    if ctx.needs_input_grad[1]:
      grad_weight = grad_output.t().mm(input)
    if bias is not None and ctx.needs_input_grad[2]:
      grad_bias = grad_output.sum(0).squeeze(0)

    return grad_input, grad_weight, grad_bias

#调用自定义的自动求导函数
linear = LinearFunction.apply(*args) #前向传播
linear.backward()#反向传播
linear.grad_fn.apply(*args)#反向传播

对于非参数化的张量(权重是常量,不需要更新),此时可以定义为:

class MulConstant(Function):
  @staticmethod
  def forward(ctx, tensor, constant):
    # ctx is a context object that can be used to stash information
    # for backward computation
    ctx.constant = constant
    return tensor * constant

  @staticmethod
  def backward(ctx, grad_output):
    # We return as many input gradients as there were arguments.
    # Gradients of non-Tensor arguments to forward must be None.
    return grad_output * ctx.constant, None

高阶导数

grad_x =t.autograd.grad(y, x, create_graph=True)

grad_grad_x = t.autograd.grad(grad_x[0],x)

自定义Module

计算图和自动求导在定义复杂网络和求梯度的时候非常好用,但对于大型的网络,这个还是有点偏底层。在我们构建网络的时候,经常希望将计算限制在每个层之内(参数更新分层更新)。而且在TensorFlow等其他深度学习框架中都提供了高级抽象结构。因此,在pytorch中也提供了类似的包nn,它定义了一组等价于层(layer)的模块(Modules)。一个Module接受输入张量并得到输出张量,同时也会包含可学习的参数。

有时候,我们希望运用一些新的且nn包中不存在的Module。此时就需要定义自己的Module了。自定义的Module需要继承nn.Module且自定义forward函数。其中forward函数可以接受输入张量并利用其它模型或者其他自动求导操作来产生输出张量。但并不需要重写backward函数,因此nn使用了autograd。这也就意味着,需要自定义Module, 都必须有对应的autograd函数以调用其中的backward。

class Linear(nn.Module):
  def __init__(self, input_features, output_features, bias=True):
    super(Linear, self).__init__()
    self.input_features = input_features
    self.output_features = output_features

    # nn.Parameter is a special kind of Tensor, that will get
    # automatically registered as Module's parameter once it's assigned
    # as an attribute. Parameters and buffers need to be registered, or
    # they won't appear in .parameters() (doesn't apply to buffers), and
    # won't be converted when e.g. .cuda() is called. You can use
    # .register_buffer() to register buffers.
    # (很重要!!!参数一定需要梯度!)nn.Parameters require gradients by default.
    self.weight = nn.Parameter(torch.Tensor(output_features, input_features))
    if bias:
      self.bias = nn.Parameter(torch.Tensor(output_features))
    else:
      # You should always register all possible parameters, but the
      # optional ones can be None if you want.
      self.register_parameter('bias', None)

    # Not a very smart way to initialize weights
    self.weight.data.uniform_(-0.1, 0.1)
    if bias is not None:
      self.bias.data.uniform_(-0.1, 0.1)

  def forward(self, input):
    # See the autograd section for explanation of what happens here.
    return LinearFunction.apply(input, self.weight, self.bias)

  def extra_repr(self):
    # (Optional)Set the extra information about this module. You can test
    # it by printing an object of this class.
    return 'in_features={}, out_features={}, bias={}'.format(
      self.in_features, self.out_features, self.bias is not None

Function与Module的异同

Function与Module都可以对pytorch进行自定义拓展,使其满足网络的需求,但这两者还是有十分重要的不同:

Function一般只定义一个操作,因为其无法保存参数,因此适用于激活函数、pooling等操作;Module是保存了参数,因此适合于定义一层,如线性层,卷积层,也适用于定义一个网络

Function需要定义三个方法:init, forward, backward(需要自己写求导公式);Module:只需定义init和forward,而backward的计算由自动求导机制构成

可以不严谨的认为,Module是由一系列Function组成,因此其在forward的过程中,Function和Variable组成了计算图,在backward时,只需调用Function的backward就得到结果,因此Module不需要再定义backward。

Module不仅包括了Function,还包括了对应的参数,以及其他函数与变量,这是Function所不具备的。

module 是 pytorch 组织神经网络的基本方式。Module 包含了模型的参数以及计算逻辑。Function 承载了实际的功能,定义了前向和后向的计算逻辑。

Module 是任何神经网络的基类,pytorch 中所有模型都必需是 Module 的子类。 Module 可以套嵌,构成树状结构。一个 Module 可以通过将其他 Module 做为属性的方式,完成套嵌。

Function 是 pytorch 自动求导机制的核心类。Function 是无参数或者说无状态的,它只负责接收输入,返回相应的输出;对于反向,它接收输出相应的梯度,返回输入相应的梯度。

在调用loss.backward()时,使用的是Function子类中定义的backward()函数。

以上这篇Pytorch: 自定义网络层实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现批量将word转html并将html内容发布至网站的方法
Jul 14 Python
python实现网站的模拟登录
Jan 04 Python
centos6.7安装python2.7.11的具体方法
Jan 16 Python
matplotlib设置legend图例代码示例
Dec 19 Python
详解Python中 sys.argv[]的用法简明解释
Dec 20 Python
Python简单实现阿拉伯数字和罗马数字的互相转换功能示例
Apr 17 Python
flask-socketio实现WebSocket的方法
Jul 31 Python
Python 数值区间处理_对interval 库的快速入门详解
Nov 16 Python
python的移位操作实现详解
Aug 21 Python
Python 生成VOC格式的标签实例
Mar 10 Python
pyspark给dataframe增加新的一列的实现示例
Apr 24 Python
Python DES加密实现原理及实例解析
Jul 17 Python
Python StringIO如何在内存中读写str
Jan 07 #Python
Python内置数据类型list各方法的性能测试过程解析
Jan 07 #Python
python模拟实现斗地主发牌
Jan 07 #Python
python全局变量引用与修改过程解析
Jan 07 #Python
python__new__内置静态方法使用解析
Jan 07 #Python
Python常用模块sys,os,time,random功能与用法实例分析
Jan 07 #Python
python单例设计模式实现解析
Jan 07 #Python
You might like
页面利用渐进式JPEG来提升用户体验度
2014/12/01 PHP
php+xml实现在线英文词典之添加词条的方法
2015/01/23 PHP
PHPStrom 新建FTP项目以及在线操作教程
2016/10/16 PHP
php正则去除网页中所有的html,js,css,注释的实现方法
2016/11/03 PHP
php实现mysql连接池效果实现代码
2018/01/25 PHP
PHP INT类型在内存中占字节详解
2019/07/20 PHP
记Laravel调用Gin接口调用formData上传文件的实现方法
2019/12/12 PHP
js parseInt("08")未指定进位制问题
2010/06/19 Javascript
js遍历子节点子元素附属性及方法
2014/08/19 Javascript
最流行的Node.js精简型和全栈型开发框架介绍
2015/02/26 Javascript
JQuery控制Radio选中方法分析
2015/05/29 Javascript
Javascript中的迭代、归并方法详解
2016/06/14 Javascript
jQuery插件HighCharts绘制2D柱状图、折线图的组合双轴图效果示例【附demo源码下载】
2017/03/09 Javascript
ActiveX控件的使用-js实现打印超市小票功能代码详解
2017/11/22 Javascript
浅谈ElementUI中switch回调函数change的参数问题
2018/08/24 Javascript
JS实现京东商品分类侧边栏
2020/12/11 Javascript
python实现划词翻译
2020/04/23 Python
Python中3种内建数据结构:列表、元组和字典
2014/11/30 Python
python单元测试unittest实例详解
2015/05/11 Python
初步讲解Python中的元组概念
2015/05/21 Python
Python调用SQLPlus来操作和解析Oracle数据库的方法
2016/04/09 Python
用于业余项目的8个优秀Python库
2018/09/21 Python
pytorch torch.expand和torch.repeat的区别详解
2019/11/05 Python
python中设置超时跳过,超时退出的方式
2019/12/13 Python
Python 余弦相似度与皮尔逊相关系数 计算实例
2019/12/23 Python
Python Django搭建网站流程图解
2020/06/13 Python
canvas进阶之贝塞尔公式推导与物体跟随复杂曲线的轨迹运动
2018/01/10 HTML / CSS
美国一站式电动和手动工具商店:International Tool
2020/11/26 全球购物
小班重阳节活动方案
2014/02/08 职场文书
档案室主任岗位职责
2014/02/12 职场文书
论文诚信承诺书
2014/05/23 职场文书
关于保护环境的标语
2014/06/09 职场文书
追悼会悼词大全
2015/06/23 职场文书
小学英语教学经验交流材料
2015/11/02 职场文书
分析并发编程之LongAdder原理
2021/06/29 Java/Android
关于springboot配置druid数据源不生效问题(踩坑记)
2021/09/25 Java/Android