pytorch .detach() .detach_() 和 .data用于切断反向传播的实现


Posted in Python onDecember 27, 2019

当我们再训练网络的时候可能希望保持一部分的网络参数不变,只对其中一部分的参数进行调整;或者值训练部分分支网络,并不让其梯度对主网络的梯度造成影响,这时候我们就需要使用detach()函数来切断一些分支的反向传播

1   detach()[source]

返回一个新的Variable,从当前计算图中分离下来的,但是仍指向原变量的存放位置,不同之处只是requires_grad为false,得到的这个Variable永远不需要计算其梯度,不具有grad。

即使之后重新将它的requires_grad置为true,它也不会具有梯度grad

这样我们就会继续使用这个新的Variable进行计算,后面当我们进行反向传播时,到该调用detach()的Variable就会停止,不能再继续向前进行传播

源码为:

def detach(self):
    """Returns a new Variable, detached from the current graph.
    Result will never require gradient. If the input is volatile, the output
    will be volatile too.
    .. note::
     Returned Variable uses the same data tensor, as the original one, and
     in-place modifications on either of them will be seen, and may trigger
     errors in correctness checks.
    """
    result = NoGrad()(self) # this is needed, because it merges version counters
    result._grad_fn = None

 return result

可见函数进行的操作有:

  • 将grad_fn设置为None
  • 将Variable的requires_grad设置为False

如果输入 volatile=True(即不需要保存记录,当只需要结果而不需要更新参数时这么设置来加快运算速度),那么返回的Variable volatile=True。(volatile已经弃用)

注意:

返回的Variable和原始的Variable公用同一个data tensor。in-place函数修改会在两个Variable上同时体现(因为它们共享data tensor),当要对其调用backward()时可能会导致错误。

举例:

比如正常的例子是:

import torch

a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid()

out.sum().backward()
print(a.grad)

返回:

(deeplearning) userdeMBP:pytorch user$ python test.py
None
tensor([0.1966, 0.1050, 0.0452])

当使用detach()但是没有进行更改时,并不会影响backward():

import torch

a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid()
print(out)

#添加detach(),c的requires_grad为False
c = out.detach()
print(c)

#这时候没有对c进行更改,所以并不会影响backward()
out.sum().backward()
print(a.grad)

返回:

(deeplearning) userdeMBP:pytorch user$ python test.py
None
tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)
tensor([0.7311, 0.8808, 0.9526])
tensor([0.1966, 0.1050, 0.0452])

可见c,out之间的区别是c是没有梯度的,out是有梯度的

如果这里使用的是c进行sum()操作并进行backward(),则会报错:

import torch

a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid()
print(out)

#添加detach(),c的requires_grad为False
c = out.detach()
print(c)

#使用新生成的Variable进行反向传播
c.sum().backward()
print(a.grad)

返回:

(deeplearning) userdeMBP:pytorch user$ python test.py
None
tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)
tensor([0.7311, 0.8808, 0.9526])
Traceback (most recent call last):
  File "test.py", line 13, in <module>
    c.sum().backward()
  File "/anaconda3/envs/deeplearning/lib/python3.6/site-packages/torch/tensor.py", line 102, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/anaconda3/envs/deeplearning/lib/python3.6/site-packages/torch/autograd/__init__.py", line 90, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

如果此时对c进行了更改,这个更改会被autograd追踪,在对out.sum()进行backward()时也会报错,因为此时的值进行backward()得到的梯度是错误的:

import torch

a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid()
print(out)

#添加detach(),c的requires_grad为False
c = out.detach()
print(c)
c.zero_() #使用in place函数对其进行修改

#会发现c的修改同时会影响out的值
print(c)
print(out)

#这时候对c进行更改,所以会影响backward(),这时候就不能进行backward(),会报错
out.sum().backward()
print(a.grad)

返回:

(deeplearning) userdeMBP:pytorch user$ python test.py
None
tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)
tensor([0.7311, 0.8808, 0.9526])
tensor([0., 0., 0.])
tensor([0., 0., 0.], grad_fn=<SigmoidBackward>)
Traceback (most recent call last):
  File "test.py", line 16, in <module>
    out.sum().backward()
  File "/anaconda3/envs/deeplearning/lib/python3.6/site-packages/torch/tensor.py", line 102, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/anaconda3/envs/deeplearning/lib/python3.6/site-packages/torch/autograd/__init__.py", line 90, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

2   .data

如果上面的操作使用的是.data,效果会不同:

这里的不同在于.data的修改不会被autograd追踪,这样当进行backward()时它不会报错,回得到一个错误的backward值

import torch

a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid()
print(out)


c = out.data
print(c)
c.zero_() #使用in place函数对其进行修改

#会发现c的修改同时也会影响out的值
print(c)
print(out)

#这里的不同在于.data的修改不会被autograd追踪,这样当进行backward()时它不会报错,回得到一个错误的backward值
out.sum().backward()
print(a.grad)

返回:

(deeplearning) userdeMBP:pytorch user$ python test.py
None
tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)
tensor([0.7311, 0.8808, 0.9526])
tensor([0., 0., 0.])
tensor([0., 0., 0.], grad_fn=<SigmoidBackward>)
tensor([0., 0., 0.])

上面的内容实现的原理是:

In-place 正确性检查

所有的Variable都会记录用在他们身上的 in-place operations。如果pytorch检测到variable在一个Function中已经被保存用来backward,但是之后它又被in-place operations修改。当这种情况发生时,在backward的时候,pytorch就会报错。这种机制保证了,如果你用了in-place operations,但是在backward过程中没有报错,那么梯度的计算就是正确的。

⚠️下面结果正确是因为改变的是sum()的结果,中间值a.sigmoid()并没有被影响,所以其对求梯度并没有影响:

import torch

a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid().sum() #但是如果sum写在这里,而不是写在backward()前,得到的结果是正确的
print(out)


c = out.data
print(c)
c.zero_() #使用in place函数对其进行修改

#会发现c的修改同时也会影响out的值
print(c)
print(out)

#没有写在这里
out.backward()
print(a.grad)

返回:

(deeplearning) userdeMBP:pytorch user$ python test.py
None
tensor(2.5644, grad_fn=<SumBackward0>)
tensor(2.5644)
tensor(0.)
tensor(0., grad_fn=<SumBackward0>)
tensor([0.1966, 0.1050, 0.0452])

3   detach_()[source]

将一个Variable从创建它的图中分离,并把它设置成叶子variable

其实就相当于变量之间的关系本来是x -> m -> y,这里的叶子variable是x,但是这个时候对m进行了.detach_()操作,其实就是进行了两个操作:

  • 将m的grad_fn的值设置为None,这样m就不会再与前一个节点x关联,这里的关系就会变成x, m -> y,此时的m就变成了叶子结点
  • 然后会将m的requires_grad设置为False,这样对y进行backward()时就不会求m的梯度

这么一看其实detach()和detach_()很像,两个的区别就是detach_()是对本身的更改,detach()则是生成了一个新的variable

比如x -> m -> y中如果对m进行detach(),后面如果反悔想还是对原来的计算图进行操作还是可以的

但是如果是进行了detach_(),那么原来的计算图也发生了变化,就不能反悔了

参考:https://pytorch-cn.readthedocs.io/zh/latest/package_references/torch-autograd/#detachsource

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python处理文本文件并生成指定格式的文件
Jul 31 Python
用Python进行TCP网络编程的教程
Apr 29 Python
Python3.2中Print函数用法实例详解
May 19 Python
Python实现将MySQL数据库表中的数据导出生成csv格式文件的方法
Jan 11 Python
Python3.5.3下配置opencv3.2.0的操作方法
Apr 02 Python
Win10 安装PyCharm2019.1.1(图文教程)
Sep 29 Python
Django ORM 查询表中某列字段值的方法
Apr 30 Python
Python把图片转化为pdf代码实例
Jul 28 Python
Python如何创建装饰器时保留函数元信息
Aug 07 Python
Python环境配置实现pip加速过程解析
Nov 27 Python
Python爬虫获取op.gg英雄联盟英雄对位胜率的源码
Jan 29 Python
解决hive中导入text文件遇到的坑
Apr 07 Python
python操作gitlab API过程解析
Dec 27 #Python
python  ceiling divide 除法向上取整(或小数向上取整)的实例
Dec 27 #Python
python使用协程实现并发操作的方法详解
Dec 27 #Python
Python调用.NET库的方法步骤
Dec 27 #Python
IronPython连接MySQL的方法步骤
Dec 27 #Python
python基于三阶贝塞尔曲线的数据平滑算法
Dec 27 #Python
python3获取文件中url内容并下载代码实例
Dec 27 #Python
You might like
模拟SQLSERVER的两个函数:dateadd(),datediff()
2006/10/09 PHP
精通php的十大要点(上)
2009/02/04 PHP
php 获取select下拉列表框的值
2010/05/08 PHP
php将金额数字转化为中文大写
2015/07/09 PHP
PHP简单实现防止SQL注入的方法
2018/03/13 PHP
一个简单的jQuery插件制作 学习过程及实例
2010/04/25 Javascript
JQuery学习笔记 nt-child的使用
2011/01/17 Javascript
原生Js页面滚动延迟加载图片实现原理及过程
2013/06/24 Javascript
jQuery事件绑定和委托实例
2014/11/25 Javascript
JavaScript中的object转换成number或string规则介绍
2014/12/31 Javascript
js打造数组转json函数
2015/01/14 Javascript
jQuery实现跨域
2015/02/03 Javascript
js插件设置innerHTML时在IE8下提示“未知运行时错误”解决方法
2015/04/25 Javascript
动态加载js文件简单示例
2016/04/21 Javascript
jquery实现数字输入框
2017/02/22 Javascript
jQuery的三种bind/One/Live/On事件绑定使用方法
2017/02/23 Javascript
JS实现的走迷宫小游戏完整实例
2017/07/19 Javascript
Angular 数据请求的实现方法
2018/05/07 Javascript
详解基于webpack&amp;gettext的前端多语言方案
2019/01/29 Javascript
详解js 创建对象的几种方法
2019/03/08 Javascript
解决vue-router 二级导航默认选中某一选项的问题
2019/11/01 Javascript
[02:40]DOTA2英雄基础教程 巨牙海民
2013/12/23 DOTA
使用Python的判断语句模拟三目运算
2015/04/24 Python
详解用Python处理HTML转义字符的5种方式
2017/12/27 Python
快速了解Python中的装饰器
2018/01/11 Python
Python自然语言处理 NLTK 库用法入门教程【经典】
2018/06/26 Python
Python中求对数方法总结
2020/03/10 Python
Python+unittest+DDT实现数据驱动测试
2020/11/30 Python
中国网上药店领导者:1药网
2017/02/16 全球购物
Omio意大利:全欧洲低价大巴、火车和航班搜索和比价
2017/12/02 全球购物
好的自荐信包括什么内容
2013/11/07 职场文书
青春无悔演讲稿
2014/05/08 职场文书
社区服务活动感想
2015/08/11 职场文书
学生会副主席竞选稿
2015/11/19 职场文书
私人贷款担保书该怎么写呢?
2019/07/02 职场文书
vue项目配置sass及引入外部scss文件
2022/04/14 Vue.js