PyTorch中clone()、detach()及相关扩展详解


Posted in Python onDecember 09, 2020

clone() 与 detach() 对比

Torch 为了提高速度,向量或是矩阵的赋值是指向同一内存的,这不同于 Matlab。如果需要保存旧的tensor即需要开辟新的存储地址而不是引用,可以用 clone() 进行深拷贝,

首先我们来打印出来clone()操作后的数据类型定义变化:

(1). 简单打印类型

import torch

a = torch.tensor(1.0, requires_grad=True)
b = a.clone()
c = a.detach()
a.data *= 3
b += 1

print(a) # tensor(3., requires_grad=True)
print(b)
print(c)

'''
输出结果:
tensor(3., requires_grad=True)
tensor(2., grad_fn=<AddBackward0>)
tensor(3.) # detach()后的值随着a的变化出现变化
'''

grad_fn=<CloneBackward>,表示clone后的返回值是个中间变量,因此支持梯度的回溯。clone操作在一定程度上可以视为是一个identity-mapping函数。

detach()操作后的tensor与原始tensor共享数据内存,当原始tensor在计算图中数值发生反向传播等更新之后,detach()的tensor值也发生了改变。

注意: 在pytorch中我们不要直接使用id是否相等来判断tensor是否共享内存,这只是充分条件,因为也许底层共享数据内存,但是仍然是新的tensor,比如detach(),如果我们直接打印id会出现以下情况。

import torch as t
a = t.tensor([1.0,2.0], requires_grad=True)
b = a.detach()
#c[:] = a.detach()
print(id(a))
print(id(b))
#140568935450520
140570337203616

显然直接打印出来的id不等,我们可以通过简单的赋值后观察数据变化进行判断。

(2). clone()的梯度回传

detach()函数可以返回一个完全相同的tensor,与旧的tensor共享内存,脱离计算图,不会牵扯梯度计算。

而clone充当中间变量,会将梯度传给源张量进行叠加,但是本身不保存其grad,即值为None

import torch
a = torch.tensor(1.0, requires_grad=True)
a_ = a.clone()
y = a**2
z = a ** 2+a_ * 3
y.backward()
print(a.grad) # 2
z.backward()
print(a_.grad) # None. 中间variable,无grad
print(a.grad) 
'''
输出:
tensor(2.) 
None
tensor(7.) # 2*2+3=7
'''

使用torch.clone()获得的新tensor和原来的数据不再共享内存,但仍保留在计算图中,clone操作在不共享数据内存的同时支持梯度梯度传递与叠加,所以常用在神经网络中某个单元需要重复使用的场景下。

通常如果原tensor的requires_grad=True,则:

  • clone()操作后的tensor requires_grad=True
  • detach()操作后的tensor requires_grad=False。
import torch
torch.manual_seed(0)

x= torch.tensor([1., 2.], requires_grad=True)
clone_x = x.clone() 
detach_x = x.detach()
clone_detach_x = x.clone().detach() 

f = torch.nn.Linear(2, 1)
y = f(x)
y.backward()

print(x.grad)
print(clone_x.requires_grad)
print(clone_x.grad)
print(detach_x.requires_grad)
print(clone_detach_x.requires_grad)
'''
输出结果如下:
tensor([-0.0053, 0.3793])
True
None
False
False
'''

另一个比较特殊的是当源张量的 require_grad=False,clone后的张量 require_grad=True,此时不存在张量回传现象,可以得到clone后的张量求导。

如下:

import torch
a = torch.tensor(1.0)
a_ = a.clone()
a_.requires_grad_() #require_grad=True
y = a_ ** 2
y.backward()
print(a.grad) # None
print(a_.grad) 
'''
输出:
None
tensor(2.)
'''

了解了两者的区别后我们常与其他函数进行搭配使用,实现数据拷贝后的其他需要。

比如我们经常使用view()函数对tensor进行reshape操作。返回的新Tensor与源Tensor可能有不同的size,但是是共享data的,即其中的一个发生变化,另外一个也会跟着改变。

需要注意的是view返回的Tensor与源Tensor是共享data的,但是依然是一个新的Tensor(因为Tensor除了包含data外还有一些其他属性),两者id(内存地址)并不一致。

x = torch.rand(2, 2)
y = x.view(4)
x += 1
print(x)
print(y) # 也加了1

view() 仅仅是改变了对这个张量的观察角度,内部数据并未改变。这时候想返回一个真正新的副本(即不共享data内存)该怎么办呢?Pytorch还提供了一个reshape()可以改变形状,但是此函数并不能保证返回的是其拷贝,所以不推荐使用。推荐先用clone创造一个副本然后再使用view。参考此处

x = torch.rand(2, 2)
x_cp = x.clone().view(4)
x += 1
print(id(x))
print(id(x_cp))
print(x)
print(x_cp)
'''
140568935036464
140568935035816
tensor([[0.4963, 0.7682],
 [0.1320, 0.3074]])
tensor([[1.4963, 1.7682, 1.1320, 1.3074]]) 
'''

另外使用clone()会被记录在计算图中,即梯度回传到副本时也会传到源Tensor。在上一篇中有总结。

总结:

  • torch.detach() — 新的tensor会脱离计算图,不会牵扯梯度计算
  • torch.clone() — 新的tensor充当中间变量,会保留在计算图中,参与梯度计算(回传叠加),但是一般不会保留自身梯度。
    原地操作(in-place, such as resize_ / resize_as_ / set_ / transpose_) 在上面两者中执行都会引发错误或者警告。
  • 共享数据内存是底层设计,并不能简单的通过直接打印tensor的id地址进行判断,需要在进行赋值或运算操作后打印比较数据的变化进行判断。
  • 复制操作可以根据实际需要进行结合使用。

引用官方文档的话:如果你使用了in-place operation而没有报错的话,那么你可以确定你的梯度计算是正确的。另外尽量避免in-place的使用。

像y = x + y这样的运算会新开内存,然后将y指向新内存。我们可以使用Python自带的id函数进行验证:如果两个实例的ID相同,则它们所对应的内存地址相同。

到此这篇关于PyTorch中clone()、detach()及相关扩展详解的文章就介绍到这了,更多相关PyTorch中clone()、detach()及相关扩展内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python中处理字符串之isalpha()方法的使用
May 18 Python
使用Python实现BT种子和磁力链接的相互转换
Nov 09 Python
python实现爬虫统计学校BBS男女比例(一)
Dec 31 Python
python绘制双柱形图代码实例
Dec 14 Python
今天 平安夜 Python 送你一顶圣诞帽 @微信官方
Dec 25 Python
浅谈Python中的作用域规则和闭包
Mar 20 Python
python+opencv实现阈值分割
Dec 26 Python
python读取csv和txt数据转换成向量的实例
Feb 12 Python
python实现列表中最大最小值输出的示例
Jul 09 Python
Django的models模型的具体使用
Jul 15 Python
Python 使用 Pillow 模块给图片添加文字水印的方法
Aug 30 Python
python 获取当前目录下的文件目录和文件名实例代码详解
Mar 10 Python
python调用jenkinsAPI构建jenkins,并传递参数的示例
Dec 09 #Python
python excel多行合并的方法
Dec 09 #Python
PyTorch中的拷贝与就地操作详解
Dec 09 #Python
python 调用Google翻译接口的方法
Dec 09 #Python
浅析Python 中的 WSGI 接口和 WSGI 服务的运行
Dec 09 #Python
python dir函数快速掌握用法技巧
Dec 09 #Python
5 分钟读懂Python 中的 Hook 钩子函数
Dec 09 #Python
You might like
php查找任何页面上的所有链接的方法
2013/12/03 PHP
php+mysqli使用预处理技术进行数据库查询的方法
2015/01/28 PHP
PHP框架Laravel插件Pagination实现自定义分页
2020/04/22 PHP
laravel5 Eloquent 实现事务方式
2019/10/21 PHP
js 实现打印网页中定义的部分内容的代码
2010/04/01 Javascript
利用js实现选项卡的特别效果的实例
2013/03/03 Javascript
jQuery .attr()和.removeAttr()方法操作元素属性示例
2013/07/16 Javascript
jQuery插件Elastislide实现响应式的焦点图无缝滚动切换特效
2015/04/12 Javascript
JavaScript中instanceof运算符的使用示例
2016/06/08 Javascript
Angularjs在初始化未完毕时出现闪烁问题的解决方法分析
2016/08/05 Javascript
原生JavaScript制作计算器
2016/10/16 Javascript
微信小程序 swiper组件详解及实例代码
2016/10/25 Javascript
jquery实时获取时间的简单实例
2017/01/26 Javascript
简单实现js无缝滚动效果
2017/02/05 Javascript
js 数字、字符串、布尔值的转换方法(必看)
2017/04/07 Javascript
vue.js源代码core scedule.js学习笔记
2017/07/03 Javascript
详解如何在vue项目中使用lodop打印插件
2018/09/27 Javascript
JavaScript实现与使用发布/订阅模式详解
2019/01/19 Javascript
layui清空,重置表单数据的实例
2019/09/12 Javascript
Vue实现滑动拼图验证码功能
2019/09/15 Javascript
[01:06:18]DOTA2-DPC中国联赛 正赛 Phoenix vs Dynasty BO3 第二场 1月26日
2021/03/11 DOTA
Python正则表达式教程之三:贪婪/非贪婪特性
2017/03/02 Python
快速入门python学习笔记
2017/12/06 Python
python实现整数的二进制循环移位
2019/03/08 Python
python Manager 之dict KeyError问题的解决
2019/12/21 Python
Python3.6 中的pyinstaller安装和使用教程
2020/03/16 Python
tensorflow使用L2 regularization正则化修正overfitting过拟合方式
2020/05/22 Python
Python configparser模块封装及构造配置文件
2020/08/07 Python
python 删除系统中的文件(按时间,大小,扩展名)
2020/11/19 Python
python多线程爬取西刺代理的示例代码
2021/01/30 Python
从Pytorch模型pth文件中读取参数成numpy矩阵的操作
2021/03/04 Python
大学生求职自我评价
2014/01/16 职场文书
预备党员表决心书
2014/03/11 职场文书
2014年寒假社会实践活动心得体会
2014/04/07 职场文书
开学第一天的感想
2015/08/10 职场文书
Android开发手册TextInputLayout样式使用示例
2022/06/10 Java/Android