Pytorch 中retain_graph的用法详解


Posted in Python onJanuary 07, 2020

用法分析

在查看SRGAN源码时有如下损失函数,其中设置了retain_graph=True,其作用是什么?

############################
    # (1) Update D network: maximize D(x)-1-D(G(z))
    ###########################
    real_img = Variable(target)
    if torch.cuda.is_available():
      real_img = real_img.cuda()
    z = Variable(data)
    if torch.cuda.is_available():
      z = z.cuda()
    fake_img = netG(z)

    netD.zero_grad()
    real_out = netD(real_img).mean()
    fake_out = netD(fake_img).mean()
    d_loss = 1 - real_out + fake_out
    d_loss.backward(retain_graph=True) #####
    optimizerD.step()

    ############################
    # (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss
    ###########################
    netG.zero_grad()
    g_loss = generator_criterion(fake_out, fake_img, real_img)
    g_loss.backward()
    optimizerG.step()
    fake_img = netG(z)
    fake_out = netD(fake_img).mean()

    g_loss = generator_criterion(fake_out, fake_img, real_img)
    running_results['g_loss'] += g_loss.data[0] * batch_size
    d_loss = 1 - real_out + fake_out
    running_results['d_loss'] += d_loss.data[0] * batch_size
    running_results['d_score'] += real_out.data[0] * batch_size
    running_results['g_score'] += fake_out.data[0] * batch_size

在更新D网络时的loss反向传播过程中使用了retain_graph=True,目的为是为保留该过程中计算的梯度,后续G网络更新时使用;

其实retain_graph这个参数在平常中我们是用不到的,但是在特殊的情况下我们会用到它,

如下代码:

import torch
y=x**2
z=y*4
output1=z.mean()
output2=z.sum()
output1.backward()
output2.backward()

输出如下错误信息:

---------------------------------------------------------------------------
RuntimeError               Traceback (most recent call last)
<ipython-input-19-8ad6b0658906> in <module>()
----> 1 output1.backward()
   2 output2.backward()

D:\ProgramData\Anaconda3\lib\site-packages\torch\tensor.py in backward(self, gradient, retain_graph, create_graph)
   91         products. Defaults to ``False``.
   92     """
---> 93     torch.autograd.backward(self, gradient, retain_graph, create_graph)
   94 
   95   def register_hook(self, hook):

D:\ProgramData\Anaconda3\lib\site-packages\torch\autograd\__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
   88   Variable._execution_engine.run_backward(
   89     tensors, grad_tensors, retain_graph, create_graph,
---> 90     allow_unreachable=True) # allow_unreachable flag
   91 
   92 

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

修改成如下正确:

import torch
y=x**2
z=y*4
output1=z.mean()
output2=z.sum()
output1.backward(retain_graph=True)
output2.backward()
# 假如你有两个Loss,先执行第一个的backward,再执行第二个backward
loss1.backward(retain_graph=True)
loss2.backward() # 执行完这个后,所有中间变量都会被释放,以便下一次的循环
optimizer.step() # 更新参数

Variable 类源代码

class Variable(_C._VariableBase):
 
  """
  Attributes:
    data: 任意类型的封装好的张量。
    grad: 保存与data类型和位置相匹配的梯度,此属性难以分配并且不能重新分配。
    requires_grad: 标记变量是否已经由一个需要调用到此变量的子图创建的bool值。只能在叶子变量上进行修改。
    volatile: 标记变量是否能在推理模式下应用(如不保存历史记录)的bool值。只能在叶变量上更改。
    is_leaf: 标记变量是否是图叶子(如由用户创建的变量)的bool值.
    grad_fn: Gradient function graph trace.
 
  Parameters:
    data (any tensor class): 要包装的张量.
    requires_grad (bool): bool型的标记值. **Keyword only.**
    volatile (bool): bool型的标记值. **Keyword only.**
  """
 
  def backward(self, gradient=None, retain_graph=None, create_graph=None, retain_variables=None):
    """计算关于当前图叶子变量的梯度,图使用链式法则导致分化
    如果Variable是一个标量(例如它包含一个单元素数据),你无需对backward()指定任何参数
    如果变量不是标量(包含多个元素数据的矢量)且需要梯度,函数需要额外的梯度;
    需要指定一个和tensor的形状匹配的grad_output参数(y在指定方向投影对x的导数);
    可以是一个类型和位置相匹配且包含与自身相关的不同函数梯度的张量。
    函数在叶子上累积梯度,调用前需要对该叶子进行清零。
 
    Arguments:
      grad_variables (Tensor, Variable or None):
              变量的梯度,如果是一个张量,除非“create_graph”是True,否则会自动转换成volatile型的变量。
              可以为标量变量或不需要grad的值指定None值。如果None值可接受,则此参数可选。
      retain_graph (bool, optional): 如果为False,用来计算梯度的图将被释放。
                      在几乎所有情况下,将此选项设置为True不是必需的,通常可以以更有效的方式解决。
                      默认值为create_graph的值。
      create_graph (bool, optional): 为True时,会构造一个导数的图,用来计算出更高阶导数结果。
                      默认为False,除非``gradient``是一个volatile变量。
    """
    torch.autograd.backward(self, gradient, retain_graph, create_graph, retain_variables)
 
 
  def register_hook(self, hook):
    """Registers a backward hook.
 
    每当与variable相关的梯度被计算时调用hook,hook的申明:hook(grad)->Variable or None
    不能对hook的参数进行修改,但可以选择性地返回一个新的梯度以用在`grad`的相应位置。
 
    函数返回一个handle,其``handle.remove()``方法用于将hook从模块中移除。
 
    Example:
      >>> v = Variable(torch.Tensor([0, 0, 0]), requires_grad=True)
      >>> h = v.register_hook(lambda grad: grad * 2) # double the gradient
      >>> v.backward(torch.Tensor([1, 1, 1]))
      >>> v.grad.data
       2
       2
       2
      [torch.FloatTensor of size 3]
      >>> h.remove() # removes the hook
    """
    if self.volatile:
      raise RuntimeError("cannot register a hook on a volatile variable")
    if not self.requires_grad:
      raise RuntimeError("cannot register a hook on a variable that "
                "doesn't require gradient")
    if self._backward_hooks is None:
      self._backward_hooks = OrderedDict()
      if self.grad_fn is not None:
        self.grad_fn._register_hook_dict(self)
    handle = hooks.RemovableHandle(self._backward_hooks)
    self._backward_hooks[handle.id] = hook
    return handle
 
  def reinforce(self, reward):
    """Registers a reward obtained as a result of a stochastic process.
    区分随机节点需要为他们提供reward值。如果图表中包含任何的随机操作,都应该在其输出上调用此函数,否则会出现错误。
    Parameters:
      reward(Tensor): 带有每个元素奖赏的张量,必须与Variable数据的设备位置和形状相匹配。
    """
    if not isinstance(self.grad_fn, StochasticFunction):
      raise RuntimeError("reinforce() can be only called on outputs "
                "of stochastic functions")
    self.grad_fn._reinforce(reward)
 
  def detach(self):
    """返回一个从当前图分离出来的心变量。
    结果不需要梯度,如果输入是volatile,则输出也是volatile。
 
    .. 注意::
     返回变量使用与原始变量相同的数据张量,并且可以看到其中任何一个的就地修改,并且可能会触发正确性检查中的错误。
    """
    result = NoGrad()(self) # this is needed, because it merges version counters
    result._grad_fn = None
    return result
 
  def detach_(self):
    """从创建它的图中分离出变量并作为该图的一个叶子"""
    self._grad_fn = None
    self.requires_grad = False
 
  def retain_grad(self):
    """Enables .grad attribute for non-leaf Variables."""
    if self.grad_fn is None: # no-op for leaves
      return
    if not self.requires_grad:
      raise RuntimeError("can't retain_grad on Variable that has requires_grad=False")
    if hasattr(self, 'retains_grad'):
      return
    weak_self = weakref.ref(self)
 
    def retain_grad_hook(grad):
      var = weak_self()
      if var is None:
        return
      if var._grad is None:
        var._grad = grad.clone()
      else:
        var._grad = var._grad + grad
 
    self.register_hook(retain_grad_hook)
    self.retains_grad = True

以上这篇Pytorch 中retain_graph的用法详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
布同 Python中文问题解决方法(总结了多位前人经验,初学者必看)
Mar 13 Python
python实现的生成随机迷宫算法核心代码分享(含游戏完整代码)
Jul 11 Python
Python中的if、else、elif语句用法简明讲解
Mar 11 Python
python django使用haystack:全文检索的框架(实例讲解)
Sep 27 Python
Python数据处理numpy.median的实例讲解
Apr 02 Python
对Python3中的input函数详解
Apr 22 Python
Python socket套接字实现C/S模式远程命令执行功能案例
Jul 06 Python
Python3模拟curl发送post请求操作示例
May 03 Python
使用Python3 poplib模块删除服务器多天前的邮件实现代码
Apr 24 Python
python的数学算法函数及公式用法
Nov 18 Python
使用Python实现音频双通道分离
Dec 25 Python
Python使用OpenCV和K-Means聚类对毕业照进行图像分割
Jun 11 Python
PyTorch中的Variable变量详解
Jan 07 #Python
python enumerate内置函数用法总结
Jan 07 #Python
pytorch加载自定义网络权重的实现
Jan 07 #Python
Matplotlib绘制雷达图和三维图的示例代码
Jan 07 #Python
Pytorch 神经网络—自定义数据集上实现教程
Jan 07 #Python
浅谈Python访问MySQL的正确姿势
Jan 07 #Python
pytorch自定义二值化网络层方式
Jan 07 #Python
You might like
我的论坛源代码(七)
2006/10/09 PHP
浅析PHP原理之变量(Variables inside PHP)
2013/08/09 PHP
php中session退出登陆问题
2014/02/27 PHP
CI框架实现框架前后端分离的方法详解
2016/12/30 PHP
PHP设计模式之工厂方法设计模式实例分析
2018/04/25 PHP
Laravel5.3+框架定义API路径取消CSRF保护方法详解
2020/04/06 PHP
firefox下jQuery UI Autocomplete 1.8.*中文输入修正方法
2012/09/19 Javascript
yarn与npm的命令行小结
2016/10/20 Javascript
基于jPlayer三分屏的制作方法
2016/12/21 Javascript
nodejs和php实现图片访问实时处理
2017/01/05 NodeJs
Node.js的特点详解
2017/02/03 Javascript
JavaScript基于数组实现的栈与队列操作示例
2018/12/22 Javascript
微信小程序实现的日期午别医生排班表功能示例
2019/01/09 Javascript
Vue项目前后端联调(使用proxyTable实现跨域方式)
2020/07/18 Javascript
Python设计模式之观察者模式简单示例
2018/01/10 Python
Python并行分布式框架Celery详解
2018/10/15 Python
Python创建字典的八种方式
2019/02/27 Python
浅谈Python爬虫基本套路
2019/03/25 Python
pytorch中tensor张量数据类型的转化方式
2019/12/31 Python
如何用Python绘制3D柱形图
2020/09/16 Python
PyCharm 2020.1版安装破解注册码永久激活(激活到2089年)
2020/09/24 Python
HTML5 input placeholder 颜色修改示例
2014/05/30 HTML / CSS
解锁canvas导出图片跨域的N种姿势小结
2019/01/24 HTML / CSS
小程序瀑布流解决左右两边高度差距过大的问题
2019/02/20 HTML / CSS
Theory美国官网:后现代都市风时装品牌
2018/05/09 全球购物
英国门销售网站:Green Tree Doors
2020/01/07 全球购物
女大学生个人求职信
2013/12/09 职场文书
编辑求职信样本
2013/12/16 职场文书
关于读书的演讲稿
2014/05/07 职场文书
营业用房租赁协议书
2014/11/26 职场文书
2015年妇幼卫生工作总结
2015/05/23 职场文书
大学生心理健康教育心得体会
2016/01/12 职场文书
企业管理制度设计时要注意的几种“常见病”!
2019/04/19 职场文书
用人单位的规章制度,怎样制定才是有效的?
2019/07/09 职场文书
请学会珍惜眼前,因为人生没有下辈子!
2019/11/12 职场文书
Python快速优雅的批量修改Word文档样式
2021/05/20 Python