浅谈Pytorch中的自动求导函数backward()所需参数的含义


Posted in Python onFebruary 29, 2020

正常来说backward( )函数是要传入参数的,一直没弄明白backward需要传入的参数具体含义,但是没关系,生命在与折腾,咱们来折腾一下,嘿嘿。

对标量自动求导

首先,如果out.backward()中的out是一个标量的话(相当于一个神经网络有一个样本,这个样本有两个属性,神经网络有一个输出)那么此时我的backward函数是不需要输入任何参数的。

import torch
from torch.autograd import Variable
 
a = Variable(torch.Tensor([2,3]),requires_grad=True)
b = a + 3
c = b * 3
out = c.mean()
out.backward()
print('input:')
print(a.data)
print('output:')
print(out.data.item())
print('input gradients are:')
print(a.grad)

运行结果:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

不难看出,我们构建了这样的一个函数:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

所以其求导也很容易看出:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

这是对其进行标量自动求导的结果.

对向量自动求导

如果out.backward()中的out是一个向量(或者理解成1xN的矩阵)的话,我们对向量进行自动求导,看看会发生什么?

先构建这样的一个模型(相当于一个神经网络有一个样本,这个样本有两个属性,神经网络有两个输出):

import torch
from torch.autograd import Variable
 
a = Variable(torch.Tensor([[2.,4.]]),requires_grad=True)
b = torch.zeros(1,2)
b[0,0] = a[0,0] ** 2 
b[0,1] = a[0,1] ** 3 
out = 2 * b
#其参数要传入和out维度一样的矩阵
out.backward(torch.FloatTensor([[1.,1.]]))
print('input:')
print(a.data)
print('output:')
print(out.data)
print('input gradients are:')
print(a.grad)

模型也很简单,不难看出out求导出来的雅克比应该是:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

因为a1 = 2,a2 = 4,所以上面的矩阵应该是:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

运行的结果:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

嗯,的确是8和96,但是仔细想一想,和咱们想要的雅克比矩阵的形式也不一样啊。难道是backward自动把0给省略了?

咱们继续试试,这次在上一个模型的基础上进行小修改,如下:

import torch
from torch.autograd import Variable
 
a = Variable(torch.Tensor([[2.,4.]]),requires_grad=True)
b = torch.zeros(1,2)
b[0,0] = a[0,0] ** 2 + a[0,1] 
b[0,1] = a[0,1] ** 3 + a[0,0]
out = 2 * b
#其参数要传入和out维度一样的矩阵
out.backward(torch.FloatTensor([[1.,1.]]))
print('input:')
print(a.data)
print('output:')
print(out.data)
print('input gradients are:')
print(a.grad)

可以看出这个模型的雅克比应该是:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

运行一下:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

等等,什么鬼?正常来说不应该是

浅谈Pytorch中的自动求导函数backward()所需参数的含义

么?我是谁?我再哪?为什么就给我2个数,而且是 8 + 2 = 10 ,96 + 2 = 98 。难道都是加的 2 ?想一想,刚才咱们backward中传的参数是 [ [ 1 , 1 ] ],难道安装这个关系对应求和了?咱们换个参数来试一试,程序中只更改传入的参数为[ [ 1 , 2 ] ]:

import torch
from torch.autograd import Variable
 
a = Variable(torch.Tensor([[2.,4.]]),requires_grad=True)
b = torch.zeros(1,2)
b[0,0] = a[0,0] ** 2 + a[0,1] 
b[0,1] = a[0,1] ** 3 + a[0,0]
out = 2 * b
#其参数要传入和out维度一样的矩阵
out.backward(torch.FloatTensor([[1.,2.]]))
print('input:')
print(a.data)
print('output:')
print(out.data)
print('input gradients are:')
print(a.grad)

浅谈Pytorch中的自动求导函数backward()所需参数的含义

嗯,这回可以理解了,我们传入的参数,是对原来模型正常求导出来的雅克比矩阵进行线性操作,可以把我们传进的参数(设为arg)看成一个列向量,那么我们得到的结果就是:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

在这个题目中,我们得到的实际是:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

看起来一切完美的解释了,但是就在我刚刚打字的一刻,我意识到官方文档中说k.backward()传入的参数应该和k具有相同的维度,所以如果按上述去解释是解释不通的。哪里出问题了呢?

仔细看了一下,原来是这样的:在对雅克比矩阵进行线性操作的时候,应该把我们传进的参数(设为arg)看成一个行向量(不是列向量),那么我们得到的结果就是:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

也就是:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

这回我们就解释的通了。

现在我们来输出一下雅克比矩阵吧,为了不引起歧义,我们让雅克比矩阵的每个数值都不一样(一开始分析错了就是因为雅克比矩阵中有相同的数据),所以模型小改动如下:

import torch
from torch.autograd import Variable
 
a = Variable(torch.Tensor([[2.,4.]]),requires_grad=True)
b = torch.zeros(1,2)
b[0,0] = a[0,0] ** 2 + a[0,1] 
b[0,1] = a[0,1] ** 3 + a[0,0] * 2
out = 2 * b
#其参数要传入和out维度一样的矩阵
out.backward(torch.FloatTensor([[1,0]]),retain_graph=True)
A_temp = copy.deepcopy(a.grad)
a.grad.zero_()
out.backward(torch.FloatTensor([[0,1]]))
B_temp = a.grad
print('jacobian matrix is:')
print(torch.cat( (A_temp,B_temp),0 ))

如果没问题的话咱们的雅克比矩阵应该是 [ [ 8 , 2 ] , [ 4 , 96 ] ]

好了,下面是见证奇迹的时刻了,不要眨眼睛奥,千万不要眨眼睛… 3 2 1 砰…

浅谈Pytorch中的自动求导函数backward()所需参数的含义

好了,现在总结一下:因为经过了复杂的神经网络之后,out中每个数值都是由很多输入样本的属性(也就是输入数据)线性或者非线性组合而成的,那么out中的每个数值和输入数据的每个数值都有关联,也就是说【out】中的每个数都可以对【a】中每个数求导,那么我们backward()的参数[k1,k2,k3…kn]的含义就是:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

也可以理解成每个out分量对an求导时的权重。

对矩阵自动求导

现在,如果out是一个矩阵呢?

下面的例子也可以理解为:相当于一个神经网络有两个样本,每个样本有两个属性,神经网络有两个输出。

import torch
from torch.autograd import Variable
from torch import nn

a = Variable(torch.FloatTensor([[2,3],[1,2]]),requires_grad=True)
w = Variable( torch.zeros(2,1),requires_grad=True )
out = torch.mm(a,w)
out.backward(torch.FloatTensor([[1.],[1.]]),retain_graph=True)
print("gradients are:{}".format(w.grad.data))

如果前面的例子理解了,那么这个也很好理解,backward输入的参数k是一个2x1的矩阵,2代表的就是样本数量,就是在前面的基础上,再对每个样本进行加权求和。结果是:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

如果有兴趣,也可以拓展一下多个样本的多分类问题,猜一下k的维度应该是【输入样本的个数 * 分类的个数】

好啦,纠结我好久的pytorch自动求导原理算是彻底搞懂啦~~~

以上这篇浅谈Pytorch中的自动求导函数backward()所需参数的含义就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python Mysql自动备份脚本
Jul 14 Python
python使用os模块的os.walk遍历文件夹示例
Jan 27 Python
python使用Queue在多个子进程间交换数据的方法
Apr 18 Python
python监控键盘输入实例代码
Feb 09 Python
matlab中实现矩阵删除一行或一列的方法
Apr 04 Python
Python基本socket通信控制操作示例
Jan 30 Python
python实现nao机器人手臂动作控制
Apr 29 Python
Django中的FBV和CBV用法详解
Sep 15 Python
在 Windows 下搭建高效的 django 开发环境的详细教程
Jul 27 Python
详解Django中views数据查询使用locals()函数进行优化
Aug 24 Python
Python GUI库Tkiner使用方法代码示例
Nov 27 Python
python缺失值的解决方法总结
Jun 09 Python
python数据预处理 :样本分布不均的解决(过采样和欠采样)
Feb 29 #Python
python实现门限回归方式
Feb 29 #Python
Python3.9又更新了:dict内置新功能
Feb 28 #Python
python实现logistic分类算法代码
Feb 28 #Python
python GUI库图形界面开发之PyQt5打印控件QPrinter详细使用方法与实例
Feb 28 #Python
使用sklearn的cross_val_score进行交叉验证实例
Feb 28 #Python
彻底搞懂 python 中文乱码问题(深入分析)
Feb 28 #Python
You might like
PHP设计模式之装饰者模式
2012/02/29 PHP
用PHP实现浏览器点击下载TXT文档的方法详解
2013/06/02 PHP
比较strtr, str_replace和preg_replace三个函数的效率
2013/06/26 PHP
php导入导出excel实例
2013/10/25 PHP
6种php上传图片重命名的方法实例
2013/11/04 PHP
php从csv文件读取数据并输出到网页的方法
2015/03/14 PHP
php编写简单的文章发布程序
2015/06/18 PHP
PHP count_chars()函数讲解
2019/02/14 PHP
laravel5.1 ajax post 传值_token示例
2019/10/24 PHP
PHP接入微信H5支付的方法示例
2019/10/28 PHP
关于跨站脚本攻击问题
2011/12/22 Javascript
用JavaScript判断CSS浏览器类型前缀的两种方法
2015/10/08 Javascript
js 获取本地文件及目录的方法(推荐)
2016/11/10 Javascript
ionic2打包android时gradle无法下载的解决方法
2017/04/05 Javascript
解决vue.js在编写过程中出现空格不规范报错的问题
2017/09/20 Javascript
微信小程序入口场景的问题集合与相关解决方法
2019/06/26 Javascript
python赋值操作方法分享
2013/03/23 Python
Python fileinput模块使用实例
2015/05/28 Python
快速实现基于Python的微信聊天机器人示例代码
2017/03/03 Python
Python如何生成树形图案
2018/01/03 Python
3分钟学会一个Python小技巧
2018/11/23 Python
Django视图扩展类知识点详解
2019/10/25 Python
python使用梯度下降和牛顿法寻找Rosenbrock函数最小值实例
2020/04/02 Python
解决Keras使用GPU资源耗尽的问题
2020/06/22 Python
Python requests上传文件实现步骤
2020/09/15 Python
canvas实现圆绘制的示例代码
2019/09/11 HTML / CSS
Stefania Mode英国:奢华设计师和时尚服装
2017/10/23 全球购物
澳大利亚婴儿礼品公司:The Baby Gift Company
2018/11/04 全球购物
大专生自我鉴定范文
2013/10/01 职场文书
服装厂厂长职责
2013/12/16 职场文书
蓝颜请假条
2014/04/11 职场文书
大学生个人求职信
2014/06/02 职场文书
公司户外活动总结
2014/07/04 职场文书
承诺书的内容有哪些,怎么写?
2019/06/21 职场文书
python - timeit 时间模块
2021/04/06 Python
php去除deprecated的实例方法
2021/11/17 PHP