pytorch的梯度计算以及backward方法详解


Posted in Python onJanuary 10, 2020

基础知识

tensors:

tensor在pytorch里面是一个n维数组。我们可以通过指定参数reuqires_grad=True来建立一个反向传播图,从而能够计算梯度。在pytorch中一般叫做dynamic computation graph(DCG)——即动态计算图。

import torch
import numpy as np

# 方式一
x = torch.randn(2,2, requires_grad=True)

# 方式二
x = torch.autograd.Variable(torch.Tensor([2,3]), requires_grad=True)

#方式三
x = torch.tensor([2,3], requires_grad=True, dtype=torch.float64)

# 方式四
x = np.array([1,2,3] ,dtype=np.float64)
x = torch.from_numpy(x)
x.requires_grad = True
# 或者 x.requires_grad_(True)

note1:在pytorch中,只有浮点类型的数才有梯度,故在方法四中指定np数组的类型为float类型。为什么torch.Tensor中不需要呢,可以通过以下代码验证

import torch
import numpy as np

a = torch.Tensor([2,3])
print(a.dtype) # torch.floaat32

b = torch.tensor([2,3])
print(b.dtype) # torch.int64

 c = np.array(2,3)
 print(c.dtype) # int64

note2pytorch中tensor与Tensor的区别是什么?这两个看起来如此相似。

首先,torch.Tensor是一个类,所有的tensor都是Tensor的一个实例;而torch.tensor是一个函数。这也说明了为什么使用torch.Tensor()没有问题而torch.tensor()却有问题。

其次,torch.tensor主要是将一个data封装成tensor,并且可以指定requires_grad。

torch.tensor(data,dtype=None,device=None,requires_grad=False) - > Tensor

最后,我们更多地使用torch.tensor,我们可以通过使用torch.tensor(())来达到与torch.Tensor()同样的效果。

具体可参考torch.tensor与torch.Tensor的区别

Dynamic Computational graph

我们来看一个计算图

pytorch的梯度计算以及backward方法详解

我们 来看一个计算图 解释一下各个属性的含义,

data: 变量中存储的值,如x中存储着1,y中存储着2,z中存储着3

requires_grad:该变量有两个值,True 或者 False,如果为True,则加入到反向传播图中参与计算。

grad:该属性存储着相关的梯度值。当requires_grad为False时,该属性为None。即使requires_grad为True,也必须在调用其他节点的backward()之后,该变量的grad才会保存相关的梯度值。否则为None

grad_fn:表示用于计算梯度的函数。

is_leaf:为True或者False,表示该节点是否为叶子节点。

当调用backward函数时,只有requires_grad为true以及is_leaf为true的节点才会被计算梯度,即grad属性才会被赋予值。

梯度计算

examples

运算结果变量的requires_grad取决于输入变量。例如:当变量z的requires_grad属性为True时,为了求得z的梯度,那么变量b的requires_grad就必须为true了,而变量x,y,a的requires_grad属性都为False。

将事先创建的变量,如x、y、z称为创建变量;像a、b这样由其他变量运算得到的称为结果变量。

from torch.autograd import Variable

x = Variable(torch.randn(2,2))
y = Variable(torch.randn(2,2))
z = Variable(torch.randn(2,2), requires_grad=True)


a = x+y
b = a+z

print(x.requires_grad, y.requires_grad, z.requires_grad) # False, False, True
print(a.requires_grad, b.requires_grad) # False, True

print(x.requires_grad) # True
print(a.requires_grad) # True

调用backward()计算梯度

import torch as t
from torch.autograd import Variable as v

a = v(t.FloatTensor([2, 3]), requires_grad=True) 
b = a + 3
c = b * b * 3
out = c.mean()
out.backward(retain_graph=True) # 这里可以不带参数,默认值为‘1',由于下面我们还要求导,故加上retain_graph=True选项

print(a.grad) # tensor([15., 18.])

backward中的gradient参数使用

a. 最后的结果变量为标量(scalar)

如第二个例子,通过调用out.backward()实现对a的求导,这里默认调用了out.backward(gradient=None)或者指定为out.backward(gradient=torch.Tensor([1.0])

b. 最后的结果变量为向量(vector)

import torch
from torch.autograd import Variable as V

m = V(torch.FloatTensor([2, 3]), requires_grad=True) # 注意这里有两层括号,非标量
n = V(torch.zeros(2))
n[0] = m[0] ** 2
n[1] = m[1] ** 3
n.backward(gradient=torch.Tensor([1,1]), retain_graph=True)
print(m.grad)

结果为:

tensor([ 4., 27.])

如果使用n.backward()的话,那么就会报如下的错:RuntimeError: grad can be implicitly created only for scalar outputs

注意:这里的gradient的维度必须与n的维度相同。其中的原理如下:

在执行z.backward(gradient)的时候,如果z不是一个标量,那么先构造一个标量的值:L = torch.sum(z*gradient),再计算关于L对各个leaf Variable的梯度。

pytorch的梯度计算以及backward方法详解

以上这篇pytorch的梯度计算以及backward方法详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
一篇不错的Python入门教程
Feb 08 Python
python 中文字符串的处理实现代码
Oct 25 Python
Python牛刀小试密码爆破
Feb 03 Python
使用django-suit为django 1.7 admin后台添加模板
Nov 18 Python
C#返回当前系统所有可用驱动器符号的方法
Apr 18 Python
Python中的字典与成员运算符初步探究
Oct 13 Python
基于Python 的进程管理工具supervisor使用指南
Sep 18 Python
Python中摘要算法MD5,SHA1简介及应用实例代码
Jan 09 Python
python绘制热力图heatmap
Mar 23 Python
基于wxPython的GUI实现输入对话框(1)
Feb 27 Python
python3.x实现base64加密和解密
Mar 28 Python
详解Python sys.argv使用方法
May 10 Python
Python如何获取Win7,Win10系统缩放大小
Jan 10 #Python
python-OpenCV 实现将数组转换成灰度图和彩图
Jan 09 #Python
Python 实现将数组/矩阵转换成Image类
Jan 09 #Python
python 实现将Numpy数组保存为图像
Jan 09 #Python
Python+OpenCV实现将图像转换为二进制格式
Jan 09 #Python
如何使用Python破解ZIP或RAR压缩文件密码
Jan 09 #Python
python读取raw binary图片并提取统计信息的实例
Jan 09 #Python
You might like
海河写的 Discuz论坛帖子调用js的php代码
2007/08/23 PHP
PHP has encountered an Access Violation at 7C94BD02解决方法
2009/08/24 PHP
PHP图片处理之图片旋转和图片翻转实例
2014/11/19 PHP
thinkPHP数据查询常用方法总结【select,find,getField,query】
2017/03/15 PHP
PHP扩展Swoole实现实时异步任务队列示例
2019/04/13 PHP
jQuery asp.net 用json格式返回自定义对象
2010/04/07 Javascript
基于jquery的$.ajax async使用
2011/10/19 Javascript
js取整数、取余数的方法
2014/05/11 Javascript
JS往数组中添加项性能分析
2015/02/25 Javascript
基于jQuery 实现bootstrapValidator下的全局验证
2015/12/07 Javascript
js实现将选中内容分享到新浪或腾讯微博
2015/12/16 Javascript
一个字符串中出现次数最多的字符 统计这个次数【实现代码】
2016/04/29 Javascript
使用jquery获取url以及jquery获取url参数的实现方法
2016/05/25 Javascript
基于iscroll.js实现下拉刷新和上拉加载效果
2016/11/28 Javascript
Javascript中八种遍历方法的执行速度深度对比
2017/04/25 Javascript
jQuery实现导航栏头部菜单项点击后变换颜色的方法
2017/07/19 jQuery
vue.js路由跳转详解
2017/08/28 Javascript
zTree节点文字过多的处理方法
2017/11/24 Javascript
angularJs自定义过滤器实现手机号信息隐藏的方法
2018/10/08 Javascript
详解Vue前端对axios的封装和使用
2019/04/01 Javascript
前端面试知识点目录一览
2019/04/15 Javascript
微信小程序 弹窗输入组件的实现解析
2019/08/12 Javascript
JavaScript实现文件下载并重命名代码实例
2019/12/12 Javascript
如何在postman测试用例中实现断言过程解析
2020/07/09 Javascript
python实现将汉字转换成汉语拼音的库
2015/05/05 Python
Python常用时间操作总结【取得当前时间、时间函数、应用等】
2017/05/11 Python
Python面向对象程序设计OOP深入分析【构造函数,组合类,工具类等】
2019/01/05 Python
Python 实现中值滤波、均值滤波的方法
2019/01/09 Python
python groupby 函数 as_index详解
2019/12/16 Python
ReVive利维肤美国官网:RéVive Skincare
2018/04/18 全球购物
橄榄树药房:OLIVEDA
2019/09/01 全球购物
武汉高蓝德国际.net机试
2016/06/24 面试题
大学三年的自我评价
2013/12/25 职场文书
《李广射虎》教学反思
2014/04/27 职场文书
晚会闭幕词
2015/01/28 职场文书
Jupyter Notebook内使用argparse报错的解决方案
2021/06/03 Python