解决Pytorch自定义层出现多Variable共享内存错误问题


Posted in Python onJune 28, 2020

错误信息:

RuntimeError: in-place operations can be only used on variables that don't share storage with any other variables, but detected that there are 4 objects sharing it

自动求导是很方便, 但是想想, 如果两个Variable共享内存, 再对这个共享的内存的数据进行修改, 就会引起错误!

一般是由于 inplace操作或是indexing或是转置. 这些都是共享内存的.

@staticmethod
 def backward(ctx, grad_output):
  ind_lst = ctx.ind_lst
  flag = ctx.flag

  c = grad_output.size(1)
  grad_former_all = grad_output[:, 0:c//3, :, :]
  grad_latter_all = grad_output[:, c//3: c*2//3, :, :]
  grad_swapped_all = grad_output[:, c*2//3:c, :, :]

  spatial_size = ctx.h * ctx.w

  W_mat_all = Variable(ctx.Tensor(ctx.bz, spatial_size, spatial_size).zero_())
  for idx in range(ctx.bz):
   W_mat = W_mat_all.select(0,idx)
   for cnt in range(spatial_size):
    indS = ind_lst[idx][cnt] 

    if flag[cnt] == 1:
     # 这里W_mat是W_mat_all通过select出来的, 他们共享内存.
     W_mat[cnt, indS] = 1

   W_mat_t = W_mat.t()

   grad_swapped_weighted = torch.mm(W_mat_t, grad_swapped_all[idx].view(c//3, -1).t())
   grad_swapped_weighted = grad_swapped_weighted.t().contiguous().view(1, c//3, ctx.h, ctx.w)
   grad_latter_all[idx] = torch.add(grad_latter_all[idx], grad_swapped_weighted.mul(ctx.triple_w))

由于 这里W_mat是W_mat_all通过select出来的, 他们共享内存. 所以当对这个共享的内存进行修改W_mat[cnt, indS] = 1, 就会出错. 此时我们可以通过clone()将W_mat和W_mat_all独立出来. 这样的话, 梯度也会通过 clone()操作将W_mat的梯度正确反传到W_mat_all中.

@staticmethod
 def backward(ctx, grad_output):
  ind_lst = ctx.ind_lst
  flag = ctx.flag

  c = grad_output.size(1)
  grad_former_all = grad_output[:, 0:c//3, :, :]
  grad_latter_all = grad_output[:, c//3: c*2//3, :, :]
  grad_swapped_all = grad_output[:, c*2//3:c, :, :]

  spatial_size = ctx.h * ctx.w

  W_mat_all = Variable(ctx.Tensor(ctx.bz, spatial_size, spatial_size).zero_())
  for idx in range(ctx.bz):
   # 这里使用clone了
   W_mat = W_mat_all.select(0,idx).clone()
   for cnt in range(spatial_size):
    indS = ind_lst[idx][cnt]

    if flag[cnt] == 1:
     W_mat[cnt, indS] = 1

   W_mat_t = W_mat.t()

   grad_swapped_weighted = torch.mm(W_mat_t, grad_swapped_all[idx].view(c//3, -1).t())
   grad_swapped_weighted = grad_swapped_weighted.t().contiguous().view(1, c//3, ctx.h, ctx.w)

   # 这句话删了不会出错, 加上就吹出错
   grad_latter_all[idx] = torch.add(grad_latter_all[idx], grad_swapped_weighted.mul(ctx.triple_w))

但是现在却出现 4个objects共享内存. 如果将最后一句话删掉, 那么则不会出错.

如果没有最后一句话, 我们看到

grad_swapped_weighted = torch.mm(W_mat_t, grad_swapped_all[idx].view(c//3, -1).t())

grad_swapped_weighted = grad_swapped_weighted.t().contiguous().view(1, c//3, ctx.h, ctx.w)

grad_swapped_weighted 一个新的Variable, 因此并没有和其他Variable共享内存, 所以不会出错. 但是最后一句话,

grad_latter_all[idx] = torch.add(grad_latter_all[idx], grad_swapped_weighted.mul(ctx.triple_w))

你可能会说, 不对啊, 修改grad_latter_all[idx]又没有创建新的Variable, 怎么会出错. 这是因为grad_latter_all和grad_output是共享内存的. 因为 grad_latter_all = grad_output[:, c//3: c*2//3, :, :], 所以这里的解决方案是:

@staticmethod
 def backward(ctx, grad_output):
  ind_lst = ctx.ind_lst
  flag = ctx.flag

  c = grad_output.size(1)
  grad_former_all = grad_output[:, 0:c//3, :, :]
  # 这两个后面修改值了, 所以也要加clone, 防止它们与grad_output共享内存
  grad_latter_all = grad_output[:, c//3: c*2//3, :, :].clone()
  grad_swapped_all = grad_output[:, c*2//3:c, :, :].clone()

  spatial_size = ctx.h * ctx.w

  W_mat_all = Variable(ctx.Tensor(ctx.bz, spatial_size, spatial_size).zero_())
  for idx in range(ctx.bz):
   W_mat = W_mat_all.select(0,idx).clone()
   for cnt in range(spatial_size):
    indS = ind_lst[idx][cnt]

    if flag[cnt] == 1:
     W_mat[cnt, indS] = 1

   W_mat_t = W_mat.t()

   grad_swapped_weighted = torch.mm(W_mat_t, grad_swapped_all[idx].view(c//3, -1).t())

   grad_swapped_weighted = grad_swapped_weighted.t().contiguous().view(1, c//3, ctx.h, ctx.w)
   grad_latter_all[idx] = torch.add(grad_latter_all[idx], grad_swapped_weighted.mul(ctx.triple_w))

  grad_input = torch.cat([grad_former_all, grad_latter_all], 1)

  return grad_input, None, None, None, None, None, None, None, None, None, None

补充知识:Pytorch 中 expand, expand_as是共享内存的,只是原始数据的一个视图 view

如下所示:

mask = mask_miss.expand_as(sxing).clone() # type: torch.Tensor
mask[:, :, -2, :, :] = 1 # except for person mask channel

为了避免对expand后对某个channel操作会影响原始tensor的全部元素,需要使用clone()

如果没有clone(),对mask_miss的某个通道赋值后,所有通道上的tensor都会变成1!

# Notice! expand does not allocate more memory but just make the tensor look as if you expanded it.
# You should call .clone() on the resulting tensor if you plan on modifying it
# https://discuss.pytorch.org/t/very-strange-behavior-change-one-element-of-a-tensor-will-influence-all-elements/41190

以上这篇解决Pytorch自定义层出现多Variable共享内存错误问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python模块学习 datetime介绍
Aug 27 Python
10种检测Python程序运行时间、CPU和内存占用的方法
Apr 01 Python
在Python中操作字典之fromkeys()方法的使用
May 21 Python
关于numpy中np.nonzero()函数用法的详解
Feb 07 Python
socket + select 完成伪并发操作的实例
Aug 15 Python
Python绘制的二项分布概率图示例
Aug 22 Python
python搜索包的路径的实现方法
Jul 19 Python
python logging添加filter教程
Dec 24 Python
怎么快速自学python
Jun 22 Python
Python 操作 MySQL数据库
Sep 18 Python
PyQt5 显示超清高分辨率图片的方法
Apr 11 Python
Django数据库(SQlite)基本入门使用教程
Jul 07 Python
Pytorch学习之torch用法----比较操作(Comparison Ops)
Jun 28 #Python
PyTorch的torch.cat用法
Jun 28 #Python
使用pytorch 筛选出一定范围的值
Jun 28 #Python
解析python 中/ 和 % 和 //(地板除)
Jun 28 #Python
pytorch 常用函数 max ,eq说明
Jun 28 #Python
浅谈pytorch中torch.max和F.softmax函数的维度解释
Jun 28 #Python
Python turtle库的画笔控制说明
Jun 28 #Python
You might like
成为好程序员必须避免的5个坏习惯
2014/07/04 PHP
Fedora下安装php Redis扩展笔记
2014/09/03 PHP
使用Huagepage和PGO来提升PHP7的执行性能
2015/11/30 PHP
解读PHP的Yii框架中请求与响应的处理流程
2016/03/17 PHP
PHP有序表查找之二分查找(折半查找)算法示例
2018/02/09 PHP
详细解读php的命名空间(一)
2018/02/21 PHP
jQuery Tips 为AJAX回调函数传递额外参数的方法
2010/12/28 Javascript
js实现连个数字相加而不是拼接的方法
2014/02/23 Javascript
使用JavaScript 编写简单计算器
2014/11/24 Javascript
JavaScript自定义数组排序方法
2015/02/12 Javascript
学习javascript的闭包,原型,和匿名函数之旅
2015/10/18 Javascript
JS实现的鼠标跟随代码(卡通手型点击效果)
2015/10/26 Javascript
详解javascript数组去重问题
2015/11/06 Javascript
js与jQuery实现checkbox复选框全选/全不选的方法
2016/01/05 Javascript
JavaScript组合模式学习要点
2016/08/26 Javascript
NodeJS遍历文件生产文件列表功能示例
2017/01/22 NodeJs
深入浅析ES6 Class 中的 super 关键字
2017/10/20 Javascript
关闭Vue计算属性自带的缓存功能方法
2018/03/02 Javascript
JS实现的哈夫曼编码示例【原始版与修改版】
2018/04/22 Javascript
mockjs+vue页面直接展示数据的方法
2018/12/19 Javascript
关于layui导航栏不展示下拉列表的解决方法
2019/09/25 Javascript
Python深入学习之闭包
2014/08/31 Python
python中int与str互转方法
2018/07/02 Python
pandas pivot_table() 按日期分多列数据的方法
2018/11/16 Python
对python特殊函数 __call__()的使用详解
2019/07/02 Python
pandas按行按列遍历Dataframe的几种方式
2019/10/23 Python
python读取raw binary图片并提取统计信息的实例
2020/01/09 Python
Python排序函数的使用方法详解
2020/12/11 Python
函授毕业自我鉴定
2013/12/19 职场文书
平民服装店创业计划书
2014/01/17 职场文书
环保建议书300字
2014/05/14 职场文书
医学求职信
2014/05/28 职场文书
群众路线教育实践活动自我剖析思想汇报
2014/10/04 职场文书
清明节主题班会
2015/08/14 职场文书
青少年法制教育心得体会
2016/01/14 职场文书
使用Python通过企业微信应用给企业成员发消息
2022/04/18 Python