Pytorch 中retain_graph的用法详解


Posted in Python onJanuary 07, 2020

用法分析

在查看SRGAN源码时有如下损失函数,其中设置了retain_graph=True,其作用是什么?

############################
    # (1) Update D network: maximize D(x)-1-D(G(z))
    ###########################
    real_img = Variable(target)
    if torch.cuda.is_available():
      real_img = real_img.cuda()
    z = Variable(data)
    if torch.cuda.is_available():
      z = z.cuda()
    fake_img = netG(z)

    netD.zero_grad()
    real_out = netD(real_img).mean()
    fake_out = netD(fake_img).mean()
    d_loss = 1 - real_out + fake_out
    d_loss.backward(retain_graph=True) #####
    optimizerD.step()

    ############################
    # (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss
    ###########################
    netG.zero_grad()
    g_loss = generator_criterion(fake_out, fake_img, real_img)
    g_loss.backward()
    optimizerG.step()
    fake_img = netG(z)
    fake_out = netD(fake_img).mean()

    g_loss = generator_criterion(fake_out, fake_img, real_img)
    running_results['g_loss'] += g_loss.data[0] * batch_size
    d_loss = 1 - real_out + fake_out
    running_results['d_loss'] += d_loss.data[0] * batch_size
    running_results['d_score'] += real_out.data[0] * batch_size
    running_results['g_score'] += fake_out.data[0] * batch_size

在更新D网络时的loss反向传播过程中使用了retain_graph=True,目的为是为保留该过程中计算的梯度,后续G网络更新时使用;

其实retain_graph这个参数在平常中我们是用不到的,但是在特殊的情况下我们会用到它,

如下代码:

import torch
y=x**2
z=y*4
output1=z.mean()
output2=z.sum()
output1.backward()
output2.backward()

输出如下错误信息:

---------------------------------------------------------------------------
RuntimeError               Traceback (most recent call last)
<ipython-input-19-8ad6b0658906> in <module>()
----> 1 output1.backward()
   2 output2.backward()

D:\ProgramData\Anaconda3\lib\site-packages\torch\tensor.py in backward(self, gradient, retain_graph, create_graph)
   91         products. Defaults to ``False``.
   92     """
---> 93     torch.autograd.backward(self, gradient, retain_graph, create_graph)
   94 
   95   def register_hook(self, hook):

D:\ProgramData\Anaconda3\lib\site-packages\torch\autograd\__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
   88   Variable._execution_engine.run_backward(
   89     tensors, grad_tensors, retain_graph, create_graph,
---> 90     allow_unreachable=True) # allow_unreachable flag
   91 
   92 

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

修改成如下正确:

import torch
y=x**2
z=y*4
output1=z.mean()
output2=z.sum()
output1.backward(retain_graph=True)
output2.backward()
# 假如你有两个Loss,先执行第一个的backward,再执行第二个backward
loss1.backward(retain_graph=True)
loss2.backward() # 执行完这个后,所有中间变量都会被释放,以便下一次的循环
optimizer.step() # 更新参数

Variable 类源代码

class Variable(_C._VariableBase):
 
  """
  Attributes:
    data: 任意类型的封装好的张量。
    grad: 保存与data类型和位置相匹配的梯度,此属性难以分配并且不能重新分配。
    requires_grad: 标记变量是否已经由一个需要调用到此变量的子图创建的bool值。只能在叶子变量上进行修改。
    volatile: 标记变量是否能在推理模式下应用(如不保存历史记录)的bool值。只能在叶变量上更改。
    is_leaf: 标记变量是否是图叶子(如由用户创建的变量)的bool值.
    grad_fn: Gradient function graph trace.
 
  Parameters:
    data (any tensor class): 要包装的张量.
    requires_grad (bool): bool型的标记值. **Keyword only.**
    volatile (bool): bool型的标记值. **Keyword only.**
  """
 
  def backward(self, gradient=None, retain_graph=None, create_graph=None, retain_variables=None):
    """计算关于当前图叶子变量的梯度,图使用链式法则导致分化
    如果Variable是一个标量(例如它包含一个单元素数据),你无需对backward()指定任何参数
    如果变量不是标量(包含多个元素数据的矢量)且需要梯度,函数需要额外的梯度;
    需要指定一个和tensor的形状匹配的grad_output参数(y在指定方向投影对x的导数);
    可以是一个类型和位置相匹配且包含与自身相关的不同函数梯度的张量。
    函数在叶子上累积梯度,调用前需要对该叶子进行清零。
 
    Arguments:
      grad_variables (Tensor, Variable or None):
              变量的梯度,如果是一个张量,除非“create_graph”是True,否则会自动转换成volatile型的变量。
              可以为标量变量或不需要grad的值指定None值。如果None值可接受,则此参数可选。
      retain_graph (bool, optional): 如果为False,用来计算梯度的图将被释放。
                      在几乎所有情况下,将此选项设置为True不是必需的,通常可以以更有效的方式解决。
                      默认值为create_graph的值。
      create_graph (bool, optional): 为True时,会构造一个导数的图,用来计算出更高阶导数结果。
                      默认为False,除非``gradient``是一个volatile变量。
    """
    torch.autograd.backward(self, gradient, retain_graph, create_graph, retain_variables)
 
 
  def register_hook(self, hook):
    """Registers a backward hook.
 
    每当与variable相关的梯度被计算时调用hook,hook的申明:hook(grad)->Variable or None
    不能对hook的参数进行修改,但可以选择性地返回一个新的梯度以用在`grad`的相应位置。
 
    函数返回一个handle,其``handle.remove()``方法用于将hook从模块中移除。
 
    Example:
      >>> v = Variable(torch.Tensor([0, 0, 0]), requires_grad=True)
      >>> h = v.register_hook(lambda grad: grad * 2) # double the gradient
      >>> v.backward(torch.Tensor([1, 1, 1]))
      >>> v.grad.data
       2
       2
       2
      [torch.FloatTensor of size 3]
      >>> h.remove() # removes the hook
    """
    if self.volatile:
      raise RuntimeError("cannot register a hook on a volatile variable")
    if not self.requires_grad:
      raise RuntimeError("cannot register a hook on a variable that "
                "doesn't require gradient")
    if self._backward_hooks is None:
      self._backward_hooks = OrderedDict()
      if self.grad_fn is not None:
        self.grad_fn._register_hook_dict(self)
    handle = hooks.RemovableHandle(self._backward_hooks)
    self._backward_hooks[handle.id] = hook
    return handle
 
  def reinforce(self, reward):
    """Registers a reward obtained as a result of a stochastic process.
    区分随机节点需要为他们提供reward值。如果图表中包含任何的随机操作,都应该在其输出上调用此函数,否则会出现错误。
    Parameters:
      reward(Tensor): 带有每个元素奖赏的张量,必须与Variable数据的设备位置和形状相匹配。
    """
    if not isinstance(self.grad_fn, StochasticFunction):
      raise RuntimeError("reinforce() can be only called on outputs "
                "of stochastic functions")
    self.grad_fn._reinforce(reward)
 
  def detach(self):
    """返回一个从当前图分离出来的心变量。
    结果不需要梯度,如果输入是volatile,则输出也是volatile。
 
    .. 注意::
     返回变量使用与原始变量相同的数据张量,并且可以看到其中任何一个的就地修改,并且可能会触发正确性检查中的错误。
    """
    result = NoGrad()(self) # this is needed, because it merges version counters
    result._grad_fn = None
    return result
 
  def detach_(self):
    """从创建它的图中分离出变量并作为该图的一个叶子"""
    self._grad_fn = None
    self.requires_grad = False
 
  def retain_grad(self):
    """Enables .grad attribute for non-leaf Variables."""
    if self.grad_fn is None: # no-op for leaves
      return
    if not self.requires_grad:
      raise RuntimeError("can't retain_grad on Variable that has requires_grad=False")
    if hasattr(self, 'retains_grad'):
      return
    weak_self = weakref.ref(self)
 
    def retain_grad_hook(grad):
      var = weak_self()
      if var is None:
        return
      if var._grad is None:
        var._grad = grad.clone()
      else:
        var._grad = var._grad + grad
 
    self.register_hook(retain_grad_hook)
    self.retains_grad = True

以上这篇Pytorch 中retain_graph的用法详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Python中的greenlet包实现并发编程的入门教程
Apr 16 Python
详解python并发获取snmp信息及性能测试
Mar 27 Python
Python编程之序列操作实例详解
Jul 22 Python
Django与JS交互的示例代码
Aug 23 Python
python实现的MySQL增删改查操作实例小结
Dec 19 Python
Django uwsgi Nginx 的生产环境部署详解
Feb 02 Python
Python 微信爬虫完整实例【单线程与多线程】
Jul 06 Python
Python中正反斜杠(‘/’和‘\’)的意义与用法
Aug 12 Python
Python文件操作基础流程解析
Mar 19 Python
使用matplotlib动态刷新指定曲线实例
Apr 23 Python
python多进程 主进程和子进程间共享和不共享全局变量实例
Apr 25 Python
ITK 实现多张图像转成单个nii.gz或mha文件案例
Jul 01 Python
PyTorch中的Variable变量详解
Jan 07 #Python
python enumerate内置函数用法总结
Jan 07 #Python
pytorch加载自定义网络权重的实现
Jan 07 #Python
Matplotlib绘制雷达图和三维图的示例代码
Jan 07 #Python
Pytorch 神经网络—自定义数据集上实现教程
Jan 07 #Python
浅谈Python访问MySQL的正确姿势
Jan 07 #Python
pytorch自定义二值化网络层方式
Jan 07 #Python
You might like
星际争霸任务指南——神族
2020/03/04 星际争霸
php过滤html中的其他网站链接的方法(域名白名单功能)
2014/04/24 PHP
php中HTTP_REFERER函数用法实例
2014/11/21 PHP
深入php内核之php in array
2015/11/10 PHP
PHP调用Mailgun发送邮件的方法
2017/05/04 PHP
php利用云片网实现短信验证码功能的示例代码
2017/11/18 PHP
jquery常用特效方法使用示例
2014/04/25 Javascript
微信小程序 弹框和模态框实现代码
2017/03/10 Javascript
小程序图片长按识别功能的实现方法
2018/08/30 Javascript
vue项目使用微信公众号支付总结及遇到的坑
2018/10/23 Javascript
JavaScript查看代码运行效率console.time()与console.timeEnd()用法
2019/01/18 Javascript
在React中写一个Animation组件为组件进入和离开加上动画/过度效果
2019/06/24 Javascript
vue实现几秒后跳转新页面代码
2020/09/09 Javascript
二种python发送邮件实例讲解(python发邮件附件可以使用email模块实现)
2013/12/03 Python
Python实现给文件添加内容及得到文件信息的方法
2015/05/28 Python
Python过滤列表用法实例分析
2016/04/29 Python
TF-IDF与余弦相似性的应用(二) 找出相似文章
2017/12/21 Python
Python采集代理ip并判断是否可用和定时更新的方法
2018/05/07 Python
Django使用paginator插件实现翻页功能的实例
2018/10/24 Python
python处理两种分隔符的数据集方法
2018/12/12 Python
python使用插值法画出平滑曲线
2018/12/15 Python
大家都说好用的Python命令行库click的使用
2019/11/07 Python
Flask项目中实现短信验证码和邮箱验证码功能
2019/12/05 Python
解决windows下python3使用multiprocessing.Pool出现的问题
2020/04/08 Python
Python3自动生成MySQL数据字典的markdown文本的实现
2020/05/07 Python
浅谈Python中的模块
2020/06/10 Python
python相对企业语言优势在哪
2020/06/12 Python
详解python 内存优化
2020/08/17 Python
美国环保婴儿用品公司:The Honest Company
2017/11/23 全球购物
后勤人员自我鉴定
2013/10/20 职场文书
新学期开学演讲稿
2014/05/24 职场文书
导航工程专业自荐信
2014/09/02 职场文书
社会实践活动总结格式
2015/05/11 职场文书
大学学生会辞职信
2015/05/13 职场文书
2015年教研员工作总结
2015/05/26 职场文书
告知书格式
2015/07/01 职场文书