Pytorch 中retain_graph的用法详解


Posted in Python onJanuary 07, 2020

用法分析

在查看SRGAN源码时有如下损失函数,其中设置了retain_graph=True,其作用是什么?

############################
    # (1) Update D network: maximize D(x)-1-D(G(z))
    ###########################
    real_img = Variable(target)
    if torch.cuda.is_available():
      real_img = real_img.cuda()
    z = Variable(data)
    if torch.cuda.is_available():
      z = z.cuda()
    fake_img = netG(z)

    netD.zero_grad()
    real_out = netD(real_img).mean()
    fake_out = netD(fake_img).mean()
    d_loss = 1 - real_out + fake_out
    d_loss.backward(retain_graph=True) #####
    optimizerD.step()

    ############################
    # (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss
    ###########################
    netG.zero_grad()
    g_loss = generator_criterion(fake_out, fake_img, real_img)
    g_loss.backward()
    optimizerG.step()
    fake_img = netG(z)
    fake_out = netD(fake_img).mean()

    g_loss = generator_criterion(fake_out, fake_img, real_img)
    running_results['g_loss'] += g_loss.data[0] * batch_size
    d_loss = 1 - real_out + fake_out
    running_results['d_loss'] += d_loss.data[0] * batch_size
    running_results['d_score'] += real_out.data[0] * batch_size
    running_results['g_score'] += fake_out.data[0] * batch_size

在更新D网络时的loss反向传播过程中使用了retain_graph=True,目的为是为保留该过程中计算的梯度,后续G网络更新时使用;

其实retain_graph这个参数在平常中我们是用不到的,但是在特殊的情况下我们会用到它,

如下代码:

import torch
y=x**2
z=y*4
output1=z.mean()
output2=z.sum()
output1.backward()
output2.backward()

输出如下错误信息:

---------------------------------------------------------------------------
RuntimeError               Traceback (most recent call last)
<ipython-input-19-8ad6b0658906> in <module>()
----> 1 output1.backward()
   2 output2.backward()

D:\ProgramData\Anaconda3\lib\site-packages\torch\tensor.py in backward(self, gradient, retain_graph, create_graph)
   91         products. Defaults to ``False``.
   92     """
---> 93     torch.autograd.backward(self, gradient, retain_graph, create_graph)
   94 
   95   def register_hook(self, hook):

D:\ProgramData\Anaconda3\lib\site-packages\torch\autograd\__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
   88   Variable._execution_engine.run_backward(
   89     tensors, grad_tensors, retain_graph, create_graph,
---> 90     allow_unreachable=True) # allow_unreachable flag
   91 
   92 

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

修改成如下正确:

import torch
y=x**2
z=y*4
output1=z.mean()
output2=z.sum()
output1.backward(retain_graph=True)
output2.backward()
# 假如你有两个Loss,先执行第一个的backward,再执行第二个backward
loss1.backward(retain_graph=True)
loss2.backward() # 执行完这个后,所有中间变量都会被释放,以便下一次的循环
optimizer.step() # 更新参数

Variable 类源代码

class Variable(_C._VariableBase):
 
  """
  Attributes:
    data: 任意类型的封装好的张量。
    grad: 保存与data类型和位置相匹配的梯度,此属性难以分配并且不能重新分配。
    requires_grad: 标记变量是否已经由一个需要调用到此变量的子图创建的bool值。只能在叶子变量上进行修改。
    volatile: 标记变量是否能在推理模式下应用(如不保存历史记录)的bool值。只能在叶变量上更改。
    is_leaf: 标记变量是否是图叶子(如由用户创建的变量)的bool值.
    grad_fn: Gradient function graph trace.
 
  Parameters:
    data (any tensor class): 要包装的张量.
    requires_grad (bool): bool型的标记值. **Keyword only.**
    volatile (bool): bool型的标记值. **Keyword only.**
  """
 
  def backward(self, gradient=None, retain_graph=None, create_graph=None, retain_variables=None):
    """计算关于当前图叶子变量的梯度,图使用链式法则导致分化
    如果Variable是一个标量(例如它包含一个单元素数据),你无需对backward()指定任何参数
    如果变量不是标量(包含多个元素数据的矢量)且需要梯度,函数需要额外的梯度;
    需要指定一个和tensor的形状匹配的grad_output参数(y在指定方向投影对x的导数);
    可以是一个类型和位置相匹配且包含与自身相关的不同函数梯度的张量。
    函数在叶子上累积梯度,调用前需要对该叶子进行清零。
 
    Arguments:
      grad_variables (Tensor, Variable or None):
              变量的梯度,如果是一个张量,除非“create_graph”是True,否则会自动转换成volatile型的变量。
              可以为标量变量或不需要grad的值指定None值。如果None值可接受,则此参数可选。
      retain_graph (bool, optional): 如果为False,用来计算梯度的图将被释放。
                      在几乎所有情况下,将此选项设置为True不是必需的,通常可以以更有效的方式解决。
                      默认值为create_graph的值。
      create_graph (bool, optional): 为True时,会构造一个导数的图,用来计算出更高阶导数结果。
                      默认为False,除非``gradient``是一个volatile变量。
    """
    torch.autograd.backward(self, gradient, retain_graph, create_graph, retain_variables)
 
 
  def register_hook(self, hook):
    """Registers a backward hook.
 
    每当与variable相关的梯度被计算时调用hook,hook的申明:hook(grad)->Variable or None
    不能对hook的参数进行修改,但可以选择性地返回一个新的梯度以用在`grad`的相应位置。
 
    函数返回一个handle,其``handle.remove()``方法用于将hook从模块中移除。
 
    Example:
      >>> v = Variable(torch.Tensor([0, 0, 0]), requires_grad=True)
      >>> h = v.register_hook(lambda grad: grad * 2) # double the gradient
      >>> v.backward(torch.Tensor([1, 1, 1]))
      >>> v.grad.data
       2
       2
       2
      [torch.FloatTensor of size 3]
      >>> h.remove() # removes the hook
    """
    if self.volatile:
      raise RuntimeError("cannot register a hook on a volatile variable")
    if not self.requires_grad:
      raise RuntimeError("cannot register a hook on a variable that "
                "doesn't require gradient")
    if self._backward_hooks is None:
      self._backward_hooks = OrderedDict()
      if self.grad_fn is not None:
        self.grad_fn._register_hook_dict(self)
    handle = hooks.RemovableHandle(self._backward_hooks)
    self._backward_hooks[handle.id] = hook
    return handle
 
  def reinforce(self, reward):
    """Registers a reward obtained as a result of a stochastic process.
    区分随机节点需要为他们提供reward值。如果图表中包含任何的随机操作,都应该在其输出上调用此函数,否则会出现错误。
    Parameters:
      reward(Tensor): 带有每个元素奖赏的张量,必须与Variable数据的设备位置和形状相匹配。
    """
    if not isinstance(self.grad_fn, StochasticFunction):
      raise RuntimeError("reinforce() can be only called on outputs "
                "of stochastic functions")
    self.grad_fn._reinforce(reward)
 
  def detach(self):
    """返回一个从当前图分离出来的心变量。
    结果不需要梯度,如果输入是volatile,则输出也是volatile。
 
    .. 注意::
     返回变量使用与原始变量相同的数据张量,并且可以看到其中任何一个的就地修改,并且可能会触发正确性检查中的错误。
    """
    result = NoGrad()(self) # this is needed, because it merges version counters
    result._grad_fn = None
    return result
 
  def detach_(self):
    """从创建它的图中分离出变量并作为该图的一个叶子"""
    self._grad_fn = None
    self.requires_grad = False
 
  def retain_grad(self):
    """Enables .grad attribute for non-leaf Variables."""
    if self.grad_fn is None: # no-op for leaves
      return
    if not self.requires_grad:
      raise RuntimeError("can't retain_grad on Variable that has requires_grad=False")
    if hasattr(self, 'retains_grad'):
      return
    weak_self = weakref.ref(self)
 
    def retain_grad_hook(grad):
      var = weak_self()
      if var is None:
        return
      if var._grad is None:
        var._grad = grad.clone()
      else:
        var._grad = var._grad + grad
 
    self.register_hook(retain_grad_hook)
    self.retains_grad = True

以上这篇Pytorch 中retain_graph的用法详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中使用logging模块代替print(logging简明指南)
Jul 09 Python
举例详解Python中的split()函数的使用方法
Apr 07 Python
Python中的字符串类型基本知识学习教程
Feb 04 Python
python机器学习之神经网络(二)
Dec 20 Python
django连接mysql配置方法总结(推荐)
Aug 18 Python
Python socket实现多对多全双工通信的方法
Feb 13 Python
解决python Markdown模块乱码的问题
Feb 14 Python
Python使用贪婪算法解决问题
Oct 22 Python
基于python实现把图片转换成素描
Nov 13 Python
Python post请求实现代码实例
Feb 28 Python
Python利用Faiss库实现ANN近邻搜索的方法详解
Aug 03 Python
Python捕获、播放和保存摄像头视频并提高视频清晰度和对比度
Apr 14 Python
PyTorch中的Variable变量详解
Jan 07 #Python
python enumerate内置函数用法总结
Jan 07 #Python
pytorch加载自定义网络权重的实现
Jan 07 #Python
Matplotlib绘制雷达图和三维图的示例代码
Jan 07 #Python
Pytorch 神经网络—自定义数据集上实现教程
Jan 07 #Python
浅谈Python访问MySQL的正确姿势
Jan 07 #Python
pytorch自定义二值化网络层方式
Jan 07 #Python
You might like
最令PHP初学者们头痛的十四个问题
2007/01/15 PHP
最新的php 文件上传模型,支持多文件上传
2009/08/13 PHP
PDO::getAvailableDrivers讲解
2019/01/28 PHP
利用PHP内置SERVER开启web服务(本地开发使用)
2020/01/22 PHP
新增加的内容是如何将div的scrollbar自动移动最下面
2014/01/02 Javascript
JavaScript中奇葩的假值示例应用
2014/03/11 Javascript
jQuery中:selected选择器用法实例
2015/01/04 Javascript
js实现漂浮回顶部按钮实例
2015/05/06 Javascript
springmvc接收jquery提交的数组数据代码分享
2017/10/28 jQuery
微信小程序模拟cookie的实现
2018/06/20 Javascript
vue+canvas实现炫酷时钟效果的倒计时插件(已发布到npm的vue2插件,开箱即用)
2018/11/05 Javascript
微信小程序的注册页面包含倒计时验证码、获取用户信息
2019/05/22 Javascript
[40:55]DOTA2上海特级锦标赛主赛事日 - 2 败者组第二轮#4Newbee VS Fnatic
2016/03/03 DOTA
python中精确输出JSON浮点数的方法
2014/04/18 Python
python中利用xml.dom模块解析xml的方法教程
2017/05/24 Python
Python爬虫常用库的安装及其环境配置
2018/09/19 Python
Python2与Python3的区别实例总结
2019/04/17 Python
梅尔倒谱系数(MFCC)实现
2019/06/19 Python
Python3 xml.etree.ElementTree支持的XPath语法详解
2020/03/06 Python
python多进程 主进程和子进程间共享和不共享全局变量实例
2020/04/25 Python
python 装饰器的实际作用有哪些
2020/09/07 Python
Python基于unittest实现测试用例执行
2020/11/25 Python
详解如何在css中引入自定义字体(font-face)
2018/05/17 HTML / CSS
美国最大的船只买卖在线市场:Boat Trader
2018/08/04 全球购物
英国著名的美容护肤和护发产品购物网站:Lookfantastic
2020/11/23 全球购物
2014年开学第一课活动方案
2014/03/06 职场文书
任命书格式
2014/06/05 职场文书
学习雷锋月活动总结
2014/07/03 职场文书
文秘应届生求职信
2014/07/05 职场文书
司机岗位职责说明书
2014/07/29 职场文书
2014年银行信贷员工作总结
2014/12/08 职场文书
小兵张嘎电影观后感
2015/06/03 职场文书
2016年学校禁毒宣传活动工作总结
2016/04/05 职场文书
Oracle11g r2 卸载干净重装的详细教程(亲测有效已重装过)
2021/06/04 Oracle
Java Lambda表达式常用的函数式接口
2022/04/07 Java/Android
Win11使用CAD卡顿或者致命错误怎么办?Win11无法正常使用CAD的解决方法
2022/07/23 数码科技