解决Pytorch自定义层出现多Variable共享内存错误问题


Posted in Python onJune 28, 2020

错误信息:

RuntimeError: in-place operations can be only used on variables that don't share storage with any other variables, but detected that there are 4 objects sharing it

自动求导是很方便, 但是想想, 如果两个Variable共享内存, 再对这个共享的内存的数据进行修改, 就会引起错误!

一般是由于 inplace操作或是indexing或是转置. 这些都是共享内存的.

@staticmethod
 def backward(ctx, grad_output):
  ind_lst = ctx.ind_lst
  flag = ctx.flag

  c = grad_output.size(1)
  grad_former_all = grad_output[:, 0:c//3, :, :]
  grad_latter_all = grad_output[:, c//3: c*2//3, :, :]
  grad_swapped_all = grad_output[:, c*2//3:c, :, :]

  spatial_size = ctx.h * ctx.w

  W_mat_all = Variable(ctx.Tensor(ctx.bz, spatial_size, spatial_size).zero_())
  for idx in range(ctx.bz):
   W_mat = W_mat_all.select(0,idx)
   for cnt in range(spatial_size):
    indS = ind_lst[idx][cnt] 

    if flag[cnt] == 1:
     # 这里W_mat是W_mat_all通过select出来的, 他们共享内存.
     W_mat[cnt, indS] = 1

   W_mat_t = W_mat.t()

   grad_swapped_weighted = torch.mm(W_mat_t, grad_swapped_all[idx].view(c//3, -1).t())
   grad_swapped_weighted = grad_swapped_weighted.t().contiguous().view(1, c//3, ctx.h, ctx.w)
   grad_latter_all[idx] = torch.add(grad_latter_all[idx], grad_swapped_weighted.mul(ctx.triple_w))

由于 这里W_mat是W_mat_all通过select出来的, 他们共享内存. 所以当对这个共享的内存进行修改W_mat[cnt, indS] = 1, 就会出错. 此时我们可以通过clone()将W_mat和W_mat_all独立出来. 这样的话, 梯度也会通过 clone()操作将W_mat的梯度正确反传到W_mat_all中.

@staticmethod
 def backward(ctx, grad_output):
  ind_lst = ctx.ind_lst
  flag = ctx.flag

  c = grad_output.size(1)
  grad_former_all = grad_output[:, 0:c//3, :, :]
  grad_latter_all = grad_output[:, c//3: c*2//3, :, :]
  grad_swapped_all = grad_output[:, c*2//3:c, :, :]

  spatial_size = ctx.h * ctx.w

  W_mat_all = Variable(ctx.Tensor(ctx.bz, spatial_size, spatial_size).zero_())
  for idx in range(ctx.bz):
   # 这里使用clone了
   W_mat = W_mat_all.select(0,idx).clone()
   for cnt in range(spatial_size):
    indS = ind_lst[idx][cnt]

    if flag[cnt] == 1:
     W_mat[cnt, indS] = 1

   W_mat_t = W_mat.t()

   grad_swapped_weighted = torch.mm(W_mat_t, grad_swapped_all[idx].view(c//3, -1).t())
   grad_swapped_weighted = grad_swapped_weighted.t().contiguous().view(1, c//3, ctx.h, ctx.w)

   # 这句话删了不会出错, 加上就吹出错
   grad_latter_all[idx] = torch.add(grad_latter_all[idx], grad_swapped_weighted.mul(ctx.triple_w))

但是现在却出现 4个objects共享内存. 如果将最后一句话删掉, 那么则不会出错.

如果没有最后一句话, 我们看到

grad_swapped_weighted = torch.mm(W_mat_t, grad_swapped_all[idx].view(c//3, -1).t())

grad_swapped_weighted = grad_swapped_weighted.t().contiguous().view(1, c//3, ctx.h, ctx.w)

grad_swapped_weighted 一个新的Variable, 因此并没有和其他Variable共享内存, 所以不会出错. 但是最后一句话,

grad_latter_all[idx] = torch.add(grad_latter_all[idx], grad_swapped_weighted.mul(ctx.triple_w))

你可能会说, 不对啊, 修改grad_latter_all[idx]又没有创建新的Variable, 怎么会出错. 这是因为grad_latter_all和grad_output是共享内存的. 因为 grad_latter_all = grad_output[:, c//3: c*2//3, :, :], 所以这里的解决方案是:

@staticmethod
 def backward(ctx, grad_output):
  ind_lst = ctx.ind_lst
  flag = ctx.flag

  c = grad_output.size(1)
  grad_former_all = grad_output[:, 0:c//3, :, :]
  # 这两个后面修改值了, 所以也要加clone, 防止它们与grad_output共享内存
  grad_latter_all = grad_output[:, c//3: c*2//3, :, :].clone()
  grad_swapped_all = grad_output[:, c*2//3:c, :, :].clone()

  spatial_size = ctx.h * ctx.w

  W_mat_all = Variable(ctx.Tensor(ctx.bz, spatial_size, spatial_size).zero_())
  for idx in range(ctx.bz):
   W_mat = W_mat_all.select(0,idx).clone()
   for cnt in range(spatial_size):
    indS = ind_lst[idx][cnt]

    if flag[cnt] == 1:
     W_mat[cnt, indS] = 1

   W_mat_t = W_mat.t()

   grad_swapped_weighted = torch.mm(W_mat_t, grad_swapped_all[idx].view(c//3, -1).t())

   grad_swapped_weighted = grad_swapped_weighted.t().contiguous().view(1, c//3, ctx.h, ctx.w)
   grad_latter_all[idx] = torch.add(grad_latter_all[idx], grad_swapped_weighted.mul(ctx.triple_w))

  grad_input = torch.cat([grad_former_all, grad_latter_all], 1)

  return grad_input, None, None, None, None, None, None, None, None, None, None

补充知识:Pytorch 中 expand, expand_as是共享内存的,只是原始数据的一个视图 view

如下所示:

mask = mask_miss.expand_as(sxing).clone() # type: torch.Tensor
mask[:, :, -2, :, :] = 1 # except for person mask channel

为了避免对expand后对某个channel操作会影响原始tensor的全部元素,需要使用clone()

如果没有clone(),对mask_miss的某个通道赋值后,所有通道上的tensor都会变成1!

# Notice! expand does not allocate more memory but just make the tensor look as if you expanded it.
# You should call .clone() on the resulting tensor if you plan on modifying it
# https://discuss.pytorch.org/t/very-strange-behavior-change-one-element-of-a-tensor-will-influence-all-elements/41190

以上这篇解决Pytorch自定义层出现多Variable共享内存错误问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
精确查找PHP WEBSHELL木马的方法(1)
Apr 12 Python
python实现将英文单词表示的数字转换成阿拉伯数字的方法
Jul 02 Python
python监控文件或目录变化
Jun 07 Python
Python 通过URL打开图片实例详解
Jun 01 Python
Python实现Linux的find命令实例分享
Jun 04 Python
安装Python的教程-Windows
Jul 22 Python
基于python(urlparse)模板的使用方法总结
Oct 13 Python
python docx 中文字体设置的操作方法
May 08 Python
Python Tkinter 简单登录界面的实现
Jun 14 Python
浅谈python3 构造函数和析构函数
Mar 12 Python
TensorFlow保存TensorBoard图像操作
Jun 23 Python
Python高并发解决方案实现过程详解
Jul 31 Python
Pytorch学习之torch用法----比较操作(Comparison Ops)
Jun 28 #Python
PyTorch的torch.cat用法
Jun 28 #Python
使用pytorch 筛选出一定范围的值
Jun 28 #Python
解析python 中/ 和 % 和 //(地板除)
Jun 28 #Python
pytorch 常用函数 max ,eq说明
Jun 28 #Python
浅谈pytorch中torch.max和F.softmax函数的维度解释
Jun 28 #Python
Python turtle库的画笔控制说明
Jun 28 #Python
You might like
SONY ICF-SW55的电路分析
2021/03/02 无线电
PHP之uniqid()函数用法
2014/11/03 PHP
PHP生成树的方法
2015/07/28 PHP
PHP版本升级到7.x后wordpress的一些修改及wordpress技巧
2015/12/25 PHP
Laravel框架FormRequest中重写错误处理的方法
2019/02/18 PHP
判断页面是关闭还是刷新的js代码
2007/01/28 Javascript
IE与Firefox下javascript getyear年份的兼容性写法
2007/12/20 Javascript
ext监听事件方法[初级篇]
2008/04/27 Javascript
js 无提示关闭浏览器页面的代码
2010/03/09 Javascript
jquery对单选框,多选框,文本框等常见操作小结
2014/01/08 Javascript
JS动态创建DOM元素的方法
2015/06/09 Javascript
全面解析Bootstrap表单样式的使用
2016/09/09 Javascript
easyUI combobox实现联动效果
2017/01/17 Javascript
javascript 判断一个对象为数组的方法
2017/05/03 Javascript
JavaScript中AOP的实现与应用
2019/05/06 Javascript
three.js利用gpu选取物体并计算交点位置的方法示例
2019/11/25 Javascript
Vue(定时器)解决mounted不能获取到data中的数据问题
2020/07/30 Javascript
[56:13]DOTA2-DPC中国联赛定级赛 LBZS vs Phoenix BO3第一场 1月10日
2021/03/11 DOTA
Django imgareaselect手动剪切头像实现方法
2015/05/26 Python
django ajax json的实例代码
2018/05/29 Python
基于python实现聊天室程序
2018/07/27 Python
python爬取内容存入Excel实例
2019/02/20 Python
图文详解Django使用Pycharm连接MySQL数据库
2019/08/09 Python
关于Python 常用获取元素 Driver 总结
2019/11/24 Python
Tensorflow进行多维矩阵的拆分与拼接实例
2020/02/07 Python
python tkinter 设置窗口大小不可缩放实例
2020/03/04 Python
python模拟实现分发扑克牌
2020/04/22 Python
Anya Hindmarch官网:奢侈设计师手袋及配饰
2018/11/15 全球购物
希腊品牌鞋类销售网站:epapoutsia.gr
2020/03/18 全球购物
大专生自我鉴定范文
2013/10/01 职场文书
优秀求职信范文分享
2013/12/19 职场文书
模具专业求职信
2014/06/26 职场文书
群教班子对照检查材料
2014/08/26 职场文书
中秋节国旗下演讲稿
2014/09/05 职场文书
十一月早安语录:把心放轻,人生就是一朵自在的云
2019/11/04 职场文书
Python爬虫入门案例之爬取二手房源数据
2021/10/16 Python