浅谈Pytorch中的自动求导函数backward()所需参数的含义


Posted in Python onFebruary 29, 2020

正常来说backward( )函数是要传入参数的,一直没弄明白backward需要传入的参数具体含义,但是没关系,生命在与折腾,咱们来折腾一下,嘿嘿。

对标量自动求导

首先,如果out.backward()中的out是一个标量的话(相当于一个神经网络有一个样本,这个样本有两个属性,神经网络有一个输出)那么此时我的backward函数是不需要输入任何参数的。

import torch
from torch.autograd import Variable
 
a = Variable(torch.Tensor([2,3]),requires_grad=True)
b = a + 3
c = b * 3
out = c.mean()
out.backward()
print('input:')
print(a.data)
print('output:')
print(out.data.item())
print('input gradients are:')
print(a.grad)

运行结果:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

不难看出,我们构建了这样的一个函数:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

所以其求导也很容易看出:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

这是对其进行标量自动求导的结果.

对向量自动求导

如果out.backward()中的out是一个向量(或者理解成1xN的矩阵)的话,我们对向量进行自动求导,看看会发生什么?

先构建这样的一个模型(相当于一个神经网络有一个样本,这个样本有两个属性,神经网络有两个输出):

import torch
from torch.autograd import Variable
 
a = Variable(torch.Tensor([[2.,4.]]),requires_grad=True)
b = torch.zeros(1,2)
b[0,0] = a[0,0] ** 2 
b[0,1] = a[0,1] ** 3 
out = 2 * b
#其参数要传入和out维度一样的矩阵
out.backward(torch.FloatTensor([[1.,1.]]))
print('input:')
print(a.data)
print('output:')
print(out.data)
print('input gradients are:')
print(a.grad)

模型也很简单,不难看出out求导出来的雅克比应该是:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

因为a1 = 2,a2 = 4,所以上面的矩阵应该是:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

运行的结果:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

嗯,的确是8和96,但是仔细想一想,和咱们想要的雅克比矩阵的形式也不一样啊。难道是backward自动把0给省略了?

咱们继续试试,这次在上一个模型的基础上进行小修改,如下:

import torch
from torch.autograd import Variable
 
a = Variable(torch.Tensor([[2.,4.]]),requires_grad=True)
b = torch.zeros(1,2)
b[0,0] = a[0,0] ** 2 + a[0,1] 
b[0,1] = a[0,1] ** 3 + a[0,0]
out = 2 * b
#其参数要传入和out维度一样的矩阵
out.backward(torch.FloatTensor([[1.,1.]]))
print('input:')
print(a.data)
print('output:')
print(out.data)
print('input gradients are:')
print(a.grad)

可以看出这个模型的雅克比应该是:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

运行一下:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

等等,什么鬼?正常来说不应该是

浅谈Pytorch中的自动求导函数backward()所需参数的含义

么?我是谁?我再哪?为什么就给我2个数,而且是 8 + 2 = 10 ,96 + 2 = 98 。难道都是加的 2 ?想一想,刚才咱们backward中传的参数是 [ [ 1 , 1 ] ],难道安装这个关系对应求和了?咱们换个参数来试一试,程序中只更改传入的参数为[ [ 1 , 2 ] ]:

import torch
from torch.autograd import Variable
 
a = Variable(torch.Tensor([[2.,4.]]),requires_grad=True)
b = torch.zeros(1,2)
b[0,0] = a[0,0] ** 2 + a[0,1] 
b[0,1] = a[0,1] ** 3 + a[0,0]
out = 2 * b
#其参数要传入和out维度一样的矩阵
out.backward(torch.FloatTensor([[1.,2.]]))
print('input:')
print(a.data)
print('output:')
print(out.data)
print('input gradients are:')
print(a.grad)

浅谈Pytorch中的自动求导函数backward()所需参数的含义

嗯,这回可以理解了,我们传入的参数,是对原来模型正常求导出来的雅克比矩阵进行线性操作,可以把我们传进的参数(设为arg)看成一个列向量,那么我们得到的结果就是:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

在这个题目中,我们得到的实际是:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

看起来一切完美的解释了,但是就在我刚刚打字的一刻,我意识到官方文档中说k.backward()传入的参数应该和k具有相同的维度,所以如果按上述去解释是解释不通的。哪里出问题了呢?

仔细看了一下,原来是这样的:在对雅克比矩阵进行线性操作的时候,应该把我们传进的参数(设为arg)看成一个行向量(不是列向量),那么我们得到的结果就是:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

也就是:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

这回我们就解释的通了。

现在我们来输出一下雅克比矩阵吧,为了不引起歧义,我们让雅克比矩阵的每个数值都不一样(一开始分析错了就是因为雅克比矩阵中有相同的数据),所以模型小改动如下:

import torch
from torch.autograd import Variable
 
a = Variable(torch.Tensor([[2.,4.]]),requires_grad=True)
b = torch.zeros(1,2)
b[0,0] = a[0,0] ** 2 + a[0,1] 
b[0,1] = a[0,1] ** 3 + a[0,0] * 2
out = 2 * b
#其参数要传入和out维度一样的矩阵
out.backward(torch.FloatTensor([[1,0]]),retain_graph=True)
A_temp = copy.deepcopy(a.grad)
a.grad.zero_()
out.backward(torch.FloatTensor([[0,1]]))
B_temp = a.grad
print('jacobian matrix is:')
print(torch.cat( (A_temp,B_temp),0 ))

如果没问题的话咱们的雅克比矩阵应该是 [ [ 8 , 2 ] , [ 4 , 96 ] ]

好了,下面是见证奇迹的时刻了,不要眨眼睛奥,千万不要眨眼睛… 3 2 1 砰…

浅谈Pytorch中的自动求导函数backward()所需参数的含义

好了,现在总结一下:因为经过了复杂的神经网络之后,out中每个数值都是由很多输入样本的属性(也就是输入数据)线性或者非线性组合而成的,那么out中的每个数值和输入数据的每个数值都有关联,也就是说【out】中的每个数都可以对【a】中每个数求导,那么我们backward()的参数[k1,k2,k3…kn]的含义就是:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

也可以理解成每个out分量对an求导时的权重。

对矩阵自动求导

现在,如果out是一个矩阵呢?

下面的例子也可以理解为:相当于一个神经网络有两个样本,每个样本有两个属性,神经网络有两个输出。

import torch
from torch.autograd import Variable
from torch import nn

a = Variable(torch.FloatTensor([[2,3],[1,2]]),requires_grad=True)
w = Variable( torch.zeros(2,1),requires_grad=True )
out = torch.mm(a,w)
out.backward(torch.FloatTensor([[1.],[1.]]),retain_graph=True)
print("gradients are:{}".format(w.grad.data))

如果前面的例子理解了,那么这个也很好理解,backward输入的参数k是一个2x1的矩阵,2代表的就是样本数量,就是在前面的基础上,再对每个样本进行加权求和。结果是:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

如果有兴趣,也可以拓展一下多个样本的多分类问题,猜一下k的维度应该是【输入样本的个数 * 分类的个数】

好啦,纠结我好久的pytorch自动求导原理算是彻底搞懂啦~~~

以上这篇浅谈Pytorch中的自动求导函数backward()所需参数的含义就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python选择排序算法的实现代码
Nov 21 Python
对Python的Django框架中的项目进行单元测试的方法
Apr 11 Python
实例探究Python以并发方式编写高性能端口扫描器的方法
Jun 14 Python
Python实现将一个大文件按段落分隔为多个小文件的简单操作方法
Apr 17 Python
详解Python异常处理中的Finally else的功能
Dec 29 Python
Pycharm导入Python包,模块的图文教程
Jun 13 Python
解决Python安装时报缺少DLL问题【两种解决方法】
Jul 15 Python
django写用户登录判定并跳转制定页面的实例
Aug 21 Python
pytest中文文档之编写断言
Sep 12 Python
keras获得model中某一层的某一个Tensor的输出维度教程
Jan 24 Python
在Pycharm中安装Pandas库方法(简单易懂)
Feb 20 Python
如何利用Matlab制作一款真正的拼图小游戏
May 11 Python
python数据预处理 :样本分布不均的解决(过采样和欠采样)
Feb 29 #Python
python实现门限回归方式
Feb 29 #Python
Python3.9又更新了:dict内置新功能
Feb 28 #Python
python实现logistic分类算法代码
Feb 28 #Python
python GUI库图形界面开发之PyQt5打印控件QPrinter详细使用方法与实例
Feb 28 #Python
使用sklearn的cross_val_score进行交叉验证实例
Feb 28 #Python
彻底搞懂 python 中文乱码问题(深入分析)
Feb 28 #Python
You might like
php批量缩放图片的代码[ini参数控制]
2011/02/11 PHP
PHP连接sql server 2005环境配置及问题解决
2014/08/08 PHP
PHP框架Swoole定时器Timer特性分析
2014/08/19 PHP
php创建多级目录与级联删除文件的方法示例
2019/09/12 PHP
用javascript来实现动画导航效果的代码
2007/12/16 Javascript
jQuery的实现原理的模拟代码 -1 核心部分
2010/08/01 Javascript
拖动布局之保存布局页面cookies篇
2010/10/29 Javascript
3款实用的在线JS代码工具(国外)
2012/03/15 Javascript
jQuery控制图片的hover效果(smartRollover.js)
2012/03/18 Javascript
IE的事件传递-event.cancelBubble示例介绍
2014/01/12 Javascript
Extjs的FileUploadField文件上传出现了两个上传按钮
2014/04/29 Javascript
JS中如何实现Laravel的route函数详解
2017/02/12 Javascript
webpack学习--webpack经典7分钟入门教程
2017/06/28 Javascript
Vue 组件修改根实例的数据的方法
2019/04/02 Javascript
Openlayers实现地图的基本操作
2020/09/28 Javascript
[02:23]1个至宝=115个英雄特效 最“绿”至宝拉比克“魔导师密钥”登场
2018/12/29 DOTA
python的描述符(descriptor)、装饰器(property)造成的一个无限递归问题分享
2014/07/09 Python
Python字典简介以及用法详解
2016/11/15 Python
视觉直观感受若干常用排序算法
2017/04/13 Python
python Tkinter版学生管理系统
2019/02/20 Python
Python实现Selenium自动化Page模式
2019/07/14 Python
Python计算公交发车时间的完整代码
2020/02/12 Python
实现Python3数组旋转的3种算法实例
2020/09/16 Python
美国复古街头服饰精品店:Need Supply Co.
2017/02/22 全球购物
企业厂长岗位职责
2013/12/17 职场文书
优秀士兵个人事迹材料
2014/01/19 职场文书
高级编程求职信模板
2014/02/16 职场文书
新年主持词
2014/03/27 职场文书
感恩的演讲稿
2014/05/06 职场文书
老公给老婆的检讨书(精华篇)
2014/10/18 职场文书
小学教师先进事迹材料
2014/12/15 职场文书
大学生实习证明
2015/06/16 职场文书
小学作文之描写天气
2019/08/15 职场文书
Canvas跟随鼠标炫彩小球的实现
2021/04/11 Javascript
教你怎么用Python实现多路径迷宫
2021/04/29 Python
js前端面试常见浏览器缓存强缓存及协商缓存实例
2022/06/21 Javascript