解决Pytorch自定义层出现多Variable共享内存错误问题


Posted in Python onJune 28, 2020

错误信息:

RuntimeError: in-place operations can be only used on variables that don't share storage with any other variables, but detected that there are 4 objects sharing it

自动求导是很方便, 但是想想, 如果两个Variable共享内存, 再对这个共享的内存的数据进行修改, 就会引起错误!

一般是由于 inplace操作或是indexing或是转置. 这些都是共享内存的.

@staticmethod
 def backward(ctx, grad_output):
  ind_lst = ctx.ind_lst
  flag = ctx.flag

  c = grad_output.size(1)
  grad_former_all = grad_output[:, 0:c//3, :, :]
  grad_latter_all = grad_output[:, c//3: c*2//3, :, :]
  grad_swapped_all = grad_output[:, c*2//3:c, :, :]

  spatial_size = ctx.h * ctx.w

  W_mat_all = Variable(ctx.Tensor(ctx.bz, spatial_size, spatial_size).zero_())
  for idx in range(ctx.bz):
   W_mat = W_mat_all.select(0,idx)
   for cnt in range(spatial_size):
    indS = ind_lst[idx][cnt] 

    if flag[cnt] == 1:
     # 这里W_mat是W_mat_all通过select出来的, 他们共享内存.
     W_mat[cnt, indS] = 1

   W_mat_t = W_mat.t()

   grad_swapped_weighted = torch.mm(W_mat_t, grad_swapped_all[idx].view(c//3, -1).t())
   grad_swapped_weighted = grad_swapped_weighted.t().contiguous().view(1, c//3, ctx.h, ctx.w)
   grad_latter_all[idx] = torch.add(grad_latter_all[idx], grad_swapped_weighted.mul(ctx.triple_w))

由于 这里W_mat是W_mat_all通过select出来的, 他们共享内存. 所以当对这个共享的内存进行修改W_mat[cnt, indS] = 1, 就会出错. 此时我们可以通过clone()将W_mat和W_mat_all独立出来. 这样的话, 梯度也会通过 clone()操作将W_mat的梯度正确反传到W_mat_all中.

@staticmethod
 def backward(ctx, grad_output):
  ind_lst = ctx.ind_lst
  flag = ctx.flag

  c = grad_output.size(1)
  grad_former_all = grad_output[:, 0:c//3, :, :]
  grad_latter_all = grad_output[:, c//3: c*2//3, :, :]
  grad_swapped_all = grad_output[:, c*2//3:c, :, :]

  spatial_size = ctx.h * ctx.w

  W_mat_all = Variable(ctx.Tensor(ctx.bz, spatial_size, spatial_size).zero_())
  for idx in range(ctx.bz):
   # 这里使用clone了
   W_mat = W_mat_all.select(0,idx).clone()
   for cnt in range(spatial_size):
    indS = ind_lst[idx][cnt]

    if flag[cnt] == 1:
     W_mat[cnt, indS] = 1

   W_mat_t = W_mat.t()

   grad_swapped_weighted = torch.mm(W_mat_t, grad_swapped_all[idx].view(c//3, -1).t())
   grad_swapped_weighted = grad_swapped_weighted.t().contiguous().view(1, c//3, ctx.h, ctx.w)

   # 这句话删了不会出错, 加上就吹出错
   grad_latter_all[idx] = torch.add(grad_latter_all[idx], grad_swapped_weighted.mul(ctx.triple_w))

但是现在却出现 4个objects共享内存. 如果将最后一句话删掉, 那么则不会出错.

如果没有最后一句话, 我们看到

grad_swapped_weighted = torch.mm(W_mat_t, grad_swapped_all[idx].view(c//3, -1).t())

grad_swapped_weighted = grad_swapped_weighted.t().contiguous().view(1, c//3, ctx.h, ctx.w)

grad_swapped_weighted 一个新的Variable, 因此并没有和其他Variable共享内存, 所以不会出错. 但是最后一句话,

grad_latter_all[idx] = torch.add(grad_latter_all[idx], grad_swapped_weighted.mul(ctx.triple_w))

你可能会说, 不对啊, 修改grad_latter_all[idx]又没有创建新的Variable, 怎么会出错. 这是因为grad_latter_all和grad_output是共享内存的. 因为 grad_latter_all = grad_output[:, c//3: c*2//3, :, :], 所以这里的解决方案是:

@staticmethod
 def backward(ctx, grad_output):
  ind_lst = ctx.ind_lst
  flag = ctx.flag

  c = grad_output.size(1)
  grad_former_all = grad_output[:, 0:c//3, :, :]
  # 这两个后面修改值了, 所以也要加clone, 防止它们与grad_output共享内存
  grad_latter_all = grad_output[:, c//3: c*2//3, :, :].clone()
  grad_swapped_all = grad_output[:, c*2//3:c, :, :].clone()

  spatial_size = ctx.h * ctx.w

  W_mat_all = Variable(ctx.Tensor(ctx.bz, spatial_size, spatial_size).zero_())
  for idx in range(ctx.bz):
   W_mat = W_mat_all.select(0,idx).clone()
   for cnt in range(spatial_size):
    indS = ind_lst[idx][cnt]

    if flag[cnt] == 1:
     W_mat[cnt, indS] = 1

   W_mat_t = W_mat.t()

   grad_swapped_weighted = torch.mm(W_mat_t, grad_swapped_all[idx].view(c//3, -1).t())

   grad_swapped_weighted = grad_swapped_weighted.t().contiguous().view(1, c//3, ctx.h, ctx.w)
   grad_latter_all[idx] = torch.add(grad_latter_all[idx], grad_swapped_weighted.mul(ctx.triple_w))

  grad_input = torch.cat([grad_former_all, grad_latter_all], 1)

  return grad_input, None, None, None, None, None, None, None, None, None, None

补充知识:Pytorch 中 expand, expand_as是共享内存的,只是原始数据的一个视图 view

如下所示:

mask = mask_miss.expand_as(sxing).clone() # type: torch.Tensor
mask[:, :, -2, :, :] = 1 # except for person mask channel

为了避免对expand后对某个channel操作会影响原始tensor的全部元素,需要使用clone()

如果没有clone(),对mask_miss的某个通道赋值后,所有通道上的tensor都会变成1!

# Notice! expand does not allocate more memory but just make the tensor look as if you expanded it.
# You should call .clone() on the resulting tensor if you plan on modifying it
# https://discuss.pytorch.org/t/very-strange-behavior-change-one-element-of-a-tensor-will-influence-all-elements/41190

以上这篇解决Pytorch自定义层出现多Variable共享内存错误问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的super用法详解
May 28 Python
python字典基本操作实例分析
Jul 11 Python
python实现汉诺塔方法汇总
Jul 25 Python
python 打印对象的所有属性值的方法
Sep 11 Python
python实现闹钟定时播放音乐功能
Jan 25 Python
Python自动发送邮件的方法实例总结
Dec 08 Python
django自带serializers序列化返回指定字段的方法
Aug 21 Python
关于Python3 lambda函数的深入浅出
Nov 27 Python
Pytorch 定义MyDatasets实现多通道分别输入不同数据方式
Jan 15 Python
如何解决cmd运行python提示不是内部命令
Jul 01 Python
python mongo 向数据中的数组类型新增数据操作
Dec 05 Python
Python+Appium新手教程
Apr 17 Python
Pytorch学习之torch用法----比较操作(Comparison Ops)
Jun 28 #Python
PyTorch的torch.cat用法
Jun 28 #Python
使用pytorch 筛选出一定范围的值
Jun 28 #Python
解析python 中/ 和 % 和 //(地板除)
Jun 28 #Python
pytorch 常用函数 max ,eq说明
Jun 28 #Python
浅谈pytorch中torch.max和F.softmax函数的维度解释
Jun 28 #Python
Python turtle库的画笔控制说明
Jun 28 #Python
You might like
PHP SEO优化之URL优化方法
2011/04/21 PHP
php自定义函数实现JS的escape的方法示例
2016/07/07 PHP
PHP与jquery实时显示网站在线人数实例详解
2016/12/02 PHP
PHP中字符串长度的截取用法示例
2017/01/12 PHP
js 高效去除数组重复元素示例代码
2013/12/19 Javascript
js document.write()使用介绍
2014/02/21 Javascript
js行号显示的文本框实现效果(兼容多种浏览器 )
2015/10/23 Javascript
AngularJS入门教程之表格实例详解
2016/07/27 Javascript
详解jQuery插件开发方式
2016/11/22 Javascript
详解Javascript数据类型的转换规则
2016/12/12 Javascript
Angularjs中使用layDate日期控件示例
2017/01/11 Javascript
bootstrap datetimepicker日期插件超详细使用方法介绍
2017/02/23 Javascript
纯js的右下角弹窗实例
2017/03/12 Javascript
五步轻松实现zTree的使用
2017/11/01 Javascript
详解Vue 全局引入bass.scss 处理方案
2018/03/26 Javascript
vue车牌号校验和银行校验实战
2019/01/23 Javascript
layui在form表单页面通过Validform加入简单验证的方法
2019/09/06 Javascript
ES2020系列之空值合并运算符 '??'
2020/07/22 Javascript
JavaScript实现点击切换验证码及校验
2021/01/10 Javascript
跟老齐学Python之从格式化表达式到方法
2014/09/28 Python
Python3安装Pymongo详细步骤
2017/05/26 Python
Python爬虫实例_城市公交网络站点数据的爬取方法
2018/01/10 Python
从请求到响应过程中django都做了哪些处理
2018/08/01 Python
浅谈Pycharm最有必要改的几个默认设置项
2020/02/14 Python
Python *args和**kwargs用法实例解析
2020/03/02 Python
python绘图pyecharts+pandas的使用详解
2020/12/13 Python
美国知名的家庭连锁百货商店:Boscov’s
2017/07/27 全球购物
美国便宜的横幅和标志印刷在线:Best of Signs
2019/05/29 全球购物
美国最大的在线生存商店:Survival Frog
2020/12/13 全球购物
文字自荐书范文
2014/02/10 职场文书
个人四风问题对照检查材料
2014/10/01 职场文书
2014年大学班级工作总结
2014/11/14 职场文书
2014年结对帮扶工作总结
2014/12/17 职场文书
世界卫生日宣传活动总结
2015/02/09 职场文书
python代码实现备忘录案例讲解
2021/07/26 Python
sql注入报错之注入原理实例解析
2022/06/10 MySQL