解决torch.autograd.backward中的参数问题


Posted in Python onJanuary 07, 2020

torch.autograd.backward(variables, grad_variables=None, retain_graph=None, create_graph=False)

给定图的叶子节点variables, 计算图中变量的梯度和。 计算图可以通过链式法则求导。如果variables中的任何一个variable是 非标量(non-scalar)的,且requires_grad=True。那么此函数需要指定grad_variables,它的长度应该和variables的长度匹配,里面保存了相关variable的梯度(对于不需要gradient tensor的variable,None是可取的)。

此函数累积leaf variables计算的梯度。你可能需要在调用此函数之前将leaf variable的梯度置零。

参数:

variables(变量的序列) - 被求微分的叶子节点,即 ys 。

grad_variables((张量,变量)的序列或无) - 对应variable的梯度。仅当variable不是标量且需要求梯度的时候使用。

retain_graph(bool,可选) - 如果为False,则用于释放计算grad的图。请注意,在几乎所有情况下,没有必要将此选项设置为True,通常可以以更有效的方式解决。默认值为create_graph的值。

create_graph(bool,可选) - 如果为True,则将构造派生图,允许计算更高阶的派生产品。默认为False。

我这里举一个官方的例子

import torch
from torch.autograd import Variable
x = Variable(torch.ones(2, 2), requires_grad=True)
y = x + 2
z = y * y * 3
out = z.mean()
out.backward()#这里是默认情况,相当于out.backward(torch.Tensor([1.0]))
print(x.grad)

输出结果是

Variable containing:
 4.5000 4.5000
 4.5000 4.5000
[torch.FloatTensor of size 2x2]

解决torch.autograd.backward中的参数问题

接着我们继续

x = torch.randn(3)
x = Variable(x, requires_grad=True)

y = x * 2
while y.data.norm() < 1000:
  y = y * 2

gradients = torch.FloatTensor([0.1, 1.0, 0.0001])
y.backward(gradients)
print(x.grad)

输出结果是

Variable containing:
 204.8000
 2048.0000
  0.2048
[torch.FloatTensor of size 3]

这里这个gradients为什么要是[0.1, 1.0, 0.0001]?

如果输出的多个loss权重不同的话,例如有三个loss,一个是x loss,一个是y loss,一个是class loss。那么很明显的不可能所有loss对结果影响程度都一样,他们之间应该有一个比例。那么比例这里指的就是[0.1, 1.0, 0.0001],这个问题中的loss对应的就是上面说的y,那么这里的输出就很好理解了dy/dx=0.1*dy1/dx+1.0*dy2/dx+0.0001*dy3/dx。

如有问题,希望大家指正,谢谢_!

以上这篇解决torch.autograd.backward中的参数问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python3基础之输入和输出实例分析
Aug 18 Python
python使用chardet判断字符串编码的方法
Mar 13 Python
Python yield 使用浅析
May 28 Python
深入理解Python中装饰器的用法
Jun 28 Python
Django使用HttpResponse返回图片并显示的方法
May 22 Python
详解python使用pip安装第三方库(工具包)速度慢、超时、失败的解决方案
Dec 02 Python
python实现AES加密和解密
Mar 27 Python
Django实现微信小程序的登录验证功能并维护登录态
Jul 04 Python
python爬虫 爬取超清壁纸代码实例
Aug 16 Python
Python判断字符串是否为空和null方法实例
Apr 26 Python
python使用正则表达式匹配txt特定字符串(有换行)
Dec 09 Python
python识别围棋定位棋盘位置
Jul 26 Python
Pytorch 中retain_graph的用法详解
Jan 07 #Python
PyTorch中的Variable变量详解
Jan 07 #Python
python enumerate内置函数用法总结
Jan 07 #Python
pytorch加载自定义网络权重的实现
Jan 07 #Python
Matplotlib绘制雷达图和三维图的示例代码
Jan 07 #Python
Pytorch 神经网络—自定义数据集上实现教程
Jan 07 #Python
浅谈Python访问MySQL的正确姿势
Jan 07 #Python
You might like
PHP 内存缓存加速功能memcached安装与用法
2009/09/03 PHP
php 获取本地IP代码
2013/06/23 PHP
php使用array_rand()函数从数组中随机选择一个或多个元素
2014/04/28 PHP
PHP 将数组打乱 shuffle函数的用法及简单实例
2016/06/17 PHP
js插件方式打开pdf文件(浏览器pdf插件分享)
2013/12/20 Javascript
JS常见问题之为什么点击弹出的i总是最后一个
2016/01/05 Javascript
推荐VSCode 上特别好用的 Vue 插件之vetur
2017/09/14 Javascript
vue用addRoutes实现动态路由的示例
2017/09/15 Javascript
vue.js给动态绑定的radio列表做批量编辑的方法
2018/02/28 Javascript
如何使用electron-builder及electron-updater给项目配置自动更新
2018/12/24 Javascript
实例讲解JS中pop使用方法
2019/01/27 Javascript
深入学习JavaScript 高阶函数
2019/06/11 Javascript
JavaScript Tab菜单实现过程解析
2020/05/13 Javascript
Node.js API详解之 dns模块用法实例分析
2020/05/15 Javascript
微信小程序自定义弹出层效果
2020/05/26 Javascript
使用js原生实现年份轮播选择效果实例
2021/01/12 Javascript
[03:09]DOTA2亚洲邀请赛 LGD战队出场宣传片
2015/02/07 DOTA
Python  unittest单元测试框架的使用
2018/09/08 Python
Python 自动登录淘宝并保存登录信息的方法
2019/09/04 Python
django实现支付宝支付实例讲解
2019/10/17 Python
Python开发之pip安装及使用方法详解
2020/02/21 Python
在Ubuntu 20.04中安装Pycharm 2020.1的图文教程
2020/04/30 Python
css3背景图片透明叠加属性cross-fade简介及用法实例
2013/01/08 HTML / CSS
两种CSS3伪类选择器详细介绍
2013/12/24 HTML / CSS
使用CSS3来绘制一个月食图案
2015/07/18 HTML / CSS
Origins加拿大官网:雅诗兰黛集团高端植物护肤品牌
2017/11/19 全球购物
Vans奥地利官方网站:美国原创极限运动潮牌
2018/09/30 全球购物
求职推荐信
2013/10/28 职场文书
教师应聘自荐信范文
2014/03/14 职场文书
团党委领导干部党的群众路线教育实践活动个人对照检查材料思想汇
2014/10/05 职场文书
保送生自荐信
2015/03/06 职场文书
大学升旗仪式主持词
2015/07/04 职场文书
商业计划书如何写?关键问题有哪些?
2019/07/11 职场文书
Python竟然能剪辑视频
2021/05/25 Python
常用的Python代码调试工具总结
2021/06/23 Python
SpringBoot中获取profile的方法详解
2022/04/08 Java/Android