解决Pytorch自定义层出现多Variable共享内存错误问题


Posted in Python onJune 28, 2020

错误信息:

RuntimeError: in-place operations can be only used on variables that don't share storage with any other variables, but detected that there are 4 objects sharing it

自动求导是很方便, 但是想想, 如果两个Variable共享内存, 再对这个共享的内存的数据进行修改, 就会引起错误!

一般是由于 inplace操作或是indexing或是转置. 这些都是共享内存的.

@staticmethod
 def backward(ctx, grad_output):
  ind_lst = ctx.ind_lst
  flag = ctx.flag

  c = grad_output.size(1)
  grad_former_all = grad_output[:, 0:c//3, :, :]
  grad_latter_all = grad_output[:, c//3: c*2//3, :, :]
  grad_swapped_all = grad_output[:, c*2//3:c, :, :]

  spatial_size = ctx.h * ctx.w

  W_mat_all = Variable(ctx.Tensor(ctx.bz, spatial_size, spatial_size).zero_())
  for idx in range(ctx.bz):
   W_mat = W_mat_all.select(0,idx)
   for cnt in range(spatial_size):
    indS = ind_lst[idx][cnt] 

    if flag[cnt] == 1:
     # 这里W_mat是W_mat_all通过select出来的, 他们共享内存.
     W_mat[cnt, indS] = 1

   W_mat_t = W_mat.t()

   grad_swapped_weighted = torch.mm(W_mat_t, grad_swapped_all[idx].view(c//3, -1).t())
   grad_swapped_weighted = grad_swapped_weighted.t().contiguous().view(1, c//3, ctx.h, ctx.w)
   grad_latter_all[idx] = torch.add(grad_latter_all[idx], grad_swapped_weighted.mul(ctx.triple_w))

由于 这里W_mat是W_mat_all通过select出来的, 他们共享内存. 所以当对这个共享的内存进行修改W_mat[cnt, indS] = 1, 就会出错. 此时我们可以通过clone()将W_mat和W_mat_all独立出来. 这样的话, 梯度也会通过 clone()操作将W_mat的梯度正确反传到W_mat_all中.

@staticmethod
 def backward(ctx, grad_output):
  ind_lst = ctx.ind_lst
  flag = ctx.flag

  c = grad_output.size(1)
  grad_former_all = grad_output[:, 0:c//3, :, :]
  grad_latter_all = grad_output[:, c//3: c*2//3, :, :]
  grad_swapped_all = grad_output[:, c*2//3:c, :, :]

  spatial_size = ctx.h * ctx.w

  W_mat_all = Variable(ctx.Tensor(ctx.bz, spatial_size, spatial_size).zero_())
  for idx in range(ctx.bz):
   # 这里使用clone了
   W_mat = W_mat_all.select(0,idx).clone()
   for cnt in range(spatial_size):
    indS = ind_lst[idx][cnt]

    if flag[cnt] == 1:
     W_mat[cnt, indS] = 1

   W_mat_t = W_mat.t()

   grad_swapped_weighted = torch.mm(W_mat_t, grad_swapped_all[idx].view(c//3, -1).t())
   grad_swapped_weighted = grad_swapped_weighted.t().contiguous().view(1, c//3, ctx.h, ctx.w)

   # 这句话删了不会出错, 加上就吹出错
   grad_latter_all[idx] = torch.add(grad_latter_all[idx], grad_swapped_weighted.mul(ctx.triple_w))

但是现在却出现 4个objects共享内存. 如果将最后一句话删掉, 那么则不会出错.

如果没有最后一句话, 我们看到

grad_swapped_weighted = torch.mm(W_mat_t, grad_swapped_all[idx].view(c//3, -1).t())

grad_swapped_weighted = grad_swapped_weighted.t().contiguous().view(1, c//3, ctx.h, ctx.w)

grad_swapped_weighted 一个新的Variable, 因此并没有和其他Variable共享内存, 所以不会出错. 但是最后一句话,

grad_latter_all[idx] = torch.add(grad_latter_all[idx], grad_swapped_weighted.mul(ctx.triple_w))

你可能会说, 不对啊, 修改grad_latter_all[idx]又没有创建新的Variable, 怎么会出错. 这是因为grad_latter_all和grad_output是共享内存的. 因为 grad_latter_all = grad_output[:, c//3: c*2//3, :, :], 所以这里的解决方案是:

@staticmethod
 def backward(ctx, grad_output):
  ind_lst = ctx.ind_lst
  flag = ctx.flag

  c = grad_output.size(1)
  grad_former_all = grad_output[:, 0:c//3, :, :]
  # 这两个后面修改值了, 所以也要加clone, 防止它们与grad_output共享内存
  grad_latter_all = grad_output[:, c//3: c*2//3, :, :].clone()
  grad_swapped_all = grad_output[:, c*2//3:c, :, :].clone()

  spatial_size = ctx.h * ctx.w

  W_mat_all = Variable(ctx.Tensor(ctx.bz, spatial_size, spatial_size).zero_())
  for idx in range(ctx.bz):
   W_mat = W_mat_all.select(0,idx).clone()
   for cnt in range(spatial_size):
    indS = ind_lst[idx][cnt]

    if flag[cnt] == 1:
     W_mat[cnt, indS] = 1

   W_mat_t = W_mat.t()

   grad_swapped_weighted = torch.mm(W_mat_t, grad_swapped_all[idx].view(c//3, -1).t())

   grad_swapped_weighted = grad_swapped_weighted.t().contiguous().view(1, c//3, ctx.h, ctx.w)
   grad_latter_all[idx] = torch.add(grad_latter_all[idx], grad_swapped_weighted.mul(ctx.triple_w))

  grad_input = torch.cat([grad_former_all, grad_latter_all], 1)

  return grad_input, None, None, None, None, None, None, None, None, None, None

补充知识:Pytorch 中 expand, expand_as是共享内存的,只是原始数据的一个视图 view

如下所示:

mask = mask_miss.expand_as(sxing).clone() # type: torch.Tensor
mask[:, :, -2, :, :] = 1 # except for person mask channel

为了避免对expand后对某个channel操作会影响原始tensor的全部元素,需要使用clone()

如果没有clone(),对mask_miss的某个通道赋值后,所有通道上的tensor都会变成1!

# Notice! expand does not allocate more memory but just make the tensor look as if you expanded it.
# You should call .clone() on the resulting tensor if you plan on modifying it
# https://discuss.pytorch.org/t/very-strange-behavior-change-one-element-of-a-tensor-will-influence-all-elements/41190

以上这篇解决Pytorch自定义层出现多Variable共享内存错误问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python调用C语言开发的共享库方法实例
Mar 18 Python
Python中每次处理一个字符的5种方法
May 21 Python
Python中使用Queue和Condition进行线程同步的方法
Jan 19 Python
python实现判断一个字符串是否是合法IP地址的示例
Jun 04 Python
pandas ix &iloc &loc的区别
Jan 10 Python
对Python Pexpect 模块的使用说明详解
Feb 14 Python
详解Python中的内建函数,可迭代对象,迭代器
Apr 29 Python
pygame实现俄罗斯方块游戏(基础篇2)
Oct 29 Python
python新式类和经典类的区别实例分析
Mar 23 Python
Python使用pyexecjs代码案例解析
Jul 13 Python
python实现模拟器爬取抖音评论数据的示例代码
Jan 06 Python
python opencv通过按键采集图片源码
May 20 Python
Pytorch学习之torch用法----比较操作(Comparison Ops)
Jun 28 #Python
PyTorch的torch.cat用法
Jun 28 #Python
使用pytorch 筛选出一定范围的值
Jun 28 #Python
解析python 中/ 和 % 和 //(地板除)
Jun 28 #Python
pytorch 常用函数 max ,eq说明
Jun 28 #Python
浅谈pytorch中torch.max和F.softmax函数的维度解释
Jun 28 #Python
Python turtle库的画笔控制说明
Jun 28 #Python
You might like
PHP中的命名空间相关概念浅析
2015/01/22 PHP
PHP 文件锁与进程锁的使用示例
2017/08/07 PHP
一段好玩的JavaScript代码
2006/12/01 Javascript
利用js获取服务器时间的两个简单方法
2010/01/08 Javascript
javascript页面倒计时实例
2015/07/25 Javascript
Javascript设计模式理论与编程实战之简单工厂模式
2015/11/03 Javascript
js实现的星星评分功能函数
2015/12/09 Javascript
JavaScript setTimeout使用闭包功能实现定时打印数值
2015/12/18 Javascript
详解JavaScript正则表达式之分组匹配及反向引用
2016/03/09 Javascript
vue中各种通信传值方式总结
2019/02/14 Javascript
vue项目从node8.x升级到12.x后的问题解决
2019/10/25 Javascript
vue实现短信验证码输入框
2020/04/17 Javascript
ES6函数和数组用法实例分析
2020/05/23 Javascript
Python进程间通信用法实例
2015/06/04 Python
python利用不到一百行代码实现一个小siri
2017/03/02 Python
Python实现树莓派WiFi断线自动重连的实例代码
2017/03/16 Python
opencv python 基于KNN的手写体识别的实例
2018/08/03 Python
python使用epoll实现服务端的方法
2018/10/16 Python
django+tornado实现实时查看远程日志的方法
2019/08/12 Python
python检测服务器端口代码实例
2019/08/31 Python
win10系统下python3安装及pip换源和使用教程
2020/01/06 Python
python全局变量引用与修改过程解析
2020/01/07 Python
Python 实现进度条的六种方式
2021/01/06 Python
美国牙科折扣计划:DentalPlans.com
2019/08/26 全球购物
自荐信要包含哪些内容
2013/11/06 职场文书
宗教学大学生职业生涯规划范文
2014/02/08 职场文书
经销商订货会主持词
2014/03/27 职场文书
优秀求职信
2014/05/29 职场文书
考试作弊万能检讨书
2014/10/19 职场文书
关爱留守儿童主题班会
2015/08/13 职场文书
银行客户经理培训心得体会
2016/01/09 职场文书
2017年寒假社区服务活动总结
2016/04/06 职场文书
Nginx解决403 forbidden的完整步骤
2021/04/01 Servers
新手必备Python开发环境搭建教程
2021/05/28 Python
springboot临时文件存储目录配置方式
2021/07/01 Java/Android
MySQL性能指标TPS+QPS+IOPS压测
2022/08/05 MySQL