PyTorch中的拷贝与就地操作详解


Posted in Python onDecember 09, 2020

前言

PyTroch中我们经常使用到Numpy进行数据的处理,然后再转为Tensor,但是关系到数据的更改时我们要注意方法是否是共享地址,这关系到整个网络的更新。本篇就In-palce操作,拷贝操作中的注意点进行总结。

In-place操作

pytorch中原地操作的后缀为_,如.add_()或.scatter_(),就地操作是直接更改给定Tensor的内容而不进行复制的操作,即不会为变量分配新的内存。Python操作类似+=或*=也是就地操作。(我加了我自己~)

为什么in-place操作可以在处理高维数据时可以帮助减少内存使用呢,下面使用一个例子进行说明,定义以下简单函数来测量PyTorch的异位ReLU(out-of-place)和就地ReLU(in-place)分配的内存:

import torch # import main library
import torch.nn as nn # import modules like nn.ReLU()
import torch.nn.functional as F # import torch functions like F.relu() and F.relu_()

def get_memory_allocated(device, inplace = False):
 '''
 Function measures allocated memory before and after the ReLU function call.
 INPUT:
 - device: gpu device to run the operation
 - inplace: True - to run ReLU in-place, False - for normal ReLU call
 '''
 
 # Create a large tensor
 t = torch.randn(10000, 10000, device=device)
 
 # Measure allocated memory
 torch.cuda.synchronize()
 start_max_memory = torch.cuda.max_memory_allocated() / 1024**2
 start_memory = torch.cuda.memory_allocated() / 1024**2
 
 # Call in-place or normal ReLU
 if inplace:
 F.relu_(t)
 else:
 output = F.relu(t)
 
 # Measure allocated memory after the call
 torch.cuda.synchronize()
 end_max_memory = torch.cuda.max_memory_allocated() / 1024**2
 end_memory = torch.cuda.memory_allocated() / 1024**2
 
 # Return amount of memory allocated for ReLU call
 return end_memory - start_memory, end_max_memory - start_max_memory
 # setup the device
device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
#开始测试
# Call the function to measure the allocated memory for the out-of-place ReLU
memory_allocated, max_memory_allocated = get_memory_allocated(device, inplace = False)
print('Allocated memory: {}'.format(memory_allocated))
print('Allocated max memory: {}'.format(max_memory_allocated))
'''
Allocated memory: 382.0
Allocated max memory: 382.0
'''
#Then call the in-place ReLU as follows:
memory_allocated_inplace, max_memory_allocated_inplace = get_memory_allocated(device, inplace = True)
print('Allocated memory: {}'.format(memory_allocated_inplace))
print('Allocated max memory: {}'.format(max_memory_allocated_inplace))
'''
Allocated memory: 0.0
Allocated max memory: 0.0
'''

看起来,使用就地操作可以帮助我们节省一些GPU内存。但是,在使用就地操作时应该格外谨慎。

就地操作的主要缺点主要原因有2点,官方文档:

1.可能会覆盖计算梯度所需的值,这意味着破坏了模型的训练过程。

2.每个就地操作实际上都需要实现来重写计算图。异地操作Out-of-place分配新对象并保留对旧图的引用,而就地操作则需要更改表示此操作的函数的所有输入的创建者。

在Autograd中支持就地操作很困难,并且在大多数情况下不鼓励使用。Autograd积极的缓冲区释放和重用使其非常高效,就地操作实际上降低内存使用量的情况很少。除非在沉重的内存压力下运行,否则可能永远不需要使用它们。

总结:Autograd很香了,就地操作要慎用。

拷贝方法

浅拷贝方法: 共享 data 的内存地址,数据会同步变化

* a.numpy() # Tensor—>Numpy array

* view() #改变tensor的形状,但共享数据内存,不要直接使用id进行判断

* y = x[:] # 索引

* torch.from_numpy() # Numpy array—>Tensor

* torch.detach() # 新的tensor会脱离计算图,不会牵扯梯度计算。

* model:forward()

还有很多选择函数也是数据共享内存,如index_select() masked_select() gather()。

以及后文提到的就地操作in-place。

深拷贝方法:

* torch.clone() # 新的tensor会保留在计算图中,参与梯度计算

下面进行验证,首先验证浅拷贝:

import torch as t
import numpy as np
a = np.ones(4)
b = t.from_numpy(a) # Numpy->Tensor
print(a)
print(b)
'''输出:
[1. 1. 1. 1.]
tensor([1., 1., 1., 1.], dtype=torch.float64)
'''
b.add_(1)# add_会修改b自身
print(a)
print(b)
'''输出:
[2. 2. 2. 2.]
tensor([2., 2., 2., 2.], dtype=torch.float64)
b进行add操作后, a,b同步发生了变化
'''

Tensor和numpy对象共享内存(浅拷贝操作),所以他们之间的转换很快,且会同步变化。

造torch中y = x + y这样的运算是会新开内存的,然后将y指向新内存。为了进行验证,我们可以使用Python自带的id函数:如果两个实例的ID一致,那么它们所对应的内存地址相同;但需要注意是在torch中还有些特殊,数据共享时直接打印tensor的id仍然会出现不同。

x = torch.tensor([1, 2])
y = torch.tensor([3, 4])
id_0 = id(y)
y = y + x
print(id(y) == id_0) 
# False

这时使用索引操作不会开辟新的内存,而想指定结果到原来的y的内存,我们可以使用索引来进行替换操作。比如把x + y的结果通过[:]写进y对应的内存中。

x = torch.tensor([1, 2])
y = torch.tensor([3, 4])
id_0 = id(y)
y[:] = y + x
print(id(y) == id_0) 
# True

另外,以下两种方式也可以索引到相同的内存:

  • torch.add(x, y, out=y)
  • y += x, y.add_(x)
x = torch.tensor([1, 2])
y = torch.tensor([3, 4])
id_0 = id(y)
torch.add(x, y, out=y) 
# y += x, y.add_(x)
print(id(y) == id_0) 
# True

clone() 与 detach() 对比

Torch 为了提高速度,向量或是矩阵的赋值是指向同一内存的,这不同于 Matlab。如果需要保存旧的tensor即需要开辟新的存储地址而不是引用,可以用 clone() 进行深拷贝,

首先我们来打印出来clone()操作后的数据类型定义变化:

(1). 简单打印类型

import torch

a = torch.tensor(1.0, requires_grad=True)
b = a.clone()
c = a.detach()
a.data *= 3
b += 1

print(a) # tensor(3., requires_grad=True)
print(b)
print(c)

'''
输出结果:
tensor(3., requires_grad=True)
tensor(2., grad_fn=<AddBackward0>)
tensor(3.)  # detach()后的值随着a的变化出现变化
'''

grad_fn=<CloneBackward>,表示clone后的返回值是个中间变量,因此支持梯度的回溯。clone操作在一定程度上可以视为是一个identity-mapping函数。

detach()操作后的tensor与原始tensor共享数据内存,当原始tensor在计算图中数值发生反向传播等更新之后,detach()的tensor值也发生了改变。

注意: 在pytorch中我们不要直接使用id是否相等来判断tensor是否共享内存,这只是充分条件,因为也许底层共享数据内存,但是仍然是新的tensor,比如detach(),如果我们直接打印id会出现以下情况。

import torch as t
a = t.tensor([1.0,2.0], requires_grad=True)
b = a.detach()
#c[:] = a.detach()
print(id(a))
print(id(b))
#140568935450520
140570337203616

显然直接打印出来的id不等,我们可以通过简单的赋值后观察数据变化进行判断。

(2). clone()的梯度回传

detach()函数可以返回一个完全相同的tensor,与旧的tensor共享内存,脱离计算图,不会牵扯梯度计算。

而clone充当中间变量,会将梯度传给源张量进行叠加,但是本身不保存其grad,即值为None

import torch
a = torch.tensor(1.0, requires_grad=True)
a_ = a.clone()
y = a**2
z = a ** 2+a_ * 3
y.backward()
print(a.grad) # 2
z.backward()
print(a_.grad) # None. 中间variable,无grad
print(a.grad) 
'''
输出:
tensor(2.) 
None
tensor(7.) # 2*2+3=7
'''

使用torch.clone()获得的新tensor和原来的数据不再共享内存,但仍保留在计算图中,clone操作在不共享数据内存的同时支持梯度梯度传递与叠加,所以常用在神经网络中某个单元需要重复使用的场景下。

通常如果原tensor的requires_grad=True,则:

  • clone()操作后的tensor requires_grad=True
  • detach()操作后的tensor requires_grad=False。
import torch
torch.manual_seed(0)

x= torch.tensor([1., 2.], requires_grad=True)
clone_x = x.clone() 
detach_x = x.detach()
clone_detach_x = x.clone().detach() 

f = torch.nn.Linear(2, 1)
y = f(x)
y.backward()

print(x.grad)
print(clone_x.requires_grad)
print(clone_x.grad)
print(detach_x.requires_grad)
print(clone_detach_x.requires_grad)
'''
输出结果如下:
tensor([-0.0053, 0.3793])
True
None
False
False
'''

另一个比较特殊的是当源张量的 require_grad=False,clone后的张量 require_grad=True,此时不存在张量回传现象,可以得到clone后的张量求导。

如下:

import torch
a = torch.tensor(1.0)
a_ = a.clone()
a_.requires_grad_() #require_grad=True
y = a_ ** 2
y.backward()
print(a.grad) # None
print(a_.grad) 
'''
输出:
None
tensor(2.)
'''

总结:

torch.detach() —新的tensor会脱离计算图,不会牵扯梯度计算

torch.clone() — 新的tensor充当中间变量,会保留在计算图中,参与梯度计算(回传叠加),但是一般不会保留自身梯度。

原地操作(in-place, such as resize_ / resize_as_ / set_ / transpose_) 在上面两者中执行都会引发错误或者警告。

引用官方文档的话:如果你使用了in-place operation而没有报错的话,那么你可以确定你的梯度计算是正确的。另外尽量避免in-place的使用。

到此这篇关于PyTorch中拷贝与就地操作的文章就介绍到这了,更多相关PyTorch拷贝与就地操作内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python使用bs4获取58同城城市分类的方法
Jul 08 Python
在Python的Django框架中更新数据库数据的方法
Jul 17 Python
一个基于flask的web应用诞生 bootstrap框架美化(3)
Apr 11 Python
python3设计模式之简单工厂模式
Oct 17 Python
Python3使用PyQt5制作简单的画板/手写板实例
Oct 19 Python
Python语言描述连续子数组的最大和
Jan 04 Python
python 利用文件锁单例执行脚本的方法
Feb 19 Python
python selenium firefox使用详解
Feb 26 Python
对django views中 request, response的常用操作详解
Jul 17 Python
利用rest framework搭建Django API过程解析
Aug 31 Python
用openCV和Python 实现图片对比,并标识出不同点的方式
Dec 19 Python
PyCharm2019 安装和配置教程详解附激活码
Jul 31 Python
python 调用Google翻译接口的方法
Dec 09 #Python
浅析Python 中的 WSGI 接口和 WSGI 服务的运行
Dec 09 #Python
python dir函数快速掌握用法技巧
Dec 09 #Python
5 分钟读懂Python 中的 Hook 钩子函数
Dec 09 #Python
Python爬虫教程之利用正则表达式匹配网页内容
Dec 08 #Python
Python创建文件夹与文件的快捷方法
Dec 08 #Python
Python之字符串的遍历的4种方式
Dec 08 #Python
You might like
基于qmail的完整WEBMAIL解决方案安装详解
2006/10/09 PHP
php获取目标函数执行时间示例
2014/03/04 PHP
thinkPHP简单导入和使用阿里云OSSsdk的方法
2017/03/15 PHP
yii框架使用分页的方法分析
2019/07/25 PHP
用js实现的模拟jquery的animate自定义动画(2.5K)
2010/07/20 Javascript
JavaScript定义类或函数的几种方式小结
2011/01/09 Javascript
js中的json对象详细介绍
2014/10/29 Javascript
js读取csv文件并使用json显示出来
2015/01/09 Javascript
JavaScript插件化开发教程(六)
2015/02/01 Javascript
Javascript毫秒数用法实例
2015/02/05 Javascript
JavaScript判断浏览器类型的方法
2015/02/10 Javascript
JS实现控制表格内指定单元格内容对齐的方法
2015/03/30 Javascript
解析JavaScript的ES6版本中的解构赋值
2015/07/28 Javascript
如何用js 实现依赖注入的思想,后端框架思想搬到前端来
2015/08/03 Javascript
Js获取图片原始宽高的实现代码
2016/05/17 Javascript
浅谈javascript:两种注释,声明变量,定义函数
2016/10/05 Javascript
AngularJS的ng-repeat指令与scope继承关系实例详解
2017/01/21 Javascript
js 中rewrap-ajax.js插件实例代码
2017/10/20 Javascript
JS匿名函数和匿名自执行函数概念与用法分析
2018/03/16 Javascript
Node 升级到最新稳定版的方法分享
2018/05/17 Javascript
angular将html代码输出为内容的实例
2018/09/30 Javascript
解决微信小程序云开发中获取数据库的内容为空的方法
2019/05/15 Javascript
[33:17]OG vs VGJ.T 2018国际邀请赛小组赛BO2 第二场 8.18
2018/08/19 DOTA
[48:28]完美世界DOTA2联赛循环赛FTD vs Magma第二场 10月30日
2020/10/31 DOTA
Python中文字符串截取问题
2015/06/15 Python
详细分析Python collections工具库
2020/07/16 Python
香蕉共和国工厂店:Banana Republic Factory
2018/06/09 全球购物
Lookfantastic香港官网:英国知名美妆购物网站
2018/06/19 全球购物
静态成员和非静态成员的区别
2012/05/12 面试题
精彩的推荐信范文
2013/11/26 职场文书
厂长岗位职责
2014/02/19 职场文书
数控技校生自我鉴定
2014/04/19 职场文书
2014年人事科工作总结
2014/11/19 职场文书
2015年党员自我剖析材料
2014/12/17 职场文书
七年级作文之环保作文
2019/10/17 职场文书
GoLang中生成UUID唯一标识的实现
2021/05/08 Golang