解决torch.autograd.backward中的参数问题


Posted in Python onJanuary 07, 2020

torch.autograd.backward(variables, grad_variables=None, retain_graph=None, create_graph=False)

给定图的叶子节点variables, 计算图中变量的梯度和。 计算图可以通过链式法则求导。如果variables中的任何一个variable是 非标量(non-scalar)的,且requires_grad=True。那么此函数需要指定grad_variables,它的长度应该和variables的长度匹配,里面保存了相关variable的梯度(对于不需要gradient tensor的variable,None是可取的)。

此函数累积leaf variables计算的梯度。你可能需要在调用此函数之前将leaf variable的梯度置零。

参数:

variables(变量的序列) - 被求微分的叶子节点,即 ys 。

grad_variables((张量,变量)的序列或无) - 对应variable的梯度。仅当variable不是标量且需要求梯度的时候使用。

retain_graph(bool,可选) - 如果为False,则用于释放计算grad的图。请注意,在几乎所有情况下,没有必要将此选项设置为True,通常可以以更有效的方式解决。默认值为create_graph的值。

create_graph(bool,可选) - 如果为True,则将构造派生图,允许计算更高阶的派生产品。默认为False。

我这里举一个官方的例子

import torch
from torch.autograd import Variable
x = Variable(torch.ones(2, 2), requires_grad=True)
y = x + 2
z = y * y * 3
out = z.mean()
out.backward()#这里是默认情况,相当于out.backward(torch.Tensor([1.0]))
print(x.grad)

输出结果是

Variable containing:
 4.5000 4.5000
 4.5000 4.5000
[torch.FloatTensor of size 2x2]

解决torch.autograd.backward中的参数问题

接着我们继续

x = torch.randn(3)
x = Variable(x, requires_grad=True)

y = x * 2
while y.data.norm() < 1000:
  y = y * 2

gradients = torch.FloatTensor([0.1, 1.0, 0.0001])
y.backward(gradients)
print(x.grad)

输出结果是

Variable containing:
 204.8000
 2048.0000
  0.2048
[torch.FloatTensor of size 3]

这里这个gradients为什么要是[0.1, 1.0, 0.0001]?

如果输出的多个loss权重不同的话,例如有三个loss,一个是x loss,一个是y loss,一个是class loss。那么很明显的不可能所有loss对结果影响程度都一样,他们之间应该有一个比例。那么比例这里指的就是[0.1, 1.0, 0.0001],这个问题中的loss对应的就是上面说的y,那么这里的输出就很好理解了dy/dx=0.1*dy1/dx+1.0*dy2/dx+0.0001*dy3/dx。

如有问题,希望大家指正,谢谢_!

以上这篇解决torch.autograd.backward中的参数问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python str与repr的区别
Mar 23 Python
python模拟enum枚举类型的方法小结
Apr 30 Python
Python封装shell命令实例分析
May 05 Python
Python实现PS图像抽象画风效果的方法
Jan 23 Python
python编程嵌套函数实例代码
Feb 11 Python
APIStar:一个专为Python3设计的API框架
Sep 26 Python
Django中的ajax请求
Oct 19 Python
python中对_init_的理解及实例解析
Oct 11 Python
关于Python3 lambda函数的深入浅出
Nov 27 Python
解决tensorflow由于未初始化变量而导致的错误问题
Jan 06 Python
浅谈Django QuerySet对象(模型.objects)的常用方法
Mar 28 Python
python 实现压缩和解压缩的示例
Sep 22 Python
Pytorch 中retain_graph的用法详解
Jan 07 #Python
PyTorch中的Variable变量详解
Jan 07 #Python
python enumerate内置函数用法总结
Jan 07 #Python
pytorch加载自定义网络权重的实现
Jan 07 #Python
Matplotlib绘制雷达图和三维图的示例代码
Jan 07 #Python
Pytorch 神经网络—自定义数据集上实现教程
Jan 07 #Python
浅谈Python访问MySQL的正确姿势
Jan 07 #Python
You might like
星际争霸 Starcraft 游戏介绍
2020/03/14 星际争霸
德生PL330测评
2021/03/02 无线电
随机广告显示(PHP函数)
2006/10/09 PHP
php实现无限级分类(递归方法)
2015/08/06 PHP
php遍历目录下文件并按修改时间排序操作示例
2019/07/12 PHP
一个选择最快的服务器转向代码
2009/04/27 Javascript
JavaScript中通过闭包解决只能取得包含函数中任何变量最后一个值的问题
2010/08/12 Javascript
jQuery EasyUI API 中文文档 - Parser 解析器
2011/09/29 Javascript
jquery validate poshytip 自定义样式
2012/11/26 Javascript
原生JS实现响应式瀑布流布局
2015/04/02 Javascript
基于JavaScript实现Json数据根据某个字段进行排序
2015/11/24 Javascript
详解基于Bootstrap扁平化的后台框架Ace
2015/11/27 Javascript
Javascript基础学习笔记(菜鸟必看篇)
2016/07/22 Javascript
原生JS实现垂直手风琴效果
2017/02/19 Javascript
jQuery实现判断上传图片类型和大小的方法示例
2018/04/11 jQuery
vue实现自定义多选与单选的答题功能
2018/07/05 Javascript
vue-cli的工程模板与构建工具详解
2018/09/27 Javascript
nodejs实现范围请求的实现代码
2018/10/12 NodeJs
require.js 加载过程与使用方法介绍
2018/10/30 Javascript
angular多语言配置详解
2019/05/16 Javascript
node.js使用net模块创建服务器和客户端示例【基于TCP协议】
2020/02/14 Javascript
Vue+Element-U实现分页显示效果
2020/11/15 Javascript
python33 urllib2使用方法细节讲解
2013/12/03 Python
详细解读Python中解析XML数据的方法
2015/10/15 Python
浅析python中while循环和for循环
2019/11/19 Python
基于python修改srt字幕的时间轴
2020/02/03 Python
Pyinstaller加密打包应用的示例代码
2020/06/11 Python
htnl5利用svg页面高斯模糊的方法
2018/07/20 HTML / CSS
Blank NYC官网:夹克、牛仔裤等
2020/12/16 全球购物
船舶专业个人求职信范文
2014/01/02 职场文书
会计专业职业规划:规划自我赢取未来
2014/02/12 职场文书
圣诞节活动策划方案
2014/06/09 职场文书
敬业奉献模范事迹材料(2016精选版)
2016/02/26 职场文书
2016年教育局“我们的节日——端午节”主题活动总结
2016/04/01 职场文书
申论不会写怎么办?教您掌握这6点思维和原则
2019/07/17 职场文书
JavaScript实现登录窗体
2021/06/22 Javascript