浅谈Pytorch中的自动求导函数backward()所需参数的含义


Posted in Python onFebruary 29, 2020

正常来说backward( )函数是要传入参数的,一直没弄明白backward需要传入的参数具体含义,但是没关系,生命在与折腾,咱们来折腾一下,嘿嘿。

对标量自动求导

首先,如果out.backward()中的out是一个标量的话(相当于一个神经网络有一个样本,这个样本有两个属性,神经网络有一个输出)那么此时我的backward函数是不需要输入任何参数的。

import torch
from torch.autograd import Variable
 
a = Variable(torch.Tensor([2,3]),requires_grad=True)
b = a + 3
c = b * 3
out = c.mean()
out.backward()
print('input:')
print(a.data)
print('output:')
print(out.data.item())
print('input gradients are:')
print(a.grad)

运行结果:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

不难看出,我们构建了这样的一个函数:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

所以其求导也很容易看出:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

这是对其进行标量自动求导的结果.

对向量自动求导

如果out.backward()中的out是一个向量(或者理解成1xN的矩阵)的话,我们对向量进行自动求导,看看会发生什么?

先构建这样的一个模型(相当于一个神经网络有一个样本,这个样本有两个属性,神经网络有两个输出):

import torch
from torch.autograd import Variable
 
a = Variable(torch.Tensor([[2.,4.]]),requires_grad=True)
b = torch.zeros(1,2)
b[0,0] = a[0,0] ** 2 
b[0,1] = a[0,1] ** 3 
out = 2 * b
#其参数要传入和out维度一样的矩阵
out.backward(torch.FloatTensor([[1.,1.]]))
print('input:')
print(a.data)
print('output:')
print(out.data)
print('input gradients are:')
print(a.grad)

模型也很简单,不难看出out求导出来的雅克比应该是:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

因为a1 = 2,a2 = 4,所以上面的矩阵应该是:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

运行的结果:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

嗯,的确是8和96,但是仔细想一想,和咱们想要的雅克比矩阵的形式也不一样啊。难道是backward自动把0给省略了?

咱们继续试试,这次在上一个模型的基础上进行小修改,如下:

import torch
from torch.autograd import Variable
 
a = Variable(torch.Tensor([[2.,4.]]),requires_grad=True)
b = torch.zeros(1,2)
b[0,0] = a[0,0] ** 2 + a[0,1] 
b[0,1] = a[0,1] ** 3 + a[0,0]
out = 2 * b
#其参数要传入和out维度一样的矩阵
out.backward(torch.FloatTensor([[1.,1.]]))
print('input:')
print(a.data)
print('output:')
print(out.data)
print('input gradients are:')
print(a.grad)

可以看出这个模型的雅克比应该是:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

运行一下:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

等等,什么鬼?正常来说不应该是

浅谈Pytorch中的自动求导函数backward()所需参数的含义

么?我是谁?我再哪?为什么就给我2个数,而且是 8 + 2 = 10 ,96 + 2 = 98 。难道都是加的 2 ?想一想,刚才咱们backward中传的参数是 [ [ 1 , 1 ] ],难道安装这个关系对应求和了?咱们换个参数来试一试,程序中只更改传入的参数为[ [ 1 , 2 ] ]:

import torch
from torch.autograd import Variable
 
a = Variable(torch.Tensor([[2.,4.]]),requires_grad=True)
b = torch.zeros(1,2)
b[0,0] = a[0,0] ** 2 + a[0,1] 
b[0,1] = a[0,1] ** 3 + a[0,0]
out = 2 * b
#其参数要传入和out维度一样的矩阵
out.backward(torch.FloatTensor([[1.,2.]]))
print('input:')
print(a.data)
print('output:')
print(out.data)
print('input gradients are:')
print(a.grad)

浅谈Pytorch中的自动求导函数backward()所需参数的含义

嗯,这回可以理解了,我们传入的参数,是对原来模型正常求导出来的雅克比矩阵进行线性操作,可以把我们传进的参数(设为arg)看成一个列向量,那么我们得到的结果就是:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

在这个题目中,我们得到的实际是:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

看起来一切完美的解释了,但是就在我刚刚打字的一刻,我意识到官方文档中说k.backward()传入的参数应该和k具有相同的维度,所以如果按上述去解释是解释不通的。哪里出问题了呢?

仔细看了一下,原来是这样的:在对雅克比矩阵进行线性操作的时候,应该把我们传进的参数(设为arg)看成一个行向量(不是列向量),那么我们得到的结果就是:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

也就是:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

这回我们就解释的通了。

现在我们来输出一下雅克比矩阵吧,为了不引起歧义,我们让雅克比矩阵的每个数值都不一样(一开始分析错了就是因为雅克比矩阵中有相同的数据),所以模型小改动如下:

import torch
from torch.autograd import Variable
 
a = Variable(torch.Tensor([[2.,4.]]),requires_grad=True)
b = torch.zeros(1,2)
b[0,0] = a[0,0] ** 2 + a[0,1] 
b[0,1] = a[0,1] ** 3 + a[0,0] * 2
out = 2 * b
#其参数要传入和out维度一样的矩阵
out.backward(torch.FloatTensor([[1,0]]),retain_graph=True)
A_temp = copy.deepcopy(a.grad)
a.grad.zero_()
out.backward(torch.FloatTensor([[0,1]]))
B_temp = a.grad
print('jacobian matrix is:')
print(torch.cat( (A_temp,B_temp),0 ))

如果没问题的话咱们的雅克比矩阵应该是 [ [ 8 , 2 ] , [ 4 , 96 ] ]

好了,下面是见证奇迹的时刻了,不要眨眼睛奥,千万不要眨眼睛… 3 2 1 砰…

浅谈Pytorch中的自动求导函数backward()所需参数的含义

好了,现在总结一下:因为经过了复杂的神经网络之后,out中每个数值都是由很多输入样本的属性(也就是输入数据)线性或者非线性组合而成的,那么out中的每个数值和输入数据的每个数值都有关联,也就是说【out】中的每个数都可以对【a】中每个数求导,那么我们backward()的参数[k1,k2,k3…kn]的含义就是:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

也可以理解成每个out分量对an求导时的权重。

对矩阵自动求导

现在,如果out是一个矩阵呢?

下面的例子也可以理解为:相当于一个神经网络有两个样本,每个样本有两个属性,神经网络有两个输出。

import torch
from torch.autograd import Variable
from torch import nn

a = Variable(torch.FloatTensor([[2,3],[1,2]]),requires_grad=True)
w = Variable( torch.zeros(2,1),requires_grad=True )
out = torch.mm(a,w)
out.backward(torch.FloatTensor([[1.],[1.]]),retain_graph=True)
print("gradients are:{}".format(w.grad.data))

如果前面的例子理解了,那么这个也很好理解,backward输入的参数k是一个2x1的矩阵,2代表的就是样本数量,就是在前面的基础上,再对每个样本进行加权求和。结果是:

浅谈Pytorch中的自动求导函数backward()所需参数的含义

如果有兴趣,也可以拓展一下多个样本的多分类问题,猜一下k的维度应该是【输入样本的个数 * 分类的个数】

好啦,纠结我好久的pytorch自动求导原理算是彻底搞懂啦~~~

以上这篇浅谈Pytorch中的自动求导函数backward()所需参数的含义就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python Socket编程入门教程
Jul 11 Python
Python列表append和+的区别浅析
Feb 02 Python
初步剖析C语言编程中的结构体
Jan 16 Python
从头学Python之编写可执行的.py文件
Nov 28 Python
python中闭包Closure函数作为返回值的方法示例
Dec 17 Python
Python数据分析之双色球中蓝红球分析统计示例
Feb 03 Python
python使用pandas处理excel文件转为csv文件的方法示例
Jul 18 Python
python创建子类的方法分析
Nov 28 Python
Python web如何在IIS发布应用过程解析
May 27 Python
快速了解Python开发环境Spyder
Jun 29 Python
python自动化八大定位元素讲解
Jul 09 Python
Python如何让字典保持有序排列
Apr 29 Python
python数据预处理 :样本分布不均的解决(过采样和欠采样)
Feb 29 #Python
python实现门限回归方式
Feb 29 #Python
Python3.9又更新了:dict内置新功能
Feb 28 #Python
python实现logistic分类算法代码
Feb 28 #Python
python GUI库图形界面开发之PyQt5打印控件QPrinter详细使用方法与实例
Feb 28 #Python
使用sklearn的cross_val_score进行交叉验证实例
Feb 28 #Python
彻底搞懂 python 中文乱码问题(深入分析)
Feb 28 #Python
You might like
window+nginx+php环境配置 附配置搭配说明
2010/12/29 PHP
PHP+Mysql+jQuery实现动态展示信息
2011/10/08 PHP
php写的AES加密解密类分享
2014/06/20 PHP
PHP使用DOM对XML解析处理操作示例
2019/07/04 PHP
ParseInt函数参数设置介绍
2014/01/02 Javascript
NodeJS学习笔记之Http模块
2015/01/13 NodeJs
使用jQuery实现更改默认alert框体
2015/04/13 Javascript
jQuery基于ajax()使用serialize()提交form数据的方法
2015/12/08 Javascript
JS中生成随机数的用法及相关函数
2016/01/09 Javascript
JS中使用mailto实现将用户在网页中输入的内容传递到本地邮件客户端
2016/10/08 Javascript
浅谈js-FCC算法Friendly Date Ranges(详解)
2017/04/10 Javascript
使用JavaScript开发跨平台的桌面应用详解
2017/07/27 Javascript
详解React 的几种条件渲染以及选择
2018/10/23 Javascript
vue2.0结合Element-ui实战案例
2019/03/06 Javascript
前端面试知识点目录一览
2019/04/15 Javascript
JS实现继承的几种常用方式示例
2019/06/22 Javascript
微信小程序文章详情页跳转案例详解
2019/07/09 Javascript
js+h5 canvas实现图片验证码
2020/10/11 Javascript
Python Web框架Pylons中使用MongoDB的例子
2013/12/03 Python
go语言计算两个时间的时间差方法
2015/03/13 Python
Python中property属性实例解析
2018/02/10 Python
使用requests库制作Python爬虫
2018/03/25 Python
Linux下安装python3.6和第三方库的教程详解
2018/11/09 Python
python爬虫把url链接编码成gbk2312格式过程解析
2020/06/08 Python
去除python中的字符串空格的简单方法
2020/12/22 Python
惠普墨西哥官方商店:HP墨西哥
2016/12/01 全球购物
美国环保婴儿用品公司:The Honest Company
2017/11/23 全球购物
6PM官网:折扣鞋、服装及配饰
2018/08/03 全球购物
美国Curacao百货连锁店网站:iCuracao.com
2019/07/20 全球购物
美国最好的葡萄酒网上商店:Wine Library
2019/11/02 全球购物
Java中的异常处理机制的简单原理和应用
2013/04/27 面试题
社团文化节策划书
2014/02/01 职场文书
商业房地产广告语
2014/03/13 职场文书
霸气队列口号
2014/06/18 职场文书
2015年城管执法工作总结
2015/07/23 职场文书
磁贴还没死, 微软Win11可修改注册表找回Win10开始菜单
2021/11/21 数码科技