解决Pytorch自定义层出现多Variable共享内存错误问题


Posted in Python onJune 28, 2020

错误信息:

RuntimeError: in-place operations can be only used on variables that don't share storage with any other variables, but detected that there are 4 objects sharing it

自动求导是很方便, 但是想想, 如果两个Variable共享内存, 再对这个共享的内存的数据进行修改, 就会引起错误!

一般是由于 inplace操作或是indexing或是转置. 这些都是共享内存的.

@staticmethod
 def backward(ctx, grad_output):
  ind_lst = ctx.ind_lst
  flag = ctx.flag

  c = grad_output.size(1)
  grad_former_all = grad_output[:, 0:c//3, :, :]
  grad_latter_all = grad_output[:, c//3: c*2//3, :, :]
  grad_swapped_all = grad_output[:, c*2//3:c, :, :]

  spatial_size = ctx.h * ctx.w

  W_mat_all = Variable(ctx.Tensor(ctx.bz, spatial_size, spatial_size).zero_())
  for idx in range(ctx.bz):
   W_mat = W_mat_all.select(0,idx)
   for cnt in range(spatial_size):
    indS = ind_lst[idx][cnt] 

    if flag[cnt] == 1:
     # 这里W_mat是W_mat_all通过select出来的, 他们共享内存.
     W_mat[cnt, indS] = 1

   W_mat_t = W_mat.t()

   grad_swapped_weighted = torch.mm(W_mat_t, grad_swapped_all[idx].view(c//3, -1).t())
   grad_swapped_weighted = grad_swapped_weighted.t().contiguous().view(1, c//3, ctx.h, ctx.w)
   grad_latter_all[idx] = torch.add(grad_latter_all[idx], grad_swapped_weighted.mul(ctx.triple_w))

由于 这里W_mat是W_mat_all通过select出来的, 他们共享内存. 所以当对这个共享的内存进行修改W_mat[cnt, indS] = 1, 就会出错. 此时我们可以通过clone()将W_mat和W_mat_all独立出来. 这样的话, 梯度也会通过 clone()操作将W_mat的梯度正确反传到W_mat_all中.

@staticmethod
 def backward(ctx, grad_output):
  ind_lst = ctx.ind_lst
  flag = ctx.flag

  c = grad_output.size(1)
  grad_former_all = grad_output[:, 0:c//3, :, :]
  grad_latter_all = grad_output[:, c//3: c*2//3, :, :]
  grad_swapped_all = grad_output[:, c*2//3:c, :, :]

  spatial_size = ctx.h * ctx.w

  W_mat_all = Variable(ctx.Tensor(ctx.bz, spatial_size, spatial_size).zero_())
  for idx in range(ctx.bz):
   # 这里使用clone了
   W_mat = W_mat_all.select(0,idx).clone()
   for cnt in range(spatial_size):
    indS = ind_lst[idx][cnt]

    if flag[cnt] == 1:
     W_mat[cnt, indS] = 1

   W_mat_t = W_mat.t()

   grad_swapped_weighted = torch.mm(W_mat_t, grad_swapped_all[idx].view(c//3, -1).t())
   grad_swapped_weighted = grad_swapped_weighted.t().contiguous().view(1, c//3, ctx.h, ctx.w)

   # 这句话删了不会出错, 加上就吹出错
   grad_latter_all[idx] = torch.add(grad_latter_all[idx], grad_swapped_weighted.mul(ctx.triple_w))

但是现在却出现 4个objects共享内存. 如果将最后一句话删掉, 那么则不会出错.

如果没有最后一句话, 我们看到

grad_swapped_weighted = torch.mm(W_mat_t, grad_swapped_all[idx].view(c//3, -1).t())

grad_swapped_weighted = grad_swapped_weighted.t().contiguous().view(1, c//3, ctx.h, ctx.w)

grad_swapped_weighted 一个新的Variable, 因此并没有和其他Variable共享内存, 所以不会出错. 但是最后一句话,

grad_latter_all[idx] = torch.add(grad_latter_all[idx], grad_swapped_weighted.mul(ctx.triple_w))

你可能会说, 不对啊, 修改grad_latter_all[idx]又没有创建新的Variable, 怎么会出错. 这是因为grad_latter_all和grad_output是共享内存的. 因为 grad_latter_all = grad_output[:, c//3: c*2//3, :, :], 所以这里的解决方案是:

@staticmethod
 def backward(ctx, grad_output):
  ind_lst = ctx.ind_lst
  flag = ctx.flag

  c = grad_output.size(1)
  grad_former_all = grad_output[:, 0:c//3, :, :]
  # 这两个后面修改值了, 所以也要加clone, 防止它们与grad_output共享内存
  grad_latter_all = grad_output[:, c//3: c*2//3, :, :].clone()
  grad_swapped_all = grad_output[:, c*2//3:c, :, :].clone()

  spatial_size = ctx.h * ctx.w

  W_mat_all = Variable(ctx.Tensor(ctx.bz, spatial_size, spatial_size).zero_())
  for idx in range(ctx.bz):
   W_mat = W_mat_all.select(0,idx).clone()
   for cnt in range(spatial_size):
    indS = ind_lst[idx][cnt]

    if flag[cnt] == 1:
     W_mat[cnt, indS] = 1

   W_mat_t = W_mat.t()

   grad_swapped_weighted = torch.mm(W_mat_t, grad_swapped_all[idx].view(c//3, -1).t())

   grad_swapped_weighted = grad_swapped_weighted.t().contiguous().view(1, c//3, ctx.h, ctx.w)
   grad_latter_all[idx] = torch.add(grad_latter_all[idx], grad_swapped_weighted.mul(ctx.triple_w))

  grad_input = torch.cat([grad_former_all, grad_latter_all], 1)

  return grad_input, None, None, None, None, None, None, None, None, None, None

补充知识:Pytorch 中 expand, expand_as是共享内存的,只是原始数据的一个视图 view

如下所示:

mask = mask_miss.expand_as(sxing).clone() # type: torch.Tensor
mask[:, :, -2, :, :] = 1 # except for person mask channel

为了避免对expand后对某个channel操作会影响原始tensor的全部元素,需要使用clone()

如果没有clone(),对mask_miss的某个通道赋值后,所有通道上的tensor都会变成1!

# Notice! expand does not allocate more memory but just make the tensor look as if you expanded it.
# You should call .clone() on the resulting tensor if you plan on modifying it
# https://discuss.pytorch.org/t/very-strange-behavior-change-one-element-of-a-tensor-will-influence-all-elements/41190

以上这篇解决Pytorch自定义层出现多Variable共享内存错误问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python探索之自定义实现线程池
Oct 27 Python
python 遍历目录(包括子目录)下所有文件的实例
Jul 11 Python
Python get获取页面cookie代码实例
Sep 12 Python
浅谈python在提示符下使用open打开文件失败的原因及解决方法
Nov 30 Python
基于python生成器封装的协程类
Mar 20 Python
Django框架视图层URL映射与反向解析实例分析
Jul 29 Python
Python实现TCP探测目标服务路由轨迹的原理与方法详解
Sep 04 Python
python机器学习库xgboost的使用
Jan 20 Python
Python利用Xpath选择器爬取京东网商品信息
Jun 01 Python
浅谈python 调用open()打开文件时路径出错的原因
Jun 05 Python
Python如何定义接口和抽象类
Jul 28 Python
Python读写压缩文件的方法
Jul 30 Python
Pytorch学习之torch用法----比较操作(Comparison Ops)
Jun 28 #Python
PyTorch的torch.cat用法
Jun 28 #Python
使用pytorch 筛选出一定范围的值
Jun 28 #Python
解析python 中/ 和 % 和 //(地板除)
Jun 28 #Python
pytorch 常用函数 max ,eq说明
Jun 28 #Python
浅谈pytorch中torch.max和F.softmax函数的维度解释
Jun 28 #Python
Python turtle库的画笔控制说明
Jun 28 #Python
You might like
PHP生成唯一ID之SnowFlake算法
2016/12/17 PHP
新闻内页-JS分页
2006/06/07 Javascript
HTML中的setCapture和releaseCapture使用介绍
2012/03/21 Javascript
JavaScript加强之自定义callback示例
2013/09/21 Javascript
jquery Ajax 实现加载数据前动画效果的示例代码
2014/02/07 Javascript
关于JavaScript作用域你想知道的一切
2016/02/04 Javascript
ionic实现滑动的三种方式
2016/08/27 Javascript
一个可复用的vue分页组件
2017/05/15 Javascript
JQuery 获取多个select标签option的text内容(实例)
2017/09/07 jQuery
js 公式编辑器 - 自定义匹配规则 - 带提示下拉框 - 动态获取光标像素坐标
2018/01/04 Javascript
浅谈vue项目重构技术要点和总结
2018/01/23 Javascript
详解React项目的服务端渲染改造(koa2+webpack3.11)
2018/03/19 Javascript
Vue源码解析之数组变异的实现
2018/12/04 Javascript
使用TS来编写express服务器的方法步骤
2020/10/29 Javascript
JavaScript this关键字的深入详解
2021/01/14 Javascript
[03:11]DOTA2上海特锦赛小组赛第一日recap精彩回顾
2016/02/28 DOTA
Python版微信红包分配算法
2015/05/04 Python
python绘图方法实例入门
2015/05/19 Python
基于wxpython实现的windows GUI程序实例
2015/05/30 Python
通过Python使用saltstack生成服务器资产清单
2016/03/01 Python
对pyqt5之menu和action的使用详解
2019/06/20 Python
Python实现word2Vec model过程解析
2019/12/16 Python
对Python中 \r, \n, \r\n的彻底理解
2020/03/06 Python
python爬虫学习笔记之Beautifulsoup模块用法详解
2020/04/09 Python
python自动化办公操作PPT的实现
2021/02/05 Python
html5构建触屏网站之网站尺寸探讨
2013/01/07 HTML / CSS
EJB需直接实现它的业务接口或Home接口吗,请简述理由
2016/11/23 面试题
工程质量承诺书
2014/03/27 职场文书
环保志愿者活动总结
2014/06/27 职场文书
交通事故赔偿协议书怎么写
2014/10/04 职场文书
2014年审计人员工作总结
2014/12/19 职场文书
简短清晨问候语
2015/11/10 职场文书
会计手工模拟做账心得体会
2016/01/22 职场文书
Python爬取科目四考试题库的方法实现
2021/03/30 Python
手把手教你怎么用Python实现zip文件密码的破解
2021/05/27 Python
MySQL中IO问题的深入分析与优化
2022/04/02 MySQL