解决torch.autograd.backward中的参数问题


Posted in Python onJanuary 07, 2020

torch.autograd.backward(variables, grad_variables=None, retain_graph=None, create_graph=False)

给定图的叶子节点variables, 计算图中变量的梯度和。 计算图可以通过链式法则求导。如果variables中的任何一个variable是 非标量(non-scalar)的,且requires_grad=True。那么此函数需要指定grad_variables,它的长度应该和variables的长度匹配,里面保存了相关variable的梯度(对于不需要gradient tensor的variable,None是可取的)。

此函数累积leaf variables计算的梯度。你可能需要在调用此函数之前将leaf variable的梯度置零。

参数:

variables(变量的序列) - 被求微分的叶子节点,即 ys 。

grad_variables((张量,变量)的序列或无) - 对应variable的梯度。仅当variable不是标量且需要求梯度的时候使用。

retain_graph(bool,可选) - 如果为False,则用于释放计算grad的图。请注意,在几乎所有情况下,没有必要将此选项设置为True,通常可以以更有效的方式解决。默认值为create_graph的值。

create_graph(bool,可选) - 如果为True,则将构造派生图,允许计算更高阶的派生产品。默认为False。

我这里举一个官方的例子

import torch
from torch.autograd import Variable
x = Variable(torch.ones(2, 2), requires_grad=True)
y = x + 2
z = y * y * 3
out = z.mean()
out.backward()#这里是默认情况,相当于out.backward(torch.Tensor([1.0]))
print(x.grad)

输出结果是

Variable containing:
 4.5000 4.5000
 4.5000 4.5000
[torch.FloatTensor of size 2x2]

解决torch.autograd.backward中的参数问题

接着我们继续

x = torch.randn(3)
x = Variable(x, requires_grad=True)

y = x * 2
while y.data.norm() < 1000:
  y = y * 2

gradients = torch.FloatTensor([0.1, 1.0, 0.0001])
y.backward(gradients)
print(x.grad)

输出结果是

Variable containing:
 204.8000
 2048.0000
  0.2048
[torch.FloatTensor of size 3]

这里这个gradients为什么要是[0.1, 1.0, 0.0001]?

如果输出的多个loss权重不同的话,例如有三个loss,一个是x loss,一个是y loss,一个是class loss。那么很明显的不可能所有loss对结果影响程度都一样,他们之间应该有一个比例。那么比例这里指的就是[0.1, 1.0, 0.0001],这个问题中的loss对应的就是上面说的y,那么这里的输出就很好理解了dy/dx=0.1*dy1/dx+1.0*dy2/dx+0.0001*dy3/dx。

如有问题,希望大家指正,谢谢_!

以上这篇解决torch.autograd.backward中的参数问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python入门_条件控制(详解)
May 16 Python
python 实现UTC时间加减的方法
Dec 31 Python
python控制nao机器人身体动作实例详解
Apr 29 Python
python 判断字符串中是否含有汉字或非汉字的实例
Jul 15 Python
python 列表推导式使用详解
Aug 29 Python
Django之PopUp的具体实现方法
Aug 31 Python
解决Pycharm 导入其他文件夹源码的2种方法
Feb 12 Python
解决Django中checkbox复选框的传值问题
Mar 31 Python
使用matplotlib动态刷新指定曲线实例
Apr 23 Python
Python基于Socket实现简易多人聊天室的示例代码
Nov 29 Python
Python实现区域填充的示例代码
Feb 03 Python
能让Python提速超40倍的神器Cython详解
Jun 24 Python
Pytorch 中retain_graph的用法详解
Jan 07 #Python
PyTorch中的Variable变量详解
Jan 07 #Python
python enumerate内置函数用法总结
Jan 07 #Python
pytorch加载自定义网络权重的实现
Jan 07 #Python
Matplotlib绘制雷达图和三维图的示例代码
Jan 07 #Python
Pytorch 神经网络—自定义数据集上实现教程
Jan 07 #Python
浅谈Python访问MySQL的正确姿势
Jan 07 #Python
You might like
PHP无限分类(树形类)
2013/09/28 PHP
5种PHP创建数组的实例代码分享
2014/01/17 PHP
php利用事务处理转账问题
2015/04/22 PHP
Laravel网站打开速度优化的方法汇总
2017/07/16 PHP
PHP实现Markdown文章上传到七牛图床的实例内容
2020/02/11 PHP
jQuery源码分析-04 选择器-Sizzle-工作原理分析
2011/11/14 Javascript
超棒的响应式布局jQuery插件Freetile.js
2014/11/17 Javascript
JavaScript中的原型链prototype介绍
2014/12/30 Javascript
javascript动态添加删除tabs标签的方法
2015/07/06 Javascript
浅谈javascript原型链与继承
2015/07/13 Javascript
js实现基于正则表达式的轻量提示插件
2015/08/29 Javascript
微信小程序开发之视频播放器 Video 弹幕 弹幕颜色自定义实例
2016/12/08 Javascript
简单实现jQuery级联菜单
2017/01/09 Javascript
JS图片延迟加载插件LazyImgv1.0用法分析【附demo源码下载】
2017/09/04 Javascript
JavaScript获取移动设备型号的实现代码(JS获取手机型号和系统)
2018/03/10 Javascript
20个最常见的jQuery面试问题及答案
2018/05/23 jQuery
ztree加载完成后显示勾选节点的实现代码
2018/10/22 Javascript
vue element-ui table组件动态生成表头和数据并修改单元格格式 父子组件通信
2019/08/15 Javascript
layer更改皮肤的实现方法
2019/09/11 Javascript
vue项目,代码提交至码云,iconfont的用法说明
2020/07/30 Javascript
Python datetime时间格式化去掉前导0
2014/07/31 Python
在Django中编写模版节点及注册标签的方法
2015/07/20 Python
Django之PopUp的具体实现方法
2019/08/31 Python
Python: tkinter窗口屏幕居中,设置窗口最大,最小尺寸实例
2020/03/04 Python
Python venv虚拟环境配置过程解析
2020/07/08 Python
详解使用canvas保存网页为pdf文件支持跨域
2018/11/23 HTML / CSS
荷兰音乐会和音乐剧门票订购网站:Topticketshop
2019/08/27 全球购物
应届生幼儿园求职信
2013/11/12 职场文书
行政人员工作职责
2013/12/05 职场文书
校园新闻广播稿
2014/01/10 职场文书
敬老院献爱心活动总结
2014/07/08 职场文书
2014年数学教师工作总结
2014/12/03 职场文书
升学宴学生答谢词
2015/01/05 职场文书
会计人员岗位职责
2015/02/03 职场文书
地道战观后感500字
2015/06/04 职场文书
Nginx本地目录映射实现代码实例
2021/03/31 Servers