在pytorch中对非叶节点的变量计算梯度实例


Posted in Python onJanuary 10, 2020

在pytorch中一般只对叶节点进行梯度计算,也就是下图中的d,e节点,而对非叶节点,也即是c,b节点则没有显式地去保留其中间计算过程中的梯度(因为一般来说只有叶节点才需要去更新),这样可以节省很大部分的显存,但是在调试过程中,有时候我们需要对中间变量梯度进行监控,以确保网络的有效性,这个时候我们需要打印出非叶节点的梯度,为了实现这个目的,我们可以通过两种手段进行。

在pytorch中对非叶节点的变量计算梯度实例

注册hook函数

Tensor.register_hook[2] 可以注册一个反向梯度传导时的hook函数,这个hook函数将会在每次计算 关于该张量 在pytorch中对非叶节点的变量计算梯度实例 的时候 被调用,经常用于调试的时候打印出非叶节点梯度。当然,通过这个手段,你也可以自定义某一层的梯度更新方法。[3] 具体到这里的打印非叶节点的梯度,代码如:

def hook_y(grad):
 print(grad)

x = Variable(torch.ones(2, 2), requires_grad=True)
y = x + 2
z = y * y * 3

y.register_hook(hook_y) 

out = z.mean()
out.backward()

输出如:

tensor([[4.5000, 4.5000],
  [4.5000, 4.5000]])

retain_grad()

Tensor.retain_grad()显式地保存非叶节点的梯度,当然代价就是会增加显存的消耗,而用hook函数的方法则是在反向计算时直接打印,因此不会增加显存消耗,但是使用起来retain_grad()要比hook函数方便一些。代码如:

x = Variable(torch.ones(2, 2), requires_grad=True)
y = x + 2
y.retain_grad()
z = y * y * 3
out = z.mean()
out.backward()
print(y.grad)

输出如:

tensor([[4.5000, 4.5000],
  [4.5000, 4.5000]])

以上这篇在pytorch中对非叶节点的变量计算梯度实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
几个提升Python运行效率的方法之间的对比
Apr 03 Python
python使用matplotlib绘图时图例显示问题的解决
Apr 27 Python
Python 3中print函数的使用方法总结
Aug 08 Python
Python实现的矩阵类实例
Aug 22 Python
python 中if else 语句的作用及示例代码
Mar 05 Python
Python continue继续循环用法总结
Jun 10 Python
详解django.contirb.auth-认证
Jul 16 Python
python中将\\uxxxx转换为Unicode字符串的方法
Sep 06 Python
Python 访问限制 private public的详细介绍
Oct 16 Python
给大家整理了19个pythonic的编程习惯(小结)
Sep 25 Python
pandas中遍历dataframe的每一个元素的实现
Oct 23 Python
Python map及filter函数使用方法解析
Aug 06 Python
python如何获取apk的packagename和activity
Jan 10 #Python
浅谈pytorch卷积核大小的设置对全连接神经元的影响
Jan 10 #Python
python颜色随机生成器的实例代码
Jan 10 #Python
关于python pycharm中输出的内容不全的解决办法
Jan 10 #Python
Python GUI自动化实现绕过验证码登录
Jan 10 #Python
pytorch nn.Conv2d()中的padding以及输出大小方式
Jan 10 #Python
如何给Python代码进行加密
Jan 10 #Python
You might like
PHP 调试工具Debug Tools
2011/04/30 PHP
php 数组使用详解 推荐
2011/06/02 PHP
php引用返回与取消引用的详解
2013/06/08 PHP
新浪SAE搭建PHP项目教程
2015/01/28 PHP
详解PHP中websocket的使用方法
2016/09/15 PHP
thinkphp5引入公共部分header、footer的方法详解
2018/09/14 PHP
jQuery Ajax使用 全解析
2010/12/15 Javascript
推荐30个新鲜出炉的精美 jQuery 效果
2012/03/26 Javascript
window.requestAnimationFrame是什么意思,怎么用
2013/01/13 Javascript
jquery实现滑动图片自己测试的例子
2013/11/05 Javascript
ExtJS4 动态生成的grid导出为excel示例
2014/05/02 Javascript
JavaScript设计模式之外观模式介绍
2014/12/28 Javascript
Jquery对select的增、删、改、查操作
2015/02/06 Javascript
Bootstrap每天必学之进度条
2015/11/30 Javascript
初学者AngularJS的环境搭建过程
2017/10/27 Javascript
常用的 JS 排序算法 整理版
2018/04/05 Javascript
JS使用数组实现的队列功能示例
2019/03/04 Javascript
通过实例解析JavaScript for in及for of区别
2020/06/15 Javascript
vuex实现购物车的增加减少移除
2020/06/28 Javascript
微信小程序连接服务器展示MQTT数据信息的实现
2020/07/14 Javascript
vue+elementui实现点击table中的单元格触发事件--弹框
2020/07/18 Javascript
vue中使用腾讯云Im的示例
2020/10/23 Javascript
微信小程序视频弹幕发送功能的实现
2020/12/28 Javascript
[08:38]DOTA2-DPC中国联赛 正赛 VG vs Elephant 选手采访
2021/03/11 DOTA
python 随机数生成的代码的详细分析
2011/05/15 Python
python 将字符串转换成字典dict
2013/03/24 Python
详解django2中关于时间处理策略
2019/03/06 Python
python实现信号时域统计特征提取代码
2020/02/26 Python
美国派对用品及装饰品网上商店:Shindigz
2016/07/30 全球购物
Maisons du Monde德国:法国家具和装饰的市场领导者
2019/07/26 全球购物
Elizabeth Gage官网:英国最好的珠宝设计之一
2020/09/26 全球购物
毕业设计计划书
2014/01/09 职场文书
保健品市场营销方案
2014/03/31 职场文书
人工作失职检讨书
2015/05/05 职场文书
2015年测量员工作总结
2015/05/23 职场文书
Python TypeError: ‘float‘ object is not subscriptable错误解决
2022/12/24 Python