解决Pytorch自定义层出现多Variable共享内存错误问题


Posted in Python onJune 28, 2020

错误信息:

RuntimeError: in-place operations can be only used on variables that don't share storage with any other variables, but detected that there are 4 objects sharing it

自动求导是很方便, 但是想想, 如果两个Variable共享内存, 再对这个共享的内存的数据进行修改, 就会引起错误!

一般是由于 inplace操作或是indexing或是转置. 这些都是共享内存的.

@staticmethod
 def backward(ctx, grad_output):
  ind_lst = ctx.ind_lst
  flag = ctx.flag

  c = grad_output.size(1)
  grad_former_all = grad_output[:, 0:c//3, :, :]
  grad_latter_all = grad_output[:, c//3: c*2//3, :, :]
  grad_swapped_all = grad_output[:, c*2//3:c, :, :]

  spatial_size = ctx.h * ctx.w

  W_mat_all = Variable(ctx.Tensor(ctx.bz, spatial_size, spatial_size).zero_())
  for idx in range(ctx.bz):
   W_mat = W_mat_all.select(0,idx)
   for cnt in range(spatial_size):
    indS = ind_lst[idx][cnt] 

    if flag[cnt] == 1:
     # 这里W_mat是W_mat_all通过select出来的, 他们共享内存.
     W_mat[cnt, indS] = 1

   W_mat_t = W_mat.t()

   grad_swapped_weighted = torch.mm(W_mat_t, grad_swapped_all[idx].view(c//3, -1).t())
   grad_swapped_weighted = grad_swapped_weighted.t().contiguous().view(1, c//3, ctx.h, ctx.w)
   grad_latter_all[idx] = torch.add(grad_latter_all[idx], grad_swapped_weighted.mul(ctx.triple_w))

由于 这里W_mat是W_mat_all通过select出来的, 他们共享内存. 所以当对这个共享的内存进行修改W_mat[cnt, indS] = 1, 就会出错. 此时我们可以通过clone()将W_mat和W_mat_all独立出来. 这样的话, 梯度也会通过 clone()操作将W_mat的梯度正确反传到W_mat_all中.

@staticmethod
 def backward(ctx, grad_output):
  ind_lst = ctx.ind_lst
  flag = ctx.flag

  c = grad_output.size(1)
  grad_former_all = grad_output[:, 0:c//3, :, :]
  grad_latter_all = grad_output[:, c//3: c*2//3, :, :]
  grad_swapped_all = grad_output[:, c*2//3:c, :, :]

  spatial_size = ctx.h * ctx.w

  W_mat_all = Variable(ctx.Tensor(ctx.bz, spatial_size, spatial_size).zero_())
  for idx in range(ctx.bz):
   # 这里使用clone了
   W_mat = W_mat_all.select(0,idx).clone()
   for cnt in range(spatial_size):
    indS = ind_lst[idx][cnt]

    if flag[cnt] == 1:
     W_mat[cnt, indS] = 1

   W_mat_t = W_mat.t()

   grad_swapped_weighted = torch.mm(W_mat_t, grad_swapped_all[idx].view(c//3, -1).t())
   grad_swapped_weighted = grad_swapped_weighted.t().contiguous().view(1, c//3, ctx.h, ctx.w)

   # 这句话删了不会出错, 加上就吹出错
   grad_latter_all[idx] = torch.add(grad_latter_all[idx], grad_swapped_weighted.mul(ctx.triple_w))

但是现在却出现 4个objects共享内存. 如果将最后一句话删掉, 那么则不会出错.

如果没有最后一句话, 我们看到

grad_swapped_weighted = torch.mm(W_mat_t, grad_swapped_all[idx].view(c//3, -1).t())

grad_swapped_weighted = grad_swapped_weighted.t().contiguous().view(1, c//3, ctx.h, ctx.w)

grad_swapped_weighted 一个新的Variable, 因此并没有和其他Variable共享内存, 所以不会出错. 但是最后一句话,

grad_latter_all[idx] = torch.add(grad_latter_all[idx], grad_swapped_weighted.mul(ctx.triple_w))

你可能会说, 不对啊, 修改grad_latter_all[idx]又没有创建新的Variable, 怎么会出错. 这是因为grad_latter_all和grad_output是共享内存的. 因为 grad_latter_all = grad_output[:, c//3: c*2//3, :, :], 所以这里的解决方案是:

@staticmethod
 def backward(ctx, grad_output):
  ind_lst = ctx.ind_lst
  flag = ctx.flag

  c = grad_output.size(1)
  grad_former_all = grad_output[:, 0:c//3, :, :]
  # 这两个后面修改值了, 所以也要加clone, 防止它们与grad_output共享内存
  grad_latter_all = grad_output[:, c//3: c*2//3, :, :].clone()
  grad_swapped_all = grad_output[:, c*2//3:c, :, :].clone()

  spatial_size = ctx.h * ctx.w

  W_mat_all = Variable(ctx.Tensor(ctx.bz, spatial_size, spatial_size).zero_())
  for idx in range(ctx.bz):
   W_mat = W_mat_all.select(0,idx).clone()
   for cnt in range(spatial_size):
    indS = ind_lst[idx][cnt]

    if flag[cnt] == 1:
     W_mat[cnt, indS] = 1

   W_mat_t = W_mat.t()

   grad_swapped_weighted = torch.mm(W_mat_t, grad_swapped_all[idx].view(c//3, -1).t())

   grad_swapped_weighted = grad_swapped_weighted.t().contiguous().view(1, c//3, ctx.h, ctx.w)
   grad_latter_all[idx] = torch.add(grad_latter_all[idx], grad_swapped_weighted.mul(ctx.triple_w))

  grad_input = torch.cat([grad_former_all, grad_latter_all], 1)

  return grad_input, None, None, None, None, None, None, None, None, None, None

补充知识:Pytorch 中 expand, expand_as是共享内存的,只是原始数据的一个视图 view

如下所示:

mask = mask_miss.expand_as(sxing).clone() # type: torch.Tensor
mask[:, :, -2, :, :] = 1 # except for person mask channel

为了避免对expand后对某个channel操作会影响原始tensor的全部元素,需要使用clone()

如果没有clone(),对mask_miss的某个通道赋值后,所有通道上的tensor都会变成1!

# Notice! expand does not allocate more memory but just make the tensor look as if you expanded it.
# You should call .clone() on the resulting tensor if you plan on modifying it
# https://discuss.pytorch.org/t/very-strange-behavior-change-one-element-of-a-tensor-will-influence-all-elements/41190

以上这篇解决Pytorch自定义层出现多Variable共享内存错误问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python爬虫包 BeautifulSoup  递归抓取实例详解
Jan 28 Python
Python编程实现两个文件夹里文件的对比功能示例【包含内容的对比】
Jun 20 Python
Python PyQt5实现的简易计算器功能示例
Aug 23 Python
Python实现桶排序与快速排序算法结合应用示例
Nov 22 Python
Python3实现发送QQ邮件功能(html)
Dec 15 Python
Python cookbook(数据结构与算法)让字典保持有序的方法
Feb 18 Python
Python读写zip压缩文件的方法
Aug 29 Python
Django之无名分组和有名分组的实现
Apr 16 Python
python将字符串list写入excel和txt的实例
Jul 20 Python
关于Python Tkinter Button控件command传参问题的解决方式
Mar 04 Python
使用python实现飞机大战游戏
Mar 23 Python
Jupyter 无法下载文件夹如何实现曲线救国
Apr 22 Python
Pytorch学习之torch用法----比较操作(Comparison Ops)
Jun 28 #Python
PyTorch的torch.cat用法
Jun 28 #Python
使用pytorch 筛选出一定范围的值
Jun 28 #Python
解析python 中/ 和 % 和 //(地板除)
Jun 28 #Python
pytorch 常用函数 max ,eq说明
Jun 28 #Python
浅谈pytorch中torch.max和F.softmax函数的维度解释
Jun 28 #Python
Python turtle库的画笔控制说明
Jun 28 #Python
You might like
php 上传文件类型判断函数(避免上传漏洞 )
2010/06/08 PHP
php设计模式 Strategy(策略模式)
2011/06/26 PHP
php5.4以下版本json不支持不转义内容中文的解决方法
2015/01/13 PHP
PHP利用Socket获取网站的SSL证书与公钥
2017/06/18 PHP
laravel5实现微信第三方登录功能
2018/12/06 PHP
用JS实现一个页面多个css样式实现
2008/05/29 Javascript
刷新页面实现方式总结(HTML,ASP,JS)
2008/11/13 Javascript
html中的input标签的checked属性jquery判断代码
2012/09/19 Javascript
jQuery修改li下的样式以及li下的img的src的值的方法
2014/11/02 Javascript
Yii2使用Bootbox插件实现自定义弹窗
2015/04/02 Javascript
JavaScript中字符串(string)转json的2种方法
2015/06/25 Javascript
javascript日期处理函数,性能优化批处理
2015/09/06 Javascript
深入理解ECMAScript的几个关键语句
2016/06/01 Javascript
十大热门的JavaScript框架和库
2017/03/21 Javascript
详谈angularjs中路由页面强制更新的问题
2017/04/24 Javascript
微信小程序-滚动消息通知的实例代码
2017/08/03 Javascript
微信小程序实现点击按钮移动view标签的位置功能示例【附demo源码下载】
2017/12/06 Javascript
vue.js中$set与数组更新方法
2018/03/08 Javascript
JavaScript canvas绘制折线图
2020/02/18 Javascript
javascript实现下拉菜单效果
2021/02/09 Javascript
[03:49]DOTA2英雄基础教程 光之守卫
2014/01/14 DOTA
[01:32]完美世界DOTA2联赛10月29日精彩集锦
2020/10/30 DOTA
python实现音乐下载的统计
2018/06/20 Python
数据清洗--DataFrame中的空值处理方法
2018/07/03 Python
浅谈pandas用groupby后对层级索引levels的处理方法
2018/11/06 Python
End Clothing美国站:英国男士潮牌商城
2018/04/20 全球购物
中英文自我评价常用句型
2013/12/19 职场文书
汉语言文学职业规划
2014/02/14 职场文书
《再见了,亲人》教学反思
2014/02/26 职场文书
投标承诺书范本
2014/03/27 职场文书
党员自我评议对照检查材料
2014/09/27 职场文书
2015年班主任个人工作总结
2015/03/31 职场文书
总经理致辞
2015/07/29 职场文书
2016年师德学习心得体会
2016/01/12 职场文书
HR必备:销售经理聘用合同范本
2019/08/21 职场文书
SQL Server查询某个字段在哪些表中存在
2022/03/03 SQL Server