Pytorch 中retain_graph的用法详解


Posted in Python onJanuary 07, 2020

用法分析

在查看SRGAN源码时有如下损失函数,其中设置了retain_graph=True,其作用是什么?

############################
    # (1) Update D network: maximize D(x)-1-D(G(z))
    ###########################
    real_img = Variable(target)
    if torch.cuda.is_available():
      real_img = real_img.cuda()
    z = Variable(data)
    if torch.cuda.is_available():
      z = z.cuda()
    fake_img = netG(z)

    netD.zero_grad()
    real_out = netD(real_img).mean()
    fake_out = netD(fake_img).mean()
    d_loss = 1 - real_out + fake_out
    d_loss.backward(retain_graph=True) #####
    optimizerD.step()

    ############################
    # (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss
    ###########################
    netG.zero_grad()
    g_loss = generator_criterion(fake_out, fake_img, real_img)
    g_loss.backward()
    optimizerG.step()
    fake_img = netG(z)
    fake_out = netD(fake_img).mean()

    g_loss = generator_criterion(fake_out, fake_img, real_img)
    running_results['g_loss'] += g_loss.data[0] * batch_size
    d_loss = 1 - real_out + fake_out
    running_results['d_loss'] += d_loss.data[0] * batch_size
    running_results['d_score'] += real_out.data[0] * batch_size
    running_results['g_score'] += fake_out.data[0] * batch_size

在更新D网络时的loss反向传播过程中使用了retain_graph=True,目的为是为保留该过程中计算的梯度,后续G网络更新时使用;

其实retain_graph这个参数在平常中我们是用不到的,但是在特殊的情况下我们会用到它,

如下代码:

import torch
y=x**2
z=y*4
output1=z.mean()
output2=z.sum()
output1.backward()
output2.backward()

输出如下错误信息:

---------------------------------------------------------------------------
RuntimeError               Traceback (most recent call last)
<ipython-input-19-8ad6b0658906> in <module>()
----> 1 output1.backward()
   2 output2.backward()

D:\ProgramData\Anaconda3\lib\site-packages\torch\tensor.py in backward(self, gradient, retain_graph, create_graph)
   91         products. Defaults to ``False``.
   92     """
---> 93     torch.autograd.backward(self, gradient, retain_graph, create_graph)
   94 
   95   def register_hook(self, hook):

D:\ProgramData\Anaconda3\lib\site-packages\torch\autograd\__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
   88   Variable._execution_engine.run_backward(
   89     tensors, grad_tensors, retain_graph, create_graph,
---> 90     allow_unreachable=True) # allow_unreachable flag
   91 
   92 

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

修改成如下正确:

import torch
y=x**2
z=y*4
output1=z.mean()
output2=z.sum()
output1.backward(retain_graph=True)
output2.backward()
# 假如你有两个Loss,先执行第一个的backward,再执行第二个backward
loss1.backward(retain_graph=True)
loss2.backward() # 执行完这个后,所有中间变量都会被释放,以便下一次的循环
optimizer.step() # 更新参数

Variable 类源代码

class Variable(_C._VariableBase):
 
  """
  Attributes:
    data: 任意类型的封装好的张量。
    grad: 保存与data类型和位置相匹配的梯度,此属性难以分配并且不能重新分配。
    requires_grad: 标记变量是否已经由一个需要调用到此变量的子图创建的bool值。只能在叶子变量上进行修改。
    volatile: 标记变量是否能在推理模式下应用(如不保存历史记录)的bool值。只能在叶变量上更改。
    is_leaf: 标记变量是否是图叶子(如由用户创建的变量)的bool值.
    grad_fn: Gradient function graph trace.
 
  Parameters:
    data (any tensor class): 要包装的张量.
    requires_grad (bool): bool型的标记值. **Keyword only.**
    volatile (bool): bool型的标记值. **Keyword only.**
  """
 
  def backward(self, gradient=None, retain_graph=None, create_graph=None, retain_variables=None):
    """计算关于当前图叶子变量的梯度,图使用链式法则导致分化
    如果Variable是一个标量(例如它包含一个单元素数据),你无需对backward()指定任何参数
    如果变量不是标量(包含多个元素数据的矢量)且需要梯度,函数需要额外的梯度;
    需要指定一个和tensor的形状匹配的grad_output参数(y在指定方向投影对x的导数);
    可以是一个类型和位置相匹配且包含与自身相关的不同函数梯度的张量。
    函数在叶子上累积梯度,调用前需要对该叶子进行清零。
 
    Arguments:
      grad_variables (Tensor, Variable or None):
              变量的梯度,如果是一个张量,除非“create_graph”是True,否则会自动转换成volatile型的变量。
              可以为标量变量或不需要grad的值指定None值。如果None值可接受,则此参数可选。
      retain_graph (bool, optional): 如果为False,用来计算梯度的图将被释放。
                      在几乎所有情况下,将此选项设置为True不是必需的,通常可以以更有效的方式解决。
                      默认值为create_graph的值。
      create_graph (bool, optional): 为True时,会构造一个导数的图,用来计算出更高阶导数结果。
                      默认为False,除非``gradient``是一个volatile变量。
    """
    torch.autograd.backward(self, gradient, retain_graph, create_graph, retain_variables)
 
 
  def register_hook(self, hook):
    """Registers a backward hook.
 
    每当与variable相关的梯度被计算时调用hook,hook的申明:hook(grad)->Variable or None
    不能对hook的参数进行修改,但可以选择性地返回一个新的梯度以用在`grad`的相应位置。
 
    函数返回一个handle,其``handle.remove()``方法用于将hook从模块中移除。
 
    Example:
      >>> v = Variable(torch.Tensor([0, 0, 0]), requires_grad=True)
      >>> h = v.register_hook(lambda grad: grad * 2) # double the gradient
      >>> v.backward(torch.Tensor([1, 1, 1]))
      >>> v.grad.data
       2
       2
       2
      [torch.FloatTensor of size 3]
      >>> h.remove() # removes the hook
    """
    if self.volatile:
      raise RuntimeError("cannot register a hook on a volatile variable")
    if not self.requires_grad:
      raise RuntimeError("cannot register a hook on a variable that "
                "doesn't require gradient")
    if self._backward_hooks is None:
      self._backward_hooks = OrderedDict()
      if self.grad_fn is not None:
        self.grad_fn._register_hook_dict(self)
    handle = hooks.RemovableHandle(self._backward_hooks)
    self._backward_hooks[handle.id] = hook
    return handle
 
  def reinforce(self, reward):
    """Registers a reward obtained as a result of a stochastic process.
    区分随机节点需要为他们提供reward值。如果图表中包含任何的随机操作,都应该在其输出上调用此函数,否则会出现错误。
    Parameters:
      reward(Tensor): 带有每个元素奖赏的张量,必须与Variable数据的设备位置和形状相匹配。
    """
    if not isinstance(self.grad_fn, StochasticFunction):
      raise RuntimeError("reinforce() can be only called on outputs "
                "of stochastic functions")
    self.grad_fn._reinforce(reward)
 
  def detach(self):
    """返回一个从当前图分离出来的心变量。
    结果不需要梯度,如果输入是volatile,则输出也是volatile。
 
    .. 注意::
     返回变量使用与原始变量相同的数据张量,并且可以看到其中任何一个的就地修改,并且可能会触发正确性检查中的错误。
    """
    result = NoGrad()(self) # this is needed, because it merges version counters
    result._grad_fn = None
    return result
 
  def detach_(self):
    """从创建它的图中分离出变量并作为该图的一个叶子"""
    self._grad_fn = None
    self.requires_grad = False
 
  def retain_grad(self):
    """Enables .grad attribute for non-leaf Variables."""
    if self.grad_fn is None: # no-op for leaves
      return
    if not self.requires_grad:
      raise RuntimeError("can't retain_grad on Variable that has requires_grad=False")
    if hasattr(self, 'retains_grad'):
      return
    weak_self = weakref.ref(self)
 
    def retain_grad_hook(grad):
      var = weak_self()
      if var is None:
        return
      if var._grad is None:
        var._grad = grad.clone()
      else:
        var._grad = var._grad + grad
 
    self.register_hook(retain_grad_hook)
    self.retains_grad = True

以上这篇Pytorch 中retain_graph的用法详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python version 2.7 required, which was not found in the registry
Aug 26 Python
Python 实现链表实例代码
Apr 07 Python
python中如何使用正则表达式的非贪婪模式示例
Oct 09 Python
深入理解Python分布式爬虫原理
Nov 23 Python
利用python编写一个图片主色转换的脚本
Dec 07 Python
django小技巧之html模板中调用对象属性或对象的方法
Nov 30 Python
python实现简单日期工具类
Apr 24 Python
零基础使用Python读写处理Excel表格的方法
May 02 Python
django的聚合函数和aggregate、annotate方法使用详解
Jul 23 Python
Python中的全局变量如何理解
Jun 04 Python
python解压zip包中文乱码解决方法
Nov 27 Python
python第三方网页解析器 lxml 扩展库与 xpath 的使用方法
Apr 06 Python
PyTorch中的Variable变量详解
Jan 07 #Python
python enumerate内置函数用法总结
Jan 07 #Python
pytorch加载自定义网络权重的实现
Jan 07 #Python
Matplotlib绘制雷达图和三维图的示例代码
Jan 07 #Python
Pytorch 神经网络—自定义数据集上实现教程
Jan 07 #Python
浅谈Python访问MySQL的正确姿势
Jan 07 #Python
pytorch自定义二值化网络层方式
Jan 07 #Python
You might like
ubuntu 编译安装php 5.3.3+memcache的方法
2010/08/05 PHP
php iconv() : Detected an illegal character in input string
2010/12/05 PHP
ThinkPHP进程计数类Process用法实例详解
2015/09/25 PHP
网页里控制图片大小的相关代码
2006/06/25 Javascript
js调用flash的效果代码
2008/04/26 Javascript
js 蒙版进度条(结合图片)
2010/03/10 Javascript
javascript实现按回车键切换焦点
2015/02/09 Javascript
Angular ng-class详解及实例代码
2016/09/19 Javascript
微信小程序 欢迎界面开发的实例详解
2016/11/30 Javascript
详解jQuery lazyload 懒加载
2016/12/19 Javascript
vue组件Prop传递数据的实现示例
2017/08/17 Javascript
webpack处理 css\less\sass 样式的方法
2017/08/21 Javascript
JQuery 又谈ajax局部刷新
2017/11/27 jQuery
浅谈webpack打包之后的文件过大的解决方法
2018/03/07 Javascript
Vue 实现对quill-editor组件中的工具栏添加title
2020/08/03 Javascript
Python如何通过subprocess调用adb命令详解
2017/08/27 Python
Python+Turtle动态绘制一棵树实例分享
2018/01/16 Python
python 获取指定文件夹下所有文件名称并写入列表的实例
2018/04/23 Python
如何优雅地改进Django中的模板碎片缓存详解
2018/07/04 Python
Python2.7实现多进程下开发多线程示例
2019/05/31 Python
python实现连连看辅助(图像识别)
2020/03/25 Python
Django ORM filter() 的运用详解
2020/05/14 Python
Python中有几个关键字
2020/06/04 Python
新加坡交友网站:be2新加坡
2019/04/10 全球购物
写出SQL四条最基本的数据操作语句(DML)
2012/12/12 面试题
专科毕业生自我鉴定
2013/12/01 职场文书
不假外出检讨书
2014/01/27 职场文书
个人简历中自我评价
2014/02/11 职场文书
家具商场的活动方案
2014/08/16 职场文书
助人为乐道德模范事迹材料
2014/08/16 职场文书
学习十八大的心得体会
2014/09/01 职场文书
党员学习中共十八大思想报告
2014/09/12 职场文书
公积金贷款承诺书
2015/04/30 职场文书
Java面试题冲刺第十七天--基础篇3
2021/08/07 面试题
Python的三个重要函数详解
2022/01/18 Python
Oracle中DBLink的详细介绍
2022/04/29 Oracle