pytorch .detach() .detach_() 和 .data用于切断反向传播的实现


Posted in Python onDecember 27, 2019

当我们再训练网络的时候可能希望保持一部分的网络参数不变,只对其中一部分的参数进行调整;或者值训练部分分支网络,并不让其梯度对主网络的梯度造成影响,这时候我们就需要使用detach()函数来切断一些分支的反向传播

1   detach()[source]

返回一个新的Variable,从当前计算图中分离下来的,但是仍指向原变量的存放位置,不同之处只是requires_grad为false,得到的这个Variable永远不需要计算其梯度,不具有grad。

即使之后重新将它的requires_grad置为true,它也不会具有梯度grad

这样我们就会继续使用这个新的Variable进行计算,后面当我们进行反向传播时,到该调用detach()的Variable就会停止,不能再继续向前进行传播

源码为:

def detach(self):
    """Returns a new Variable, detached from the current graph.
    Result will never require gradient. If the input is volatile, the output
    will be volatile too.
    .. note::
     Returned Variable uses the same data tensor, as the original one, and
     in-place modifications on either of them will be seen, and may trigger
     errors in correctness checks.
    """
    result = NoGrad()(self) # this is needed, because it merges version counters
    result._grad_fn = None

 return result

可见函数进行的操作有:

  • 将grad_fn设置为None
  • 将Variable的requires_grad设置为False

如果输入 volatile=True(即不需要保存记录,当只需要结果而不需要更新参数时这么设置来加快运算速度),那么返回的Variable volatile=True。(volatile已经弃用)

注意:

返回的Variable和原始的Variable公用同一个data tensor。in-place函数修改会在两个Variable上同时体现(因为它们共享data tensor),当要对其调用backward()时可能会导致错误。

举例:

比如正常的例子是:

import torch

a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid()

out.sum().backward()
print(a.grad)

返回:

(deeplearning) userdeMBP:pytorch user$ python test.py
None
tensor([0.1966, 0.1050, 0.0452])

当使用detach()但是没有进行更改时,并不会影响backward():

import torch

a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid()
print(out)

#添加detach(),c的requires_grad为False
c = out.detach()
print(c)

#这时候没有对c进行更改,所以并不会影响backward()
out.sum().backward()
print(a.grad)

返回:

(deeplearning) userdeMBP:pytorch user$ python test.py
None
tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)
tensor([0.7311, 0.8808, 0.9526])
tensor([0.1966, 0.1050, 0.0452])

可见c,out之间的区别是c是没有梯度的,out是有梯度的

如果这里使用的是c进行sum()操作并进行backward(),则会报错:

import torch

a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid()
print(out)

#添加detach(),c的requires_grad为False
c = out.detach()
print(c)

#使用新生成的Variable进行反向传播
c.sum().backward()
print(a.grad)

返回:

(deeplearning) userdeMBP:pytorch user$ python test.py
None
tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)
tensor([0.7311, 0.8808, 0.9526])
Traceback (most recent call last):
  File "test.py", line 13, in <module>
    c.sum().backward()
  File "/anaconda3/envs/deeplearning/lib/python3.6/site-packages/torch/tensor.py", line 102, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/anaconda3/envs/deeplearning/lib/python3.6/site-packages/torch/autograd/__init__.py", line 90, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

如果此时对c进行了更改,这个更改会被autograd追踪,在对out.sum()进行backward()时也会报错,因为此时的值进行backward()得到的梯度是错误的:

import torch

a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid()
print(out)

#添加detach(),c的requires_grad为False
c = out.detach()
print(c)
c.zero_() #使用in place函数对其进行修改

#会发现c的修改同时会影响out的值
print(c)
print(out)

#这时候对c进行更改,所以会影响backward(),这时候就不能进行backward(),会报错
out.sum().backward()
print(a.grad)

返回:

(deeplearning) userdeMBP:pytorch user$ python test.py
None
tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)
tensor([0.7311, 0.8808, 0.9526])
tensor([0., 0., 0.])
tensor([0., 0., 0.], grad_fn=<SigmoidBackward>)
Traceback (most recent call last):
  File "test.py", line 16, in <module>
    out.sum().backward()
  File "/anaconda3/envs/deeplearning/lib/python3.6/site-packages/torch/tensor.py", line 102, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/anaconda3/envs/deeplearning/lib/python3.6/site-packages/torch/autograd/__init__.py", line 90, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

2   .data

如果上面的操作使用的是.data,效果会不同:

这里的不同在于.data的修改不会被autograd追踪,这样当进行backward()时它不会报错,回得到一个错误的backward值

import torch

a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid()
print(out)


c = out.data
print(c)
c.zero_() #使用in place函数对其进行修改

#会发现c的修改同时也会影响out的值
print(c)
print(out)

#这里的不同在于.data的修改不会被autograd追踪,这样当进行backward()时它不会报错,回得到一个错误的backward值
out.sum().backward()
print(a.grad)

返回:

(deeplearning) userdeMBP:pytorch user$ python test.py
None
tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)
tensor([0.7311, 0.8808, 0.9526])
tensor([0., 0., 0.])
tensor([0., 0., 0.], grad_fn=<SigmoidBackward>)
tensor([0., 0., 0.])

上面的内容实现的原理是:

In-place 正确性检查

所有的Variable都会记录用在他们身上的 in-place operations。如果pytorch检测到variable在一个Function中已经被保存用来backward,但是之后它又被in-place operations修改。当这种情况发生时,在backward的时候,pytorch就会报错。这种机制保证了,如果你用了in-place operations,但是在backward过程中没有报错,那么梯度的计算就是正确的。

⚠️下面结果正确是因为改变的是sum()的结果,中间值a.sigmoid()并没有被影响,所以其对求梯度并没有影响:

import torch

a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid().sum() #但是如果sum写在这里,而不是写在backward()前,得到的结果是正确的
print(out)


c = out.data
print(c)
c.zero_() #使用in place函数对其进行修改

#会发现c的修改同时也会影响out的值
print(c)
print(out)

#没有写在这里
out.backward()
print(a.grad)

返回:

(deeplearning) userdeMBP:pytorch user$ python test.py
None
tensor(2.5644, grad_fn=<SumBackward0>)
tensor(2.5644)
tensor(0.)
tensor(0., grad_fn=<SumBackward0>)
tensor([0.1966, 0.1050, 0.0452])

3   detach_()[source]

将一个Variable从创建它的图中分离,并把它设置成叶子variable

其实就相当于变量之间的关系本来是x -> m -> y,这里的叶子variable是x,但是这个时候对m进行了.detach_()操作,其实就是进行了两个操作:

  • 将m的grad_fn的值设置为None,这样m就不会再与前一个节点x关联,这里的关系就会变成x, m -> y,此时的m就变成了叶子结点
  • 然后会将m的requires_grad设置为False,这样对y进行backward()时就不会求m的梯度

这么一看其实detach()和detach_()很像,两个的区别就是detach_()是对本身的更改,detach()则是生成了一个新的variable

比如x -> m -> y中如果对m进行detach(),后面如果反悔想还是对原来的计算图进行操作还是可以的

但是如果是进行了detach_(),那么原来的计算图也发生了变化,就不能反悔了

参考:https://pytorch-cn.readthedocs.io/zh/latest/package_references/torch-autograd/#detachsource

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中合并两个文本文件并按照姓名首字母排序的例子
Apr 25 Python
python对json的相关操作实例详解
Jan 04 Python
使用tensorflow实现AlexNet
Nov 20 Python
python re模块的高级用法详解
Jun 06 Python
Python实现调用另一个路径下py文件中的函数方法总结
Jun 07 Python
python模块导入的细节详解
Dec 10 Python
Django之Mode的外键自关联和引用未定义的Model方法
Dec 15 Python
微信小程序python用户认证的实现
Jul 29 Python
pytorch 在网络中添加可训练参数,修改预训练权重文件的方法
Aug 17 Python
通过Python实现Payload分离免杀过程详解
Jul 13 Python
用python进行视频剪辑
Nov 02 Python
python实现简单的井字棋游戏(gui界面)
Jan 22 Python
python操作gitlab API过程解析
Dec 27 #Python
python  ceiling divide 除法向上取整(或小数向上取整)的实例
Dec 27 #Python
python使用协程实现并发操作的方法详解
Dec 27 #Python
Python调用.NET库的方法步骤
Dec 27 #Python
IronPython连接MySQL的方法步骤
Dec 27 #Python
python基于三阶贝塞尔曲线的数据平滑算法
Dec 27 #Python
python3获取文件中url内容并下载代码实例
Dec 27 #Python
You might like
Yii分页用法实例详解
2014/12/04 PHP
php通过asort()给关联数组按照值排序的方法
2015/03/18 PHP
PHP微信企业号开发之回调模式开启与用法示例
2017/11/25 PHP
IE6/7/8中Option元素未设value时Select将获取空字符串
2011/04/07 Javascript
最新28个很棒的jQuery 教程
2011/05/28 Javascript
基于jquery的文章中所有图片width大小批量设置方法
2013/08/01 Javascript
Jquery 的outerHeight方法使用介绍
2013/09/11 Javascript
第二章之Bootstrap 页面排版样式
2016/04/25 Javascript
全面解析DOM操作和jQuery实现选项移动操作代码分享
2016/06/07 Javascript
返回函数的JavaScript函数
2016/06/14 Javascript
简单封装js的dom查询实例代码
2016/07/08 Javascript
layui前段框架日期控件使用方法详解
2017/05/19 Javascript
利用nodeJs anywhere搭建本地服务器环境的方法
2018/05/12 NodeJs
Vue中使用create-keyframe-animation与动画钩子完成复杂动画
2019/04/09 Javascript
node中IO以及定时器优先级详解
2019/05/10 Javascript
vue props default Array或是Object的正确写法说明
2020/07/30 Javascript
vue 解决无法对未定义的值,空值或基元值设置反应属性报错问题
2020/07/31 Javascript
利用 Chrome Dev Tools 进行页面性能分析的步骤说明(前端性能优化)
2021/02/24 Javascript
Python DataFrame.groupby()聚合函数,分组级运算
2018/09/18 Python
Python使用ctypes调用C/C++的方法
2019/01/29 Python
python3实现在二叉树中找出和为某一值的所有路径(推荐)
2019/12/26 Python
PyTorch实现重写/改写Dataset并载入Dataloader
2020/07/14 Python
Yummie官方网站:塑身衣和衣柜必需品
2019/10/29 全球购物
TCP/IP的分层模型
2013/10/27 面试题
会计电算化专业毕业生推荐信
2013/12/24 职场文书
追悼会上的答谢词
2014/01/10 职场文书
小学生自我评价范文
2014/01/25 职场文书
《手指教学》反思
2014/02/14 职场文书
班级读书活动总结
2014/06/30 职场文书
最美乡村教师观后感
2015/06/11 职场文书
经营场所使用证明
2015/06/19 职场文书
2015年语言文字工作总结
2015/07/23 职场文书
成功的商业计划书这样写才最靠谱
2019/07/12 职场文书
《哪吒之魔童降世》观后感:世上哪有随随便便的成功
2019/11/08 职场文书
python opencv将多个图放在一个窗口的实例详解
2022/02/28 Python
MybatisPlus EntityWrapper如何自定义SQL
2022/03/22 Java/Android