浅谈Pytorch中的自动求导函数backward()所需参数的含义


Posted in Python onFebruary 29, 2020

正常来说backward( )函数是要传入参数的,一直没弄明白backward需要传入的参数具体含义,但是没关系,生命在与折腾,咱们来折腾一下,嘿嘿。

对标量自动求导

首先,如果out.backward()中的out是一个标量的话(相当于一个神经网络有一个样本,这个样本有两个属性,神经网络有一个输出)那么此时我的backward函数是不需要输入任何参数的。

import torch
from torch.autograd import Variable
 
a = Variable(torch.Tensor([2,3]),requires_grad=True)
b = a + 3
c = b * 3
out = c.mean()
out.backward()
print('input:')
print(a.data)
print('output:')
print(out.data.item())
print('input gradients are:')
print(a.grad)

运行结果:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

不难看出,我们构建了这样的一个函数:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

所以其求导也很容易看出:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

这是对其进行标量自动求导的结果.

对向量自动求导

如果out.backward()中的out是一个向量(或者理解成1xN的矩阵)的话,我们对向量进行自动求导,看看会发生什么?

先构建这样的一个模型(相当于一个神经网络有一个样本,这个样本有两个属性,神经网络有两个输出):

import torch
from torch.autograd import Variable
 
a = Variable(torch.Tensor([[2.,4.]]),requires_grad=True)
b = torch.zeros(1,2)
b[0,0] = a[0,0] ** 2 
b[0,1] = a[0,1] ** 3 
out = 2 * b
#其参数要传入和out维度一样的矩阵
out.backward(torch.FloatTensor([[1.,1.]]))
print('input:')
print(a.data)
print('output:')
print(out.data)
print('input gradients are:')
print(a.grad)

模型也很简单,不难看出out求导出来的雅克比应该是:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

因为a1 = 2,a2 = 4,所以上面的矩阵应该是:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

运行的结果:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

嗯,的确是8和96,但是仔细想一想,和咱们想要的雅克比矩阵的形式也不一样啊。难道是backward自动把0给省略了?

咱们继续试试,这次在上一个模型的基础上进行小修改,如下:

import torch
from torch.autograd import Variable
 
a = Variable(torch.Tensor([[2.,4.]]),requires_grad=True)
b = torch.zeros(1,2)
b[0,0] = a[0,0] ** 2 + a[0,1] 
b[0,1] = a[0,1] ** 3 + a[0,0]
out = 2 * b
#其参数要传入和out维度一样的矩阵
out.backward(torch.FloatTensor([[1.,1.]]))
print('input:')
print(a.data)
print('output:')
print(out.data)
print('input gradients are:')
print(a.grad)

可以看出这个模型的雅克比应该是:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

运行一下:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

等等,什么鬼?正常来说不应该是

浅谈Pytorch中的自动求导函数backward()所需参数的含义

么?我是谁?我再哪?为什么就给我2个数,而且是 8 + 2 = 10 ,96 + 2 = 98 。难道都是加的 2 ?想一想,刚才咱们backward中传的参数是 [ [ 1 , 1 ] ],难道安装这个关系对应求和了?咱们换个参数来试一试,程序中只更改传入的参数为[ [ 1 , 2 ] ]:

import torch
from torch.autograd import Variable
 
a = Variable(torch.Tensor([[2.,4.]]),requires_grad=True)
b = torch.zeros(1,2)
b[0,0] = a[0,0] ** 2 + a[0,1] 
b[0,1] = a[0,1] ** 3 + a[0,0]
out = 2 * b
#其参数要传入和out维度一样的矩阵
out.backward(torch.FloatTensor([[1.,2.]]))
print('input:')
print(a.data)
print('output:')
print(out.data)
print('input gradients are:')
print(a.grad)

浅谈Pytorch中的自动求导函数backward()所需参数的含义

嗯,这回可以理解了,我们传入的参数,是对原来模型正常求导出来的雅克比矩阵进行线性操作,可以把我们传进的参数(设为arg)看成一个列向量,那么我们得到的结果就是:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

在这个题目中,我们得到的实际是:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

看起来一切完美的解释了,但是就在我刚刚打字的一刻,我意识到官方文档中说k.backward()传入的参数应该和k具有相同的维度,所以如果按上述去解释是解释不通的。哪里出问题了呢?

仔细看了一下,原来是这样的:在对雅克比矩阵进行线性操作的时候,应该把我们传进的参数(设为arg)看成一个行向量(不是列向量),那么我们得到的结果就是:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

也就是:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

这回我们就解释的通了。

现在我们来输出一下雅克比矩阵吧,为了不引起歧义,我们让雅克比矩阵的每个数值都不一样(一开始分析错了就是因为雅克比矩阵中有相同的数据),所以模型小改动如下:

import torch
from torch.autograd import Variable
 
a = Variable(torch.Tensor([[2.,4.]]),requires_grad=True)
b = torch.zeros(1,2)
b[0,0] = a[0,0] ** 2 + a[0,1] 
b[0,1] = a[0,1] ** 3 + a[0,0] * 2
out = 2 * b
#其参数要传入和out维度一样的矩阵
out.backward(torch.FloatTensor([[1,0]]),retain_graph=True)
A_temp = copy.deepcopy(a.grad)
a.grad.zero_()
out.backward(torch.FloatTensor([[0,1]]))
B_temp = a.grad
print('jacobian matrix is:')
print(torch.cat( (A_temp,B_temp),0 ))

如果没问题的话咱们的雅克比矩阵应该是 [ [ 8 , 2 ] , [ 4 , 96 ] ]

好了,下面是见证奇迹的时刻了,不要眨眼睛奥,千万不要眨眼睛… 3 2 1 砰…

浅谈Pytorch中的自动求导函数backward()所需参数的含义

好了,现在总结一下:因为经过了复杂的神经网络之后,out中每个数值都是由很多输入样本的属性(也就是输入数据)线性或者非线性组合而成的,那么out中的每个数值和输入数据的每个数值都有关联,也就是说【out】中的每个数都可以对【a】中每个数求导,那么我们backward()的参数[k1,k2,k3…kn]的含义就是:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

也可以理解成每个out分量对an求导时的权重。

对矩阵自动求导

现在,如果out是一个矩阵呢?

下面的例子也可以理解为:相当于一个神经网络有两个样本,每个样本有两个属性,神经网络有两个输出。

import torch
from torch.autograd import Variable
from torch import nn

a = Variable(torch.FloatTensor([[2,3],[1,2]]),requires_grad=True)
w = Variable( torch.zeros(2,1),requires_grad=True )
out = torch.mm(a,w)
out.backward(torch.FloatTensor([[1.],[1.]]),retain_graph=True)
print("gradients are:{}".format(w.grad.data))

如果前面的例子理解了,那么这个也很好理解,backward输入的参数k是一个2x1的矩阵,2代表的就是样本数量,就是在前面的基础上,再对每个样本进行加权求和。结果是:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

如果有兴趣,也可以拓展一下多个样本的多分类问题,猜一下k的维度应该是【输入样本的个数 * 分类的个数】

好啦,纠结我好久的pytorch自动求导原理算是彻底搞懂啦~~~

以上这篇浅谈Pytorch中的自动求导函数backward()所需参数的含义就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中输出ASCII大文字、艺术字、字符字小技巧
Apr 28 Python
Python常用小技巧总结
Jun 01 Python
python实现将英文单词表示的数字转换成阿拉伯数字的方法
Jul 02 Python
Django中的CACHE_BACKEND参数和站点级Cache设置
Jul 23 Python
Python 爬虫学习笔记之单线程爬虫
Sep 21 Python
Python实现购物车功能的方法分析
Nov 10 Python
分析python动态规划的递归、非递归实现
Mar 04 Python
Python PyQt4实现QQ抽屉效果
Apr 20 Python
解决Pyinstaller 打包exe文件 取消dos窗口(黑框框)的问题
Jun 21 Python
python程序快速缩进多行代码方法总结
Jun 23 Python
pandas 使用均值填充缺失值列的小技巧分享
Jul 04 Python
python basemap 画出经纬度并标定的实例
Jul 09 Python
python数据预处理 :样本分布不均的解决(过采样和欠采样)
Feb 29 #Python
python实现门限回归方式
Feb 29 #Python
Python3.9又更新了:dict内置新功能
Feb 28 #Python
python实现logistic分类算法代码
Feb 28 #Python
python GUI库图形界面开发之PyQt5打印控件QPrinter详细使用方法与实例
Feb 28 #Python
使用sklearn的cross_val_score进行交叉验证实例
Feb 28 #Python
彻底搞懂 python 中文乱码问题(深入分析)
Feb 28 #Python
You might like
php编写的抽奖程序中奖概率算法
2015/05/14 PHP
Nginx服务器上安装并配置PHPMyAdmin的教程
2015/08/18 PHP
Symfony2之session与cookie用法小结
2016/03/18 PHP
JavaScript 继承详解(一)
2009/07/13 Javascript
jQuery中使用了document和window哪些属性和方法小结
2011/09/13 Javascript
AngularJS语法详解(续)
2015/01/23 Javascript
jQuery插件Elastislide实现响应式的焦点图无缝滚动切换特效
2015/04/12 Javascript
理解javascript封装
2016/02/23 Javascript
javascript关于继承解析
2016/05/10 Javascript
用AngularJS的指令实现tabs切换效果
2016/08/31 Javascript
JS实现获取图片大小和预览的方法完整实例【兼容IE和其它浏览器】
2017/04/24 Javascript
JS 使用 window对象的print方法实现分页打印功能
2018/05/16 Javascript
JavaScript 高性能数组去重的方法
2018/09/20 Javascript
用Electron写个带界面的nodejs爬虫的实现方法
2019/01/29 NodeJs
javascript关于“时间”的一次探索
2019/07/24 Javascript
Javascript原生ajax请求代码实例
2020/02/20 Javascript
VUE动态生成word的实现
2020/07/26 Javascript
python读写json文件的简单实现
2017/04/11 Python
pandas 实现将重复表格去重,并重新转换为表格的方法
2018/04/18 Python
详解Python学习之安装pandas
2019/04/16 Python
将Python文件打包成.EXE可执行文件的方法
2019/08/11 Python
python datetime中strptime用法详解
2019/08/29 Python
python计算无向图节点度的实例代码
2019/11/22 Python
关于Theano和Tensorflow多GPU使用问题
2020/06/19 Python
非洲NO.1网上商店:Jumia肯尼亚
2016/08/18 全球购物
印度婴儿用品在线商店:Firstcry.com
2016/12/05 全球购物
关于Java String的一道面试题
2013/09/29 面试题
新闻专业本科生的自我评价分享
2013/11/20 职场文书
槐乡的孩子教学反思
2014/04/27 职场文书
圣诞节活动策划方案
2014/06/09 职场文书
机关中层领导干部群众路线教育实践活动个人对照检查材料
2014/09/24 职场文书
就业推荐表导师评语
2014/12/31 职场文书
杭州西湖英语导游词
2015/02/03 职场文书
幼儿园教师个人工作总结2015
2015/05/12 职场文书
2016年“抗战胜利纪念日”71周年校园广播稿
2015/12/18 职场文书
使用springMVC所需要的pom配置
2021/09/15 Java/Android