pytorch .detach() .detach_() 和 .data用于切断反向传播的实现


Posted in Python onDecember 27, 2019

当我们再训练网络的时候可能希望保持一部分的网络参数不变,只对其中一部分的参数进行调整;或者值训练部分分支网络,并不让其梯度对主网络的梯度造成影响,这时候我们就需要使用detach()函数来切断一些分支的反向传播

1   detach()[source]

返回一个新的Variable,从当前计算图中分离下来的,但是仍指向原变量的存放位置,不同之处只是requires_grad为false,得到的这个Variable永远不需要计算其梯度,不具有grad。

即使之后重新将它的requires_grad置为true,它也不会具有梯度grad

这样我们就会继续使用这个新的Variable进行计算,后面当我们进行反向传播时,到该调用detach()的Variable就会停止,不能再继续向前进行传播

源码为:

def detach(self):
    """Returns a new Variable, detached from the current graph.
    Result will never require gradient. If the input is volatile, the output
    will be volatile too.
    .. note::
     Returned Variable uses the same data tensor, as the original one, and
     in-place modifications on either of them will be seen, and may trigger
     errors in correctness checks.
    """
    result = NoGrad()(self) # this is needed, because it merges version counters
    result._grad_fn = None

 return result

可见函数进行的操作有:

  • 将grad_fn设置为None
  • 将Variable的requires_grad设置为False

如果输入 volatile=True(即不需要保存记录,当只需要结果而不需要更新参数时这么设置来加快运算速度),那么返回的Variable volatile=True。(volatile已经弃用)

注意:

返回的Variable和原始的Variable公用同一个data tensor。in-place函数修改会在两个Variable上同时体现(因为它们共享data tensor),当要对其调用backward()时可能会导致错误。

举例:

比如正常的例子是:

import torch

a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid()

out.sum().backward()
print(a.grad)

返回:

(deeplearning) userdeMBP:pytorch user$ python test.py
None
tensor([0.1966, 0.1050, 0.0452])

当使用detach()但是没有进行更改时,并不会影响backward():

import torch

a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid()
print(out)

#添加detach(),c的requires_grad为False
c = out.detach()
print(c)

#这时候没有对c进行更改,所以并不会影响backward()
out.sum().backward()
print(a.grad)

返回:

(deeplearning) userdeMBP:pytorch user$ python test.py
None
tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)
tensor([0.7311, 0.8808, 0.9526])
tensor([0.1966, 0.1050, 0.0452])

可见c,out之间的区别是c是没有梯度的,out是有梯度的

如果这里使用的是c进行sum()操作并进行backward(),则会报错:

import torch

a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid()
print(out)

#添加detach(),c的requires_grad为False
c = out.detach()
print(c)

#使用新生成的Variable进行反向传播
c.sum().backward()
print(a.grad)

返回:

(deeplearning) userdeMBP:pytorch user$ python test.py
None
tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)
tensor([0.7311, 0.8808, 0.9526])
Traceback (most recent call last):
  File "test.py", line 13, in <module>
    c.sum().backward()
  File "/anaconda3/envs/deeplearning/lib/python3.6/site-packages/torch/tensor.py", line 102, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/anaconda3/envs/deeplearning/lib/python3.6/site-packages/torch/autograd/__init__.py", line 90, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

如果此时对c进行了更改,这个更改会被autograd追踪,在对out.sum()进行backward()时也会报错,因为此时的值进行backward()得到的梯度是错误的:

import torch

a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid()
print(out)

#添加detach(),c的requires_grad为False
c = out.detach()
print(c)
c.zero_() #使用in place函数对其进行修改

#会发现c的修改同时会影响out的值
print(c)
print(out)

#这时候对c进行更改,所以会影响backward(),这时候就不能进行backward(),会报错
out.sum().backward()
print(a.grad)

返回:

(deeplearning) userdeMBP:pytorch user$ python test.py
None
tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)
tensor([0.7311, 0.8808, 0.9526])
tensor([0., 0., 0.])
tensor([0., 0., 0.], grad_fn=<SigmoidBackward>)
Traceback (most recent call last):
  File "test.py", line 16, in <module>
    out.sum().backward()
  File "/anaconda3/envs/deeplearning/lib/python3.6/site-packages/torch/tensor.py", line 102, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/anaconda3/envs/deeplearning/lib/python3.6/site-packages/torch/autograd/__init__.py", line 90, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

2   .data

如果上面的操作使用的是.data,效果会不同:

这里的不同在于.data的修改不会被autograd追踪,这样当进行backward()时它不会报错,回得到一个错误的backward值

import torch

a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid()
print(out)


c = out.data
print(c)
c.zero_() #使用in place函数对其进行修改

#会发现c的修改同时也会影响out的值
print(c)
print(out)

#这里的不同在于.data的修改不会被autograd追踪,这样当进行backward()时它不会报错,回得到一个错误的backward值
out.sum().backward()
print(a.grad)

返回:

(deeplearning) userdeMBP:pytorch user$ python test.py
None
tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)
tensor([0.7311, 0.8808, 0.9526])
tensor([0., 0., 0.])
tensor([0., 0., 0.], grad_fn=<SigmoidBackward>)
tensor([0., 0., 0.])

上面的内容实现的原理是:

In-place 正确性检查

所有的Variable都会记录用在他们身上的 in-place operations。如果pytorch检测到variable在一个Function中已经被保存用来backward,但是之后它又被in-place operations修改。当这种情况发生时,在backward的时候,pytorch就会报错。这种机制保证了,如果你用了in-place operations,但是在backward过程中没有报错,那么梯度的计算就是正确的。

⚠️下面结果正确是因为改变的是sum()的结果,中间值a.sigmoid()并没有被影响,所以其对求梯度并没有影响:

import torch

a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid().sum() #但是如果sum写在这里,而不是写在backward()前,得到的结果是正确的
print(out)


c = out.data
print(c)
c.zero_() #使用in place函数对其进行修改

#会发现c的修改同时也会影响out的值
print(c)
print(out)

#没有写在这里
out.backward()
print(a.grad)

返回:

(deeplearning) userdeMBP:pytorch user$ python test.py
None
tensor(2.5644, grad_fn=<SumBackward0>)
tensor(2.5644)
tensor(0.)
tensor(0., grad_fn=<SumBackward0>)
tensor([0.1966, 0.1050, 0.0452])

3   detach_()[source]

将一个Variable从创建它的图中分离,并把它设置成叶子variable

其实就相当于变量之间的关系本来是x -> m -> y,这里的叶子variable是x,但是这个时候对m进行了.detach_()操作,其实就是进行了两个操作:

  • 将m的grad_fn的值设置为None,这样m就不会再与前一个节点x关联,这里的关系就会变成x, m -> y,此时的m就变成了叶子结点
  • 然后会将m的requires_grad设置为False,这样对y进行backward()时就不会求m的梯度

这么一看其实detach()和detach_()很像,两个的区别就是detach_()是对本身的更改,detach()则是生成了一个新的variable

比如x -> m -> y中如果对m进行detach(),后面如果反悔想还是对原来的计算图进行操作还是可以的

但是如果是进行了detach_(),那么原来的计算图也发生了变化,就不能反悔了

参考:https://pytorch-cn.readthedocs.io/zh/latest/package_references/torch-autograd/#detachsource

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python字符串拼接的几种方法整理
Aug 02 Python
python模块之time模块(实例讲解)
Sep 13 Python
Python实现多条件筛选目标数据功能【测试可用】
Jun 13 Python
Python地图绘制实操详解
Mar 04 Python
selenium+python环境配置教程详解
May 28 Python
详解Pandas之容易让人混淆的行选择和列选择
Jul 10 Python
使用python实现男神女神颜值打分系统(推荐)
Oct 31 Python
Python内置函数property()如何使用
Sep 01 Python
Python面向对象之内置函数相关知识总结
Jun 24 Python
浅析python中特殊文件和特殊函数
Feb 24 Python
python实现对doc、txt、xls等文档的读写操作
Apr 02 Python
PyTorch中permute的使用方法
Apr 26 Python
python操作gitlab API过程解析
Dec 27 #Python
python  ceiling divide 除法向上取整(或小数向上取整)的实例
Dec 27 #Python
python使用协程实现并发操作的方法详解
Dec 27 #Python
Python调用.NET库的方法步骤
Dec 27 #Python
IronPython连接MySQL的方法步骤
Dec 27 #Python
python基于三阶贝塞尔曲线的数据平滑算法
Dec 27 #Python
python3获取文件中url内容并下载代码实例
Dec 27 #Python
You might like
杏林同学录(八)
2006/10/09 PHP
使用apache模块rewrite_module (转)
2007/02/14 PHP
PHP开发过程中常用函数收藏
2009/12/14 PHP
ThinkPHP中order()使用方法详解
2016/04/19 PHP
解决yii2左侧菜单子级无法高亮问题的方法
2016/05/08 PHP
php获取远程图片并下载保存到本地的方法分析
2016/10/08 PHP
PNGHandler-借助JS让PNG图在IE下实现透明(包括背景图)
2007/08/31 Javascript
jquery重新播放css动画所遇问题解决
2013/08/21 Javascript
推荐4个原生javascript常用的函数
2015/01/12 Javascript
基于jquery实现在线选座订座之影院篇
2015/08/24 Javascript
总结JavaScript中布尔操作符||与&amp;&amp;的使用技巧
2015/11/17 Javascript
原生JavaScript实现动态省市县三级联动下拉框菜单实例代码
2016/02/03 Javascript
jQuery实现的浮动层div浏览器居中显示效果
2017/02/03 Javascript
Windows下快速搭建NodeJS本地服务器的步骤
2017/08/09 NodeJs
jQuery实现点击自身以外区域关闭弹出层功能完整示例【改进版】
2018/07/31 jQuery
JS加密插件CryptoJS实现的Base64加密示例
2020/08/16 Javascript
Vue2.0实现简单分页及跳转效果
2019/07/29 Javascript
在vue中封装方法以及多处引用该方法详解
2020/08/14 Javascript
详解js创建对象的几种方式和对象方法
2021/03/01 Javascript
Python中使用strip()方法删除字符串中空格的教程
2015/05/20 Python
Python图算法实例分析
2016/08/13 Python
python 判断网络连通的实现方法
2018/04/22 Python
python五子棋游戏的设计与实现
2019/06/18 Python
python爬虫 模拟登录人人网过程解析
2019/07/31 Python
对Django 转发和重定向的实例详解
2019/08/06 Python
Pytorch 抽取vgg各层并进行定制化处理的方法
2019/08/20 Python
python 消费 kafka 数据教程
2019/12/21 Python
python如何更新包
2020/06/11 Python
python在一个范围内取随机数的简单实例
2020/08/16 Python
使用HTML5的Notification API制作web通知的教程
2015/05/08 HTML / CSS
英国泰坦旅游网站:全球陪同游览,邮轮和铁路旅行
2016/11/29 全球购物
新加坡领先的在线生活方式和杂货购物网站:EAMART
2019/04/02 全球购物
自1926年以来就为冰岛保持温暖:66°North
2020/11/27 全球购物
数学与统计学院学生个人职业生涯规划书
2014/02/10 职场文书
人力资源经理的岗位职责
2014/03/02 职场文书
大学社团计划书
2014/05/01 职场文书