PyTorch中clone()、detach()及相关扩展详解


Posted in Python onDecember 09, 2020

clone() 与 detach() 对比

Torch 为了提高速度,向量或是矩阵的赋值是指向同一内存的,这不同于 Matlab。如果需要保存旧的tensor即需要开辟新的存储地址而不是引用,可以用 clone() 进行深拷贝,

首先我们来打印出来clone()操作后的数据类型定义变化:

(1). 简单打印类型

import torch

a = torch.tensor(1.0, requires_grad=True)
b = a.clone()
c = a.detach()
a.data *= 3
b += 1

print(a) # tensor(3., requires_grad=True)
print(b)
print(c)

'''
输出结果:
tensor(3., requires_grad=True)
tensor(2., grad_fn=<AddBackward0>)
tensor(3.) # detach()后的值随着a的变化出现变化
'''

grad_fn=<CloneBackward>,表示clone后的返回值是个中间变量,因此支持梯度的回溯。clone操作在一定程度上可以视为是一个identity-mapping函数。

detach()操作后的tensor与原始tensor共享数据内存,当原始tensor在计算图中数值发生反向传播等更新之后,detach()的tensor值也发生了改变。

注意: 在pytorch中我们不要直接使用id是否相等来判断tensor是否共享内存,这只是充分条件,因为也许底层共享数据内存,但是仍然是新的tensor,比如detach(),如果我们直接打印id会出现以下情况。

import torch as t
a = t.tensor([1.0,2.0], requires_grad=True)
b = a.detach()
#c[:] = a.detach()
print(id(a))
print(id(b))
#140568935450520
140570337203616

显然直接打印出来的id不等,我们可以通过简单的赋值后观察数据变化进行判断。

(2). clone()的梯度回传

detach()函数可以返回一个完全相同的tensor,与旧的tensor共享内存,脱离计算图,不会牵扯梯度计算。

而clone充当中间变量,会将梯度传给源张量进行叠加,但是本身不保存其grad,即值为None

import torch
a = torch.tensor(1.0, requires_grad=True)
a_ = a.clone()
y = a**2
z = a ** 2+a_ * 3
y.backward()
print(a.grad) # 2
z.backward()
print(a_.grad) # None. 中间variable,无grad
print(a.grad) 
'''
输出:
tensor(2.) 
None
tensor(7.) # 2*2+3=7
'''

使用torch.clone()获得的新tensor和原来的数据不再共享内存,但仍保留在计算图中,clone操作在不共享数据内存的同时支持梯度梯度传递与叠加,所以常用在神经网络中某个单元需要重复使用的场景下。

通常如果原tensor的requires_grad=True,则:

  • clone()操作后的tensor requires_grad=True
  • detach()操作后的tensor requires_grad=False。
import torch
torch.manual_seed(0)

x= torch.tensor([1., 2.], requires_grad=True)
clone_x = x.clone() 
detach_x = x.detach()
clone_detach_x = x.clone().detach() 

f = torch.nn.Linear(2, 1)
y = f(x)
y.backward()

print(x.grad)
print(clone_x.requires_grad)
print(clone_x.grad)
print(detach_x.requires_grad)
print(clone_detach_x.requires_grad)
'''
输出结果如下:
tensor([-0.0053, 0.3793])
True
None
False
False
'''

另一个比较特殊的是当源张量的 require_grad=False,clone后的张量 require_grad=True,此时不存在张量回传现象,可以得到clone后的张量求导。

如下:

import torch
a = torch.tensor(1.0)
a_ = a.clone()
a_.requires_grad_() #require_grad=True
y = a_ ** 2
y.backward()
print(a.grad) # None
print(a_.grad) 
'''
输出:
None
tensor(2.)
'''

了解了两者的区别后我们常与其他函数进行搭配使用,实现数据拷贝后的其他需要。

比如我们经常使用view()函数对tensor进行reshape操作。返回的新Tensor与源Tensor可能有不同的size,但是是共享data的,即其中的一个发生变化,另外一个也会跟着改变。

需要注意的是view返回的Tensor与源Tensor是共享data的,但是依然是一个新的Tensor(因为Tensor除了包含data外还有一些其他属性),两者id(内存地址)并不一致。

x = torch.rand(2, 2)
y = x.view(4)
x += 1
print(x)
print(y) # 也加了1

view() 仅仅是改变了对这个张量的观察角度,内部数据并未改变。这时候想返回一个真正新的副本(即不共享data内存)该怎么办呢?Pytorch还提供了一个reshape()可以改变形状,但是此函数并不能保证返回的是其拷贝,所以不推荐使用。推荐先用clone创造一个副本然后再使用view。参考此处

x = torch.rand(2, 2)
x_cp = x.clone().view(4)
x += 1
print(id(x))
print(id(x_cp))
print(x)
print(x_cp)
'''
140568935036464
140568935035816
tensor([[0.4963, 0.7682],
 [0.1320, 0.3074]])
tensor([[1.4963, 1.7682, 1.1320, 1.3074]]) 
'''

另外使用clone()会被记录在计算图中,即梯度回传到副本时也会传到源Tensor。在上一篇中有总结。

总结:

  • torch.detach() — 新的tensor会脱离计算图,不会牵扯梯度计算
  • torch.clone() — 新的tensor充当中间变量,会保留在计算图中,参与梯度计算(回传叠加),但是一般不会保留自身梯度。
    原地操作(in-place, such as resize_ / resize_as_ / set_ / transpose_) 在上面两者中执行都会引发错误或者警告。
  • 共享数据内存是底层设计,并不能简单的通过直接打印tensor的id地址进行判断,需要在进行赋值或运算操作后打印比较数据的变化进行判断。
  • 复制操作可以根据实际需要进行结合使用。

引用官方文档的话:如果你使用了in-place operation而没有报错的话,那么你可以确定你的梯度计算是正确的。另外尽量避免in-place的使用。

像y = x + y这样的运算会新开内存,然后将y指向新内存。我们可以使用Python自带的id函数进行验证:如果两个实例的ID相同,则它们所对应的内存地址相同。

到此这篇关于PyTorch中clone()、detach()及相关扩展详解的文章就介绍到这了,更多相关PyTorch中clone()、detach()及相关扩展内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
布同 统计英文单词的个数的python代码
Mar 13 Python
在Python中利用Into包整洁地进行数据迁移的教程
Mar 30 Python
python实现在sqlite动态创建表的方法
May 08 Python
pymongo为mongodb数据库添加索引的方法
May 11 Python
简单讲解Python中的闭包
Aug 11 Python
python负载均衡的简单实现方法
Feb 04 Python
Pandas:Series和DataFrame删除指定轴上数据的方法
Nov 10 Python
Python unittest单元测试框架实现参数化
Apr 29 Python
Python爬虫爬取博客实现可视化过程解析
Jun 29 Python
python 利用百度API识别图片文字(多线程版)
Dec 14 Python
地图可视化神器kepler.gl python接口的使用方法
Dec 22 Python
在Python中如何使用yield
Jun 07 Python
python调用jenkinsAPI构建jenkins,并传递参数的示例
Dec 09 #Python
python excel多行合并的方法
Dec 09 #Python
PyTorch中的拷贝与就地操作详解
Dec 09 #Python
python 调用Google翻译接口的方法
Dec 09 #Python
浅析Python 中的 WSGI 接口和 WSGI 服务的运行
Dec 09 #Python
python dir函数快速掌握用法技巧
Dec 09 #Python
5 分钟读懂Python 中的 Hook 钩子函数
Dec 09 #Python
You might like
apache2.2.4+mysql5.0.77+php5.2.8安装精简
2009/04/29 PHP
Yii安装与使用Excel扩展的方法
2016/07/13 PHP
PHP-CGI远程代码执行漏洞分析与防范
2017/05/07 PHP
PHP中用Trait封装单例模式的实现
2019/12/18 PHP
一个不错的用JavaScript实现的UBB编码函数
2007/03/09 Javascript
基于jquery的direction图片渐变动画效果
2010/05/24 Javascript
StringTemplate遇见jQuery冲突的解决方法
2011/09/22 Javascript
JavaScript DOM节点添加示例
2014/07/16 Javascript
使用jQuery判断浏览器滚动条位置的方法
2016/05/30 Javascript
Javascript 获取鼠标当前的位置实现方法
2016/10/27 Javascript
echarts鼠标覆盖高亮显示节点及关系名称详解
2018/03/17 Javascript
Vue-cropper 图片裁剪的基本原理及思路讲解
2018/04/17 Javascript
微信小程序项目总结之点赞 删除列表 分享功能
2018/06/25 Javascript
vue登录以及权限验证相关的实现
2019/10/25 Javascript
JS实现按比例缩小图片宽高
2020/08/24 Javascript
python中快速进行多个字符替换的方法小结
2016/12/15 Python
Python实现快速排序的方法详解
2019/10/25 Python
利用Python校准本地时间的方法教程
2019/10/31 Python
Python3如何对urllib和urllib2进行重构
2019/11/25 Python
python3 动态模块导入与全局变量使用实例
2019/12/22 Python
Python JSON编解码方式原理详解
2020/01/20 Python
手把手教你如何用Pycharm2020.1.1配置远程连接的详细步骤
2020/08/07 Python
关于iframe跨域使用postMessage的实现
2019/10/29 HTML / CSS
阿迪达斯丹麦官网:adidas丹麦
2016/10/01 全球购物
Lookfantastic西班牙官网:英国知名美妆购物网站
2018/06/13 全球购物
Converse匡威法国官网:美国著名帆布鞋品牌
2018/12/05 全球购物
简历自我评价怎么写呢?
2014/01/06 职场文书
教师个人剖析材料
2014/02/05 职场文书
机械制造专业毕业生求职信
2014/03/02 职场文书
小小商店教学反思
2014/04/27 职场文书
演讲稿祖国在我心中
2014/05/04 职场文书
法制宣传教育方案
2014/05/09 职场文书
2014国庆黄金周超市促销活动方案
2014/09/21 职场文书
2015年妇幼卫生工作总结
2015/05/23 职场文书
Pycharm远程调试和MySQL数据库授权问题
2022/03/18 MySQL
分享Python异步爬取知乎热榜
2022/04/12 Python