解决torch.autograd.backward中的参数问题


Posted in Python onJanuary 07, 2020

torch.autograd.backward(variables, grad_variables=None, retain_graph=None, create_graph=False)

给定图的叶子节点variables, 计算图中变量的梯度和。 计算图可以通过链式法则求导。如果variables中的任何一个variable是 非标量(non-scalar)的,且requires_grad=True。那么此函数需要指定grad_variables,它的长度应该和variables的长度匹配,里面保存了相关variable的梯度(对于不需要gradient tensor的variable,None是可取的)。

此函数累积leaf variables计算的梯度。你可能需要在调用此函数之前将leaf variable的梯度置零。

参数:

variables(变量的序列) - 被求微分的叶子节点,即 ys 。

grad_variables((张量,变量)的序列或无) - 对应variable的梯度。仅当variable不是标量且需要求梯度的时候使用。

retain_graph(bool,可选) - 如果为False,则用于释放计算grad的图。请注意,在几乎所有情况下,没有必要将此选项设置为True,通常可以以更有效的方式解决。默认值为create_graph的值。

create_graph(bool,可选) - 如果为True,则将构造派生图,允许计算更高阶的派生产品。默认为False。

我这里举一个官方的例子

import torch
from torch.autograd import Variable
x = Variable(torch.ones(2, 2), requires_grad=True)
y = x + 2
z = y * y * 3
out = z.mean()
out.backward()#这里是默认情况,相当于out.backward(torch.Tensor([1.0]))
print(x.grad)

输出结果是

Variable containing:
 4.5000 4.5000
 4.5000 4.5000
[torch.FloatTensor of size 2x2]

解决torch.autograd.backward中的参数问题

接着我们继续

x = torch.randn(3)
x = Variable(x, requires_grad=True)

y = x * 2
while y.data.norm() < 1000:
  y = y * 2

gradients = torch.FloatTensor([0.1, 1.0, 0.0001])
y.backward(gradients)
print(x.grad)

输出结果是

Variable containing:
 204.8000
 2048.0000
  0.2048
[torch.FloatTensor of size 3]

这里这个gradients为什么要是[0.1, 1.0, 0.0001]?

如果输出的多个loss权重不同的话,例如有三个loss,一个是x loss,一个是y loss,一个是class loss。那么很明显的不可能所有loss对结果影响程度都一样,他们之间应该有一个比例。那么比例这里指的就是[0.1, 1.0, 0.0001],这个问题中的loss对应的就是上面说的y,那么这里的输出就很好理解了dy/dx=0.1*dy1/dx+1.0*dy2/dx+0.0001*dy3/dx。

如有问题,希望大家指正,谢谢_!

以上这篇解决torch.autograd.backward中的参数问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现的简单dns查询功能示例
May 24 Python
python监控文件并且发送告警邮件
Jun 21 Python
python实现维吉尼亚加密法
Mar 20 Python
python中正则表达式与模式匹配
May 07 Python
Python自定义函数计算给定日期是该年第几天的方法示例
May 30 Python
利用selenium爬虫抓取数据的基础教程
Jun 10 Python
linux中如何使用python3获取ip地址
Jul 15 Python
使用python+whoosh实现全文检索
Dec 09 Python
Python爬虫Scrapy框架CrawlSpider原理及使用案例
Nov 20 Python
在python中实现导入一个需要传参的模块
May 12 Python
Python多个MP4合成视频的实现方法
Jul 16 Python
Python内置的数据类型及使用方法
Apr 13 Python
Pytorch 中retain_graph的用法详解
Jan 07 #Python
PyTorch中的Variable变量详解
Jan 07 #Python
python enumerate内置函数用法总结
Jan 07 #Python
pytorch加载自定义网络权重的实现
Jan 07 #Python
Matplotlib绘制雷达图和三维图的示例代码
Jan 07 #Python
Pytorch 神经网络—自定义数据集上实现教程
Jan 07 #Python
浅谈Python访问MySQL的正确姿势
Jan 07 #Python
You might like
php使用ob_flush不能每隔一秒输出原理分析
2015/06/02 PHP
彻底搞懂PHP 变量结构体
2017/10/11 PHP
Laravel 对某一列进行筛选然后求和sum()的例子
2019/10/10 PHP
不要小看注释掉的JS 引起的安全问题
2008/12/27 Javascript
juqery 学习之五 文档处理 插入
2011/02/11 Javascript
JS截取字符串常用方法整理及使用示例
2013/10/18 Javascript
ES6所改良的javascript“缺陷”问题
2016/08/23 Javascript
JavaScript 函数模式详解及示例
2016/09/07 Javascript
JavaScript编写九九乘法表(两种任选)
2017/02/04 Javascript
javascript 组合按键事件监听实现代码
2017/02/21 Javascript
JSON与JS对象的区别与对比
2017/03/01 Javascript
将angular-ui的分页组件封装成指令的方法详解
2017/05/10 Javascript
JavaScript对象_动力节点Java学院整理
2017/06/23 Javascript
js仿微信抢红包功能
2020/09/25 Javascript
nodejs实现简单的gulp打包
2017/12/21 NodeJs
vue+iview+less 实现换肤功能
2018/08/17 Javascript
python虚拟环境 virtualenv的简单使用
2020/01/21 Javascript
原生javascript中this几种常见用法总结
2020/02/24 Javascript
[57:47]Fnatic vs Winstrike 2018国际邀请赛小组赛BO2 第二场 8.18
2018/08/19 DOTA
Python中Continue语句的用法的举例详解
2015/05/14 Python
python使用PyGame模块播放声音的方法
2015/05/20 Python
python: line=f.readlines()消除line中\n的方法
2018/03/19 Python
Django 查询数据库并返回页面的例子
2019/08/12 Python
django 配置阿里云OSS存储media文件的例子
2019/08/20 Python
python修改FTP服务器上的文件名
2019/09/11 Python
win10安装tensorflow-gpu1.8.0详细完整步骤
2020/01/20 Python
tensorflow查看ckpt各节点名称实例
2020/01/21 Python
python 中的[:-1]和[::-1]的具体使用
2020/02/13 Python
vscode配置anaconda3的方法步骤
2020/08/08 Python
美国最灵活的移动提供商:Tello
2017/07/18 全球购物
美味咖啡的顶级烘焙师:Cafe Britt
2018/03/15 全球购物
银行竞聘演讲稿范文
2014/04/23 职场文书
给校长的建议书500字
2014/05/15 职场文书
劳动竞赛口号
2014/06/16 职场文书
2014年除四害工作总结
2014/12/06 职场文书
SQLServer2019 数据库的基本使用之图形化界面操作的实现
2021/04/08 SQL Server