解决Pytorch自定义层出现多Variable共享内存错误问题


Posted in Python onJune 28, 2020

错误信息:

RuntimeError: in-place operations can be only used on variables that don't share storage with any other variables, but detected that there are 4 objects sharing it

自动求导是很方便, 但是想想, 如果两个Variable共享内存, 再对这个共享的内存的数据进行修改, 就会引起错误!

一般是由于 inplace操作或是indexing或是转置. 这些都是共享内存的.

@staticmethod
 def backward(ctx, grad_output):
  ind_lst = ctx.ind_lst
  flag = ctx.flag

  c = grad_output.size(1)
  grad_former_all = grad_output[:, 0:c//3, :, :]
  grad_latter_all = grad_output[:, c//3: c*2//3, :, :]
  grad_swapped_all = grad_output[:, c*2//3:c, :, :]

  spatial_size = ctx.h * ctx.w

  W_mat_all = Variable(ctx.Tensor(ctx.bz, spatial_size, spatial_size).zero_())
  for idx in range(ctx.bz):
   W_mat = W_mat_all.select(0,idx)
   for cnt in range(spatial_size):
    indS = ind_lst[idx][cnt] 

    if flag[cnt] == 1:
     # 这里W_mat是W_mat_all通过select出来的, 他们共享内存.
     W_mat[cnt, indS] = 1

   W_mat_t = W_mat.t()

   grad_swapped_weighted = torch.mm(W_mat_t, grad_swapped_all[idx].view(c//3, -1).t())
   grad_swapped_weighted = grad_swapped_weighted.t().contiguous().view(1, c//3, ctx.h, ctx.w)
   grad_latter_all[idx] = torch.add(grad_latter_all[idx], grad_swapped_weighted.mul(ctx.triple_w))

由于 这里W_mat是W_mat_all通过select出来的, 他们共享内存. 所以当对这个共享的内存进行修改W_mat[cnt, indS] = 1, 就会出错. 此时我们可以通过clone()将W_mat和W_mat_all独立出来. 这样的话, 梯度也会通过 clone()操作将W_mat的梯度正确反传到W_mat_all中.

@staticmethod
 def backward(ctx, grad_output):
  ind_lst = ctx.ind_lst
  flag = ctx.flag

  c = grad_output.size(1)
  grad_former_all = grad_output[:, 0:c//3, :, :]
  grad_latter_all = grad_output[:, c//3: c*2//3, :, :]
  grad_swapped_all = grad_output[:, c*2//3:c, :, :]

  spatial_size = ctx.h * ctx.w

  W_mat_all = Variable(ctx.Tensor(ctx.bz, spatial_size, spatial_size).zero_())
  for idx in range(ctx.bz):
   # 这里使用clone了
   W_mat = W_mat_all.select(0,idx).clone()
   for cnt in range(spatial_size):
    indS = ind_lst[idx][cnt]

    if flag[cnt] == 1:
     W_mat[cnt, indS] = 1

   W_mat_t = W_mat.t()

   grad_swapped_weighted = torch.mm(W_mat_t, grad_swapped_all[idx].view(c//3, -1).t())
   grad_swapped_weighted = grad_swapped_weighted.t().contiguous().view(1, c//3, ctx.h, ctx.w)

   # 这句话删了不会出错, 加上就吹出错
   grad_latter_all[idx] = torch.add(grad_latter_all[idx], grad_swapped_weighted.mul(ctx.triple_w))

但是现在却出现 4个objects共享内存. 如果将最后一句话删掉, 那么则不会出错.

如果没有最后一句话, 我们看到

grad_swapped_weighted = torch.mm(W_mat_t, grad_swapped_all[idx].view(c//3, -1).t())

grad_swapped_weighted = grad_swapped_weighted.t().contiguous().view(1, c//3, ctx.h, ctx.w)

grad_swapped_weighted 一个新的Variable, 因此并没有和其他Variable共享内存, 所以不会出错. 但是最后一句话,

grad_latter_all[idx] = torch.add(grad_latter_all[idx], grad_swapped_weighted.mul(ctx.triple_w))

你可能会说, 不对啊, 修改grad_latter_all[idx]又没有创建新的Variable, 怎么会出错. 这是因为grad_latter_all和grad_output是共享内存的. 因为 grad_latter_all = grad_output[:, c//3: c*2//3, :, :], 所以这里的解决方案是:

@staticmethod
 def backward(ctx, grad_output):
  ind_lst = ctx.ind_lst
  flag = ctx.flag

  c = grad_output.size(1)
  grad_former_all = grad_output[:, 0:c//3, :, :]
  # 这两个后面修改值了, 所以也要加clone, 防止它们与grad_output共享内存
  grad_latter_all = grad_output[:, c//3: c*2//3, :, :].clone()
  grad_swapped_all = grad_output[:, c*2//3:c, :, :].clone()

  spatial_size = ctx.h * ctx.w

  W_mat_all = Variable(ctx.Tensor(ctx.bz, spatial_size, spatial_size).zero_())
  for idx in range(ctx.bz):
   W_mat = W_mat_all.select(0,idx).clone()
   for cnt in range(spatial_size):
    indS = ind_lst[idx][cnt]

    if flag[cnt] == 1:
     W_mat[cnt, indS] = 1

   W_mat_t = W_mat.t()

   grad_swapped_weighted = torch.mm(W_mat_t, grad_swapped_all[idx].view(c//3, -1).t())

   grad_swapped_weighted = grad_swapped_weighted.t().contiguous().view(1, c//3, ctx.h, ctx.w)
   grad_latter_all[idx] = torch.add(grad_latter_all[idx], grad_swapped_weighted.mul(ctx.triple_w))

  grad_input = torch.cat([grad_former_all, grad_latter_all], 1)

  return grad_input, None, None, None, None, None, None, None, None, None, None

补充知识:Pytorch 中 expand, expand_as是共享内存的,只是原始数据的一个视图 view

如下所示:

mask = mask_miss.expand_as(sxing).clone() # type: torch.Tensor
mask[:, :, -2, :, :] = 1 # except for person mask channel

为了避免对expand后对某个channel操作会影响原始tensor的全部元素,需要使用clone()

如果没有clone(),对mask_miss的某个通道赋值后,所有通道上的tensor都会变成1!

# Notice! expand does not allocate more memory but just make the tensor look as if you expanded it.
# You should call .clone() on the resulting tensor if you plan on modifying it
# https://discuss.pytorch.org/t/very-strange-behavior-change-one-element-of-a-tensor-will-influence-all-elements/41190

以上这篇解决Pytorch自定义层出现多Variable共享内存错误问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的jquery PyQuery库使用小结
May 13 Python
python中快速进行多个字符替换的方法小结
Dec 15 Python
requests和lxml实现爬虫的方法
Jun 11 Python
Python数据结构与算法之图结构(Graph)实例分析
Sep 05 Python
全面分析Python的优点和缺点
Feb 07 Python
利用Python如何生成便签图片详解
Jul 09 Python
python 反向输出字符串的方法
Jul 16 Python
python调用java的jar包方法
Dec 15 Python
python安装requests库的实例代码
Jun 25 Python
PyCharm 配置远程python解释器和在本地修改服务器代码
Jul 23 Python
Python多线程多进程实例对比解析
Mar 12 Python
基于tensorflow权重文件的解读
May 26 Python
Pytorch学习之torch用法----比较操作(Comparison Ops)
Jun 28 #Python
PyTorch的torch.cat用法
Jun 28 #Python
使用pytorch 筛选出一定范围的值
Jun 28 #Python
解析python 中/ 和 % 和 //(地板除)
Jun 28 #Python
pytorch 常用函数 max ,eq说明
Jun 28 #Python
浅谈pytorch中torch.max和F.softmax函数的维度解释
Jun 28 #Python
Python turtle库的画笔控制说明
Jun 28 #Python
You might like
function.inc.php超越php
2006/12/09 PHP
php连接mysql数据库代码
2009/03/10 PHP
php源码加密 仿微盾PHP加密专家(PHPCodeLock)
2010/05/06 PHP
ThinkPHP中处理表单中的注意事项
2014/11/22 PHP
使用纯php代码实现页面伪静态的方法
2015/07/25 PHP
Symfony2学习笔记之控制器用法详解
2016/03/17 PHP
示例详解Laravel重置密码代码重构
2016/08/10 PHP
Aster vs KG BO3 第一场2.19
2021/03/10 DOTA
JQuery SELECT单选模拟jQuery.select.js
2009/11/12 Javascript
JavaScript之HTMLCollection接口代码
2011/04/27 Javascript
基于jquery可配置循环左右滚动例子
2011/09/09 Javascript
深入理解javascript中defer的作用
2013/12/11 Javascript
JavaScript中的方法调用详细介绍
2014/12/30 Javascript
基于JS实现textarea中获取动态剩余字数的方法
2016/05/25 Javascript
浅谈javascript:两种注释,声明变量,定义函数
2016/09/29 Javascript
超全面的vue.js使用总结
2017/02/12 Javascript
详解基于webpack搭建react运行环境
2017/06/01 Javascript
jQuery 改变P标签文本值方法
2018/02/24 jQuery
使用Angular Cli如何创建Angular私有库详解
2019/01/30 Javascript
js实现上下左右键盘控制div移动
2020/01/16 Javascript
js与jquery获取input输入框中的值实例讲解
2020/02/27 jQuery
js 实现碰撞检测的示例
2020/10/28 Javascript
JavaScript仿京东轮播图效果
2021/02/25 Javascript
Python3中多线程编程的队列运作示例
2015/04/16 Python
PyQt5每天必学之滑块控件QSlider
2018/04/20 Python
Python生成一个迭代器的实操方法
2019/06/18 Python
python移位运算的实现
2019/07/15 Python
Django文件存储 默认存储系统解析
2019/08/02 Python
Python银行系统实战源码
2019/10/25 Python
python对文件的操作方法汇总
2020/02/28 Python
日本最大的眼镜购物网站:Oh My Glasses
2016/11/13 全球购物
MCAKE蛋糕官方网站:一直都是巴黎的味道
2018/02/06 全球购物
TripAdvisor斯洛伐克:阅读评论、比较价格和酒店预订
2018/04/25 全球购物
高三毕业生自我鉴定
2013/12/20 职场文书
大学文艺委员竞选稿
2015/11/19 职场文书
速龙x4-860k处理器相当于i几
2022/04/20 数码科技