解决torch.autograd.backward中的参数问题


Posted in Python onJanuary 07, 2020

torch.autograd.backward(variables, grad_variables=None, retain_graph=None, create_graph=False)

给定图的叶子节点variables, 计算图中变量的梯度和。 计算图可以通过链式法则求导。如果variables中的任何一个variable是 非标量(non-scalar)的,且requires_grad=True。那么此函数需要指定grad_variables,它的长度应该和variables的长度匹配,里面保存了相关variable的梯度(对于不需要gradient tensor的variable,None是可取的)。

此函数累积leaf variables计算的梯度。你可能需要在调用此函数之前将leaf variable的梯度置零。

参数:

variables(变量的序列) - 被求微分的叶子节点,即 ys 。

grad_variables((张量,变量)的序列或无) - 对应variable的梯度。仅当variable不是标量且需要求梯度的时候使用。

retain_graph(bool,可选) - 如果为False,则用于释放计算grad的图。请注意,在几乎所有情况下,没有必要将此选项设置为True,通常可以以更有效的方式解决。默认值为create_graph的值。

create_graph(bool,可选) - 如果为True,则将构造派生图,允许计算更高阶的派生产品。默认为False。

我这里举一个官方的例子

import torch
from torch.autograd import Variable
x = Variable(torch.ones(2, 2), requires_grad=True)
y = x + 2
z = y * y * 3
out = z.mean()
out.backward()#这里是默认情况,相当于out.backward(torch.Tensor([1.0]))
print(x.grad)

输出结果是

Variable containing:
 4.5000 4.5000
 4.5000 4.5000
[torch.FloatTensor of size 2x2]

解决torch.autograd.backward中的参数问题

接着我们继续

x = torch.randn(3)
x = Variable(x, requires_grad=True)

y = x * 2
while y.data.norm() < 1000:
  y = y * 2

gradients = torch.FloatTensor([0.1, 1.0, 0.0001])
y.backward(gradients)
print(x.grad)

输出结果是

Variable containing:
 204.8000
 2048.0000
  0.2048
[torch.FloatTensor of size 3]

这里这个gradients为什么要是[0.1, 1.0, 0.0001]?

如果输出的多个loss权重不同的话,例如有三个loss,一个是x loss,一个是y loss,一个是class loss。那么很明显的不可能所有loss对结果影响程度都一样,他们之间应该有一个比例。那么比例这里指的就是[0.1, 1.0, 0.0001],这个问题中的loss对应的就是上面说的y,那么这里的输出就很好理解了dy/dx=0.1*dy1/dx+1.0*dy2/dx+0.0001*dy3/dx。

如有问题,希望大家指正,谢谢_!

以上这篇解决torch.autograd.backward中的参数问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
安装Python的教程-Windows
Jul 22 Python
tensorflow 恢复指定层与不同层指定不同学习率的方法
Jul 26 Python
详解python while 函数及while和for的区别
Sep 07 Python
解决安装pycharm后不能执行python脚本的问题
Jan 19 Python
Python Datetime模块和Calendar模块用法实例分析
Apr 15 Python
python字符串分割及字符串的一些常规方法
Jul 24 Python
python 遍历pd.Series的index和value
Nov 26 Python
pytorch 改变tensor尺寸的实现
Jan 03 Python
Python调用接口合并Excel表代码实例
Mar 31 Python
Django调用百度AI接口实现人脸注册登录代码实例
Apr 23 Python
在keras里面实现计算f1-score的代码
Jun 15 Python
python爬虫用scrapy获取影片的实例分析
Nov 23 Python
Pytorch 中retain_graph的用法详解
Jan 07 #Python
PyTorch中的Variable变量详解
Jan 07 #Python
python enumerate内置函数用法总结
Jan 07 #Python
pytorch加载自定义网络权重的实现
Jan 07 #Python
Matplotlib绘制雷达图和三维图的示例代码
Jan 07 #Python
Pytorch 神经网络—自定义数据集上实现教程
Jan 07 #Python
浅谈Python访问MySQL的正确姿势
Jan 07 #Python
You might like
YB217、YB235、YB400浅听
2021/03/02 无线电
一个简单实现多条件查询的例子
2006/10/09 PHP
关于mysql 字段的那个点为是定界符
2007/01/15 PHP
php设置编码格式的方法
2013/03/05 PHP
php 根据URL下载远程图片、压缩包、pdf等文件到本地
2019/07/26 PHP
js window.event对象详尽解析
2009/02/17 Javascript
JavaScript的eval JSON object问题
2009/11/15 Javascript
40款非常有用的 jQuery 插件推荐(系列一)
2011/12/21 Javascript
P3P Header解决Cookie跨域的问题
2013/03/12 Javascript
深入理解javascript作用域和闭包
2014/09/23 Javascript
获取阴历(农历)和当前日期的js代码
2016/02/15 Javascript
JavaScript实现Base64编码转换
2016/04/23 Javascript
微信小程序的动画效果详解
2017/01/18 Javascript
vue中axios请求的封装实例代码
2019/03/23 Javascript
如何利用node.js开发一个生成逐帧动画的小工具
2019/12/01 Javascript
[01:10:24]DOTA2-DPC中国联赛 正赛 VG vs Aster BO3 第一场 2月28日
2021/03/11 DOTA
python OpenCV学习笔记之绘制直方图的方法
2018/02/08 Python
centos 安装python3.6环境并配置虚拟环境的详细教程
2018/02/22 Python
Redis使用watch完成秒杀抢购功能的代码
2018/05/07 Python
对python sklearn one-hot编码详解
2018/07/10 Python
PyCharm 创建指定版本的 Django(超详图解教程)
2019/06/18 Python
django-初始配置(纯手写)详解
2019/07/30 Python
Django admin禁用编辑链接和添加删除操作详解
2019/11/15 Python
flask的orm框架SQLAlchemy查询实现解析
2019/12/12 Python
pytorch绘制并显示loss曲线和acc曲线,LeNet5识别图像准确率
2020/01/02 Python
python 简单的调用有道翻译
2020/11/25 Python
详解python日志输出使用配置文件格式
2021/02/10 Python
中国双语服务优势的在线购票及活动平台:247tickets
2018/10/26 全球购物
JDO的含义
2012/11/17 面试题
群胜软件Java笔试题
2012/09/29 面试题
某同学的自我鉴定范文
2013/12/26 职场文书
四好少年事迹材料
2014/01/12 职场文书
运动会口号8字
2014/06/07 职场文书
奥巴马经典演讲稿
2014/09/13 职场文书
公司禁烟通知
2015/04/23 职场文书
七年级写作指导之游记作文
2019/10/07 职场文书