Pytorch 中retain_graph的用法详解


Posted in Python onJanuary 07, 2020

用法分析

在查看SRGAN源码时有如下损失函数,其中设置了retain_graph=True,其作用是什么?

############################
    # (1) Update D network: maximize D(x)-1-D(G(z))
    ###########################
    real_img = Variable(target)
    if torch.cuda.is_available():
      real_img = real_img.cuda()
    z = Variable(data)
    if torch.cuda.is_available():
      z = z.cuda()
    fake_img = netG(z)

    netD.zero_grad()
    real_out = netD(real_img).mean()
    fake_out = netD(fake_img).mean()
    d_loss = 1 - real_out + fake_out
    d_loss.backward(retain_graph=True) #####
    optimizerD.step()

    ############################
    # (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss
    ###########################
    netG.zero_grad()
    g_loss = generator_criterion(fake_out, fake_img, real_img)
    g_loss.backward()
    optimizerG.step()
    fake_img = netG(z)
    fake_out = netD(fake_img).mean()

    g_loss = generator_criterion(fake_out, fake_img, real_img)
    running_results['g_loss'] += g_loss.data[0] * batch_size
    d_loss = 1 - real_out + fake_out
    running_results['d_loss'] += d_loss.data[0] * batch_size
    running_results['d_score'] += real_out.data[0] * batch_size
    running_results['g_score'] += fake_out.data[0] * batch_size

在更新D网络时的loss反向传播过程中使用了retain_graph=True,目的为是为保留该过程中计算的梯度,后续G网络更新时使用;

其实retain_graph这个参数在平常中我们是用不到的,但是在特殊的情况下我们会用到它,

如下代码:

import torch
y=x**2
z=y*4
output1=z.mean()
output2=z.sum()
output1.backward()
output2.backward()

输出如下错误信息:

---------------------------------------------------------------------------
RuntimeError               Traceback (most recent call last)
<ipython-input-19-8ad6b0658906> in <module>()
----> 1 output1.backward()
   2 output2.backward()

D:\ProgramData\Anaconda3\lib\site-packages\torch\tensor.py in backward(self, gradient, retain_graph, create_graph)
   91         products. Defaults to ``False``.
   92     """
---> 93     torch.autograd.backward(self, gradient, retain_graph, create_graph)
   94 
   95   def register_hook(self, hook):

D:\ProgramData\Anaconda3\lib\site-packages\torch\autograd\__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
   88   Variable._execution_engine.run_backward(
   89     tensors, grad_tensors, retain_graph, create_graph,
---> 90     allow_unreachable=True) # allow_unreachable flag
   91 
   92 

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

修改成如下正确:

import torch
y=x**2
z=y*4
output1=z.mean()
output2=z.sum()
output1.backward(retain_graph=True)
output2.backward()
# 假如你有两个Loss,先执行第一个的backward,再执行第二个backward
loss1.backward(retain_graph=True)
loss2.backward() # 执行完这个后,所有中间变量都会被释放,以便下一次的循环
optimizer.step() # 更新参数

Variable 类源代码

class Variable(_C._VariableBase):
 
  """
  Attributes:
    data: 任意类型的封装好的张量。
    grad: 保存与data类型和位置相匹配的梯度,此属性难以分配并且不能重新分配。
    requires_grad: 标记变量是否已经由一个需要调用到此变量的子图创建的bool值。只能在叶子变量上进行修改。
    volatile: 标记变量是否能在推理模式下应用(如不保存历史记录)的bool值。只能在叶变量上更改。
    is_leaf: 标记变量是否是图叶子(如由用户创建的变量)的bool值.
    grad_fn: Gradient function graph trace.
 
  Parameters:
    data (any tensor class): 要包装的张量.
    requires_grad (bool): bool型的标记值. **Keyword only.**
    volatile (bool): bool型的标记值. **Keyword only.**
  """
 
  def backward(self, gradient=None, retain_graph=None, create_graph=None, retain_variables=None):
    """计算关于当前图叶子变量的梯度,图使用链式法则导致分化
    如果Variable是一个标量(例如它包含一个单元素数据),你无需对backward()指定任何参数
    如果变量不是标量(包含多个元素数据的矢量)且需要梯度,函数需要额外的梯度;
    需要指定一个和tensor的形状匹配的grad_output参数(y在指定方向投影对x的导数);
    可以是一个类型和位置相匹配且包含与自身相关的不同函数梯度的张量。
    函数在叶子上累积梯度,调用前需要对该叶子进行清零。
 
    Arguments:
      grad_variables (Tensor, Variable or None):
              变量的梯度,如果是一个张量,除非“create_graph”是True,否则会自动转换成volatile型的变量。
              可以为标量变量或不需要grad的值指定None值。如果None值可接受,则此参数可选。
      retain_graph (bool, optional): 如果为False,用来计算梯度的图将被释放。
                      在几乎所有情况下,将此选项设置为True不是必需的,通常可以以更有效的方式解决。
                      默认值为create_graph的值。
      create_graph (bool, optional): 为True时,会构造一个导数的图,用来计算出更高阶导数结果。
                      默认为False,除非``gradient``是一个volatile变量。
    """
    torch.autograd.backward(self, gradient, retain_graph, create_graph, retain_variables)
 
 
  def register_hook(self, hook):
    """Registers a backward hook.
 
    每当与variable相关的梯度被计算时调用hook,hook的申明:hook(grad)->Variable or None
    不能对hook的参数进行修改,但可以选择性地返回一个新的梯度以用在`grad`的相应位置。
 
    函数返回一个handle,其``handle.remove()``方法用于将hook从模块中移除。
 
    Example:
      >>> v = Variable(torch.Tensor([0, 0, 0]), requires_grad=True)
      >>> h = v.register_hook(lambda grad: grad * 2) # double the gradient
      >>> v.backward(torch.Tensor([1, 1, 1]))
      >>> v.grad.data
       2
       2
       2
      [torch.FloatTensor of size 3]
      >>> h.remove() # removes the hook
    """
    if self.volatile:
      raise RuntimeError("cannot register a hook on a volatile variable")
    if not self.requires_grad:
      raise RuntimeError("cannot register a hook on a variable that "
                "doesn't require gradient")
    if self._backward_hooks is None:
      self._backward_hooks = OrderedDict()
      if self.grad_fn is not None:
        self.grad_fn._register_hook_dict(self)
    handle = hooks.RemovableHandle(self._backward_hooks)
    self._backward_hooks[handle.id] = hook
    return handle
 
  def reinforce(self, reward):
    """Registers a reward obtained as a result of a stochastic process.
    区分随机节点需要为他们提供reward值。如果图表中包含任何的随机操作,都应该在其输出上调用此函数,否则会出现错误。
    Parameters:
      reward(Tensor): 带有每个元素奖赏的张量,必须与Variable数据的设备位置和形状相匹配。
    """
    if not isinstance(self.grad_fn, StochasticFunction):
      raise RuntimeError("reinforce() can be only called on outputs "
                "of stochastic functions")
    self.grad_fn._reinforce(reward)
 
  def detach(self):
    """返回一个从当前图分离出来的心变量。
    结果不需要梯度,如果输入是volatile,则输出也是volatile。
 
    .. 注意::
     返回变量使用与原始变量相同的数据张量,并且可以看到其中任何一个的就地修改,并且可能会触发正确性检查中的错误。
    """
    result = NoGrad()(self) # this is needed, because it merges version counters
    result._grad_fn = None
    return result
 
  def detach_(self):
    """从创建它的图中分离出变量并作为该图的一个叶子"""
    self._grad_fn = None
    self.requires_grad = False
 
  def retain_grad(self):
    """Enables .grad attribute for non-leaf Variables."""
    if self.grad_fn is None: # no-op for leaves
      return
    if not self.requires_grad:
      raise RuntimeError("can't retain_grad on Variable that has requires_grad=False")
    if hasattr(self, 'retains_grad'):
      return
    weak_self = weakref.ref(self)
 
    def retain_grad_hook(grad):
      var = weak_self()
      if var is None:
        return
      if var._grad is None:
        var._grad = grad.clone()
      else:
        var._grad = var._grad + grad
 
    self.register_hook(retain_grad_hook)
    self.retains_grad = True

以上这篇Pytorch 中retain_graph的用法详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python异步任务队列示例
Apr 01 Python
在Python中使用M2Crypto模块实现AES加密的教程
Apr 08 Python
Python实现模拟登录及表单提交的方法
Jul 25 Python
python生成词云的实现方法(推荐)
Jun 13 Python
老生常谈python之鸭子类和多态
Jun 13 Python
快速了解Python中的装饰器
Jan 11 Python
Python中使用logging和traceback模块记录日志和跟踪异常
Apr 09 Python
python实现最小二乘法线性拟合
Jul 19 Python
Python利用WMI实现ping命令的例子
Aug 14 Python
Python实现线性插值和三次样条插值的示例代码
Nov 13 Python
python获取引用对象的个数方式
Dec 20 Python
K最近邻算法(KNN)---sklearn+python实现方式
Feb 24 Python
PyTorch中的Variable变量详解
Jan 07 #Python
python enumerate内置函数用法总结
Jan 07 #Python
pytorch加载自定义网络权重的实现
Jan 07 #Python
Matplotlib绘制雷达图和三维图的示例代码
Jan 07 #Python
Pytorch 神经网络—自定义数据集上实现教程
Jan 07 #Python
浅谈Python访问MySQL的正确姿势
Jan 07 #Python
pytorch自定义二值化网络层方式
Jan 07 #Python
You might like
德生PL550的电路分析
2021/03/02 无线电
PHP的FTP学习(三)
2006/10/09 PHP
PHP 函数学习简单小结
2010/07/08 PHP
解析PHP工厂模式的好处
2013/06/18 PHP
针对thinkPHP5框架存储过程bug重写的存储过程扩展类完整实例
2018/06/16 PHP
php+js实现的无刷新下载文件功能示例
2019/08/23 PHP
会自动逐行上升的文本框
2006/06/30 Javascript
JQuery动画和停止动画实例代码
2013/03/01 Javascript
JS 实现获取打开一个界面中输入的值
2013/03/19 Javascript
JavaScript实现twitter puddles算法实例
2014/12/06 Javascript
JavaScript常用标签和方法总结
2015/09/01 Javascript
JavaScript基础教程——入门必看篇
2016/05/20 Javascript
AngularJS中关于ng-class指令的几种实现方式详解
2016/09/17 Javascript
JS实现旋转木马式图片轮播效果
2017/01/18 Javascript
jQuery点击页面其他部分隐藏下拉菜单功能
2018/11/27 jQuery
Vue-input框checkbox强制刷新问题
2019/04/18 Javascript
js计算两个时间差 天 时 分 秒 毫秒的代码
2019/05/21 Javascript
vue实现tab栏点击高亮效果
2020/08/19 Javascript
axios封装与传参示例详解
2020/10/18 Javascript
[02:22]2018DOTA2亚洲邀请赛VG赛前采访
2018/04/03 DOTA
python基于urllib实现按照百度音乐分类下载mp3的方法
2015/05/25 Python
Python实现简单的文件传输与MySQL备份的脚本分享
2016/01/03 Python
详解 Python中LEGB和闭包及装饰器
2017/08/03 Python
一条命令解决mac版本python IDLE不能输入中文问题
2018/05/15 Python
使用Django搭建一个基金模拟交易系统教程
2019/11/18 Python
Python 脚本的三种执行方式小结
2019/12/21 Python
Python Tornado实现WEB服务器Socket服务器共存并实现交互的方法
2020/05/26 Python
python中有函数重载吗
2020/05/28 Python
使用CSS3编写灰阶滤镜来制作黑白照片效果的方法
2016/05/09 HTML / CSS
日常奢侈品,轻松购物:Verishop
2019/08/20 全球购物
2014年转正工作总结
2014/11/08 职场文书
2014年初三班主任工作总结
2014/12/05 职场文书
圣诞晚会主持词
2015/07/01 职场文书
2016年教师节特级教师获奖感言
2015/12/09 职场文书
创业分两种人:那么哪些适合创业?,哪些适合不适合创业呢?
2019/08/23 职场文书
移除Selenium中window.navigator.webdriver值
2022/06/10 Python