解决torch.autograd.backward中的参数问题


Posted in Python onJanuary 07, 2020

torch.autograd.backward(variables, grad_variables=None, retain_graph=None, create_graph=False)

给定图的叶子节点variables, 计算图中变量的梯度和。 计算图可以通过链式法则求导。如果variables中的任何一个variable是 非标量(non-scalar)的,且requires_grad=True。那么此函数需要指定grad_variables,它的长度应该和variables的长度匹配,里面保存了相关variable的梯度(对于不需要gradient tensor的variable,None是可取的)。

此函数累积leaf variables计算的梯度。你可能需要在调用此函数之前将leaf variable的梯度置零。

参数:

variables(变量的序列) - 被求微分的叶子节点,即 ys 。

grad_variables((张量,变量)的序列或无) - 对应variable的梯度。仅当variable不是标量且需要求梯度的时候使用。

retain_graph(bool,可选) - 如果为False,则用于释放计算grad的图。请注意,在几乎所有情况下,没有必要将此选项设置为True,通常可以以更有效的方式解决。默认值为create_graph的值。

create_graph(bool,可选) - 如果为True,则将构造派生图,允许计算更高阶的派生产品。默认为False。

我这里举一个官方的例子

import torch
from torch.autograd import Variable
x = Variable(torch.ones(2, 2), requires_grad=True)
y = x + 2
z = y * y * 3
out = z.mean()
out.backward()#这里是默认情况,相当于out.backward(torch.Tensor([1.0]))
print(x.grad)

输出结果是

Variable containing:
 4.5000 4.5000
 4.5000 4.5000
[torch.FloatTensor of size 2x2]

解决torch.autograd.backward中的参数问题

接着我们继续

x = torch.randn(3)
x = Variable(x, requires_grad=True)

y = x * 2
while y.data.norm() < 1000:
  y = y * 2

gradients = torch.FloatTensor([0.1, 1.0, 0.0001])
y.backward(gradients)
print(x.grad)

输出结果是

Variable containing:
 204.8000
 2048.0000
  0.2048
[torch.FloatTensor of size 3]

这里这个gradients为什么要是[0.1, 1.0, 0.0001]?

如果输出的多个loss权重不同的话,例如有三个loss,一个是x loss,一个是y loss,一个是class loss。那么很明显的不可能所有loss对结果影响程度都一样,他们之间应该有一个比例。那么比例这里指的就是[0.1, 1.0, 0.0001],这个问题中的loss对应的就是上面说的y,那么这里的输出就很好理解了dy/dx=0.1*dy1/dx+1.0*dy2/dx+0.0001*dy3/dx。

如有问题,希望大家指正,谢谢_!

以上这篇解决torch.autograd.backward中的参数问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python cookielib 登录人人网的实现代码
Dec 19 Python
Tornado服务器中绑定域名、虚拟主机的方法
Aug 22 Python
利用Python中的mock库对Python代码进行模拟测试
Apr 16 Python
在Python的Flask框架中使用模版的入门教程
Apr 20 Python
Python实现大文件排序的方法
Jul 10 Python
12步入门Python中的decorator装饰器使用方法
Jun 20 Python
Django中间件基础用法详解
Jul 18 Python
Python使用python-docx读写word文档
Aug 26 Python
python 实现二维列表转置
Dec 02 Python
ansible动态Inventory主机清单配置遇到的坑
Jan 19 Python
python 解决mysql where in 对列表(list,,array)问题
Jun 06 Python
图解Python中深浅copy(通俗易懂)
Sep 03 Python
Pytorch 中retain_graph的用法详解
Jan 07 #Python
PyTorch中的Variable变量详解
Jan 07 #Python
python enumerate内置函数用法总结
Jan 07 #Python
pytorch加载自定义网络权重的实现
Jan 07 #Python
Matplotlib绘制雷达图和三维图的示例代码
Jan 07 #Python
Pytorch 神经网络—自定义数据集上实现教程
Jan 07 #Python
浅谈Python访问MySQL的正确姿势
Jan 07 #Python
You might like
网页游戏开发入门教程三(简单程序应用)
2009/11/02 PHP
php几个预定义变量$_SERVER用法小结
2014/11/07 PHP
php通过文件流方式复制文件的方法
2015/03/13 PHP
基于php的微信公众平台开发入门实例
2015/04/15 PHP
php使用glob函数遍历文件和目录详解
2016/09/23 PHP
用window.location.href实现刷新另个框架页面
2007/03/07 Javascript
window.parent调用父框架时 ie跟火狐不兼容问题
2009/07/30 Javascript
html+javascript实现可拖动可提交的弹出层对话框效果
2013/08/05 Javascript
js判断undefined类型,undefined,null, 的区别详细解析
2013/12/16 Javascript
Jquery的each里用return true或false代替break或continue
2014/05/21 Javascript
angularJS 中$attrs方法使用指南
2015/02/09 Javascript
AngularJS手动表单验证
2016/02/01 Javascript
一个字符串中出现次数最多的字符 统计这个次数【实现代码】
2016/04/29 Javascript
jQuery的实例及必知重要的jQuery选择器详解
2016/05/20 Javascript
微信 java 实现js-sdk 图片上传下载完整流程
2016/10/21 Javascript
基于javascript实现的快速排序
2016/12/02 Javascript
图片懒加载插件实例分享(含解析)
2017/01/09 Javascript
老生常谈angularjs中的$state.go
2017/04/24 Javascript
JavaScript基于扩展String实现替换字符串中index处字符的方法
2017/06/13 Javascript
在vue中获取dom元素内容的方法
2017/07/10 Javascript
Vue实现点击时间获取时间段查询功能
2020/08/21 Javascript
SVG实现时钟效果
2018/07/17 Javascript
koa+mongoose实现简单增删改查接口的示例代码
2019/05/13 Javascript
基于Python安装pyecharts所遇的问题及解决方法
2019/08/12 Python
python爬虫快速响应服务器的做法
2020/11/24 Python
购买中国最好的电子产品:Geekbuying
2018/03/13 全球购物
Ajax实现页面无刷新留言效果
2021/03/24 Javascript
人事主管的岗位职责
2013/11/16 职场文书
教育专业个人求职信
2013/12/02 职场文书
光荣入党自我鉴定
2014/01/22 职场文书
乡党委干部党的群众路线教育实践活动个人对照检查材料思想汇报
2014/10/01 职场文书
渠道运营商合作协议书范本
2014/10/06 职场文书
教师学习三严三实心得体会
2014/10/13 职场文书
党员对十八届四中全会的期盼思想汇报范文
2014/10/17 职场文书
如何写好闭幕词
2019/04/02 职场文书
python神经网络学习 使用Keras进行回归运算
2022/05/04 Python