解决Pytorch自定义层出现多Variable共享内存错误问题


Posted in Python onJune 28, 2020

错误信息:

RuntimeError: in-place operations can be only used on variables that don't share storage with any other variables, but detected that there are 4 objects sharing it

自动求导是很方便, 但是想想, 如果两个Variable共享内存, 再对这个共享的内存的数据进行修改, 就会引起错误!

一般是由于 inplace操作或是indexing或是转置. 这些都是共享内存的.

@staticmethod
 def backward(ctx, grad_output):
  ind_lst = ctx.ind_lst
  flag = ctx.flag

  c = grad_output.size(1)
  grad_former_all = grad_output[:, 0:c//3, :, :]
  grad_latter_all = grad_output[:, c//3: c*2//3, :, :]
  grad_swapped_all = grad_output[:, c*2//3:c, :, :]

  spatial_size = ctx.h * ctx.w

  W_mat_all = Variable(ctx.Tensor(ctx.bz, spatial_size, spatial_size).zero_())
  for idx in range(ctx.bz):
   W_mat = W_mat_all.select(0,idx)
   for cnt in range(spatial_size):
    indS = ind_lst[idx][cnt] 

    if flag[cnt] == 1:
     # 这里W_mat是W_mat_all通过select出来的, 他们共享内存.
     W_mat[cnt, indS] = 1

   W_mat_t = W_mat.t()

   grad_swapped_weighted = torch.mm(W_mat_t, grad_swapped_all[idx].view(c//3, -1).t())
   grad_swapped_weighted = grad_swapped_weighted.t().contiguous().view(1, c//3, ctx.h, ctx.w)
   grad_latter_all[idx] = torch.add(grad_latter_all[idx], grad_swapped_weighted.mul(ctx.triple_w))

由于 这里W_mat是W_mat_all通过select出来的, 他们共享内存. 所以当对这个共享的内存进行修改W_mat[cnt, indS] = 1, 就会出错. 此时我们可以通过clone()将W_mat和W_mat_all独立出来. 这样的话, 梯度也会通过 clone()操作将W_mat的梯度正确反传到W_mat_all中.

@staticmethod
 def backward(ctx, grad_output):
  ind_lst = ctx.ind_lst
  flag = ctx.flag

  c = grad_output.size(1)
  grad_former_all = grad_output[:, 0:c//3, :, :]
  grad_latter_all = grad_output[:, c//3: c*2//3, :, :]
  grad_swapped_all = grad_output[:, c*2//3:c, :, :]

  spatial_size = ctx.h * ctx.w

  W_mat_all = Variable(ctx.Tensor(ctx.bz, spatial_size, spatial_size).zero_())
  for idx in range(ctx.bz):
   # 这里使用clone了
   W_mat = W_mat_all.select(0,idx).clone()
   for cnt in range(spatial_size):
    indS = ind_lst[idx][cnt]

    if flag[cnt] == 1:
     W_mat[cnt, indS] = 1

   W_mat_t = W_mat.t()

   grad_swapped_weighted = torch.mm(W_mat_t, grad_swapped_all[idx].view(c//3, -1).t())
   grad_swapped_weighted = grad_swapped_weighted.t().contiguous().view(1, c//3, ctx.h, ctx.w)

   # 这句话删了不会出错, 加上就吹出错
   grad_latter_all[idx] = torch.add(grad_latter_all[idx], grad_swapped_weighted.mul(ctx.triple_w))

但是现在却出现 4个objects共享内存. 如果将最后一句话删掉, 那么则不会出错.

如果没有最后一句话, 我们看到

grad_swapped_weighted = torch.mm(W_mat_t, grad_swapped_all[idx].view(c//3, -1).t())

grad_swapped_weighted = grad_swapped_weighted.t().contiguous().view(1, c//3, ctx.h, ctx.w)

grad_swapped_weighted 一个新的Variable, 因此并没有和其他Variable共享内存, 所以不会出错. 但是最后一句话,

grad_latter_all[idx] = torch.add(grad_latter_all[idx], grad_swapped_weighted.mul(ctx.triple_w))

你可能会说, 不对啊, 修改grad_latter_all[idx]又没有创建新的Variable, 怎么会出错. 这是因为grad_latter_all和grad_output是共享内存的. 因为 grad_latter_all = grad_output[:, c//3: c*2//3, :, :], 所以这里的解决方案是:

@staticmethod
 def backward(ctx, grad_output):
  ind_lst = ctx.ind_lst
  flag = ctx.flag

  c = grad_output.size(1)
  grad_former_all = grad_output[:, 0:c//3, :, :]
  # 这两个后面修改值了, 所以也要加clone, 防止它们与grad_output共享内存
  grad_latter_all = grad_output[:, c//3: c*2//3, :, :].clone()
  grad_swapped_all = grad_output[:, c*2//3:c, :, :].clone()

  spatial_size = ctx.h * ctx.w

  W_mat_all = Variable(ctx.Tensor(ctx.bz, spatial_size, spatial_size).zero_())
  for idx in range(ctx.bz):
   W_mat = W_mat_all.select(0,idx).clone()
   for cnt in range(spatial_size):
    indS = ind_lst[idx][cnt]

    if flag[cnt] == 1:
     W_mat[cnt, indS] = 1

   W_mat_t = W_mat.t()

   grad_swapped_weighted = torch.mm(W_mat_t, grad_swapped_all[idx].view(c//3, -1).t())

   grad_swapped_weighted = grad_swapped_weighted.t().contiguous().view(1, c//3, ctx.h, ctx.w)
   grad_latter_all[idx] = torch.add(grad_latter_all[idx], grad_swapped_weighted.mul(ctx.triple_w))

  grad_input = torch.cat([grad_former_all, grad_latter_all], 1)

  return grad_input, None, None, None, None, None, None, None, None, None, None

补充知识:Pytorch 中 expand, expand_as是共享内存的,只是原始数据的一个视图 view

如下所示:

mask = mask_miss.expand_as(sxing).clone() # type: torch.Tensor
mask[:, :, -2, :, :] = 1 # except for person mask channel

为了避免对expand后对某个channel操作会影响原始tensor的全部元素,需要使用clone()

如果没有clone(),对mask_miss的某个通道赋值后,所有通道上的tensor都会变成1!

# Notice! expand does not allocate more memory but just make the tensor look as if you expanded it.
# You should call .clone() on the resulting tensor if you plan on modifying it
# https://discuss.pytorch.org/t/very-strange-behavior-change-one-element-of-a-tensor-will-influence-all-elements/41190

以上这篇解决Pytorch自定义层出现多Variable共享内存错误问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
从零学Python之入门(三)序列
May 25 Python
在Django中创建动态视图的教程
Jul 15 Python
Python的iOS自动化打包实例代码
Nov 22 Python
python 用下标截取字符串的实例
Dec 25 Python
Django uwsgi Nginx 的生产环境部署详解
Feb 02 Python
python 利用文件锁单例执行脚本的方法
Feb 19 Python
pyqt5利用pyqtDesigner实现登录界面
Mar 28 Python
linux中如何使用python3获取ip地址
Jul 15 Python
Tensorflow tf.dynamic_partition矩阵拆分示例(Python3)
Feb 07 Python
Python 开发工具通过 agent 代理使用的方法
Sep 27 Python
python爬虫---requests库的用法详解
Sep 28 Python
python pandas 解析(读取、写入)CSV 文件的操作方法
Dec 24 Python
Pytorch学习之torch用法----比较操作(Comparison Ops)
Jun 28 #Python
PyTorch的torch.cat用法
Jun 28 #Python
使用pytorch 筛选出一定范围的值
Jun 28 #Python
解析python 中/ 和 % 和 //(地板除)
Jun 28 #Python
pytorch 常用函数 max ,eq说明
Jun 28 #Python
浅谈pytorch中torch.max和F.softmax函数的维度解释
Jun 28 #Python
Python turtle库的画笔控制说明
Jun 28 #Python
You might like
phpMyAdmin下载、安装和使用入门教程
2007/05/31 PHP
smarty模板嵌套之include与fetch性能测试
2010/12/05 PHP
php中session_unset与session_destroy的区别分析
2011/06/16 PHP
CI框架在CLI下执行占用内存过大问题的解决方法
2014/06/17 PHP
几道坑人的PHP面试题 试试看看你会不会也中招
2014/08/19 PHP
PHP提示Warning:phpinfo() has been disabled函数禁用的解决方法
2014/12/17 PHP
php生出随机字符串
2017/07/06 PHP
PHP中常见的密码处理方式和建议总结
2018/10/14 PHP
javascript 延迟加载技术(lazyload)简单实现
2011/01/17 Javascript
js获取触发事件元素在整个网页中的绝对坐标(示例代码)
2013/12/13 Javascript
iframe调用父页面函数示例详解
2014/07/17 Javascript
使用JS实现jQuery的addClass, removeClass, hasClass函数功能
2014/10/31 Javascript
javascript实现对表格元素进行排序操作
2015/11/18 Javascript
微信js-sdk预览图片接口及从拍照或手机相册中选图接口用法示例
2016/10/13 Javascript
原生javascript实现分页效果
2017/04/21 Javascript
以BootStrap Tab为例写一个前端组件
2017/07/25 Javascript
[04:53]DOTA2英雄基础教程 祈求者
2014/01/03 DOTA
[01:30:15]DOTA2-DPC中国联赛 正赛 Ehome vs Aster BO3 第二场 2月2日
2021/03/11 DOTA
Python中的字符串操作和编码Unicode详解
2017/01/18 Python
简单实现python收发邮件功能
2018/01/05 Python
python+opencv 读取文件夹下的所有图像并批量保存ROI的方法
2019/01/10 Python
python flask解析json数据不完整的解决方法
2019/05/26 Python
Python OpenCV实现视频分帧
2019/06/01 Python
Django Aggregation聚合使用方法解析
2019/08/01 Python
Python 防止死锁的方法
2020/07/29 Python
CSS3轻松实现圆角效果
2017/11/09 HTML / CSS
css3绘制天猫logo实现代码
2012/11/06 HTML / CSS
美国婚礼和派对礼品网站:Kate Aspen(新娘送礼会、迎婴派对)
2018/03/28 全球购物
6PM官网:折扣鞋、服装及配饰
2018/08/03 全球购物
专业销售业务员求职信
2013/11/18 职场文书
购房协议书范本
2014/04/11 职场文书
中药学专业求职信
2014/05/31 职场文书
公司行政专员岗位职责
2014/08/24 职场文书
2015年销售内勤工作总结
2015/04/27 职场文书
学校捐书倡议书
2015/04/27 职场文书
SQL Server删除表中的重复数据
2022/05/25 SQL Server