浅谈pytorch grad_fn以及权重梯度不更新的问题


Posted in Python onAugust 20, 2019

前提:我训练的是二分类网络,使用语言为pytorch

Varibale包含三个属性:

data:存储了Tensor,是本体的数据

grad:保存了data的梯度,本事是个Variable而非Tensor,与data形状一致

grad_fn:指向Function对象,用于反向传播的梯度计算之用

在构建网络时,刚开始的错误为:没有可以grad_fn属性的变量。

百度后得知要对需要进行迭代更新的变量设置requires_grad=True ,操作如下:

train_pred = Variable(train_pred.float(), requires_grad=True)`

这样设置之后网络是跑起来了,但是准确率一直没有提升,很明显可以看出网络什么都没学到。

我输出 model.parameters() (网络内部的权重和偏置)查看,发现它的权重并没有更新,一直是同一个值,至此可以肯定网络什么都没学到,还是迭代那里出了问题。

询问同门后发现问题不在这里。

计算loss时,target与train_pred的size不匹配,我以以下操作修改了train_pred,使两者尺寸一致,才导致了上述问题。

train_pred = model(data)
  train_pred = torch.max(train_pred, 1)[1].data.squeeze()
  train_pred = Variable(train_pred.float(), requires_grad=False)
  train_loss = F.binary_cross_entropy(validation_pred.float(), target)
  train_loss.backward()

对train_pred多次处理后,它已无法正确地反向传播,实际上应该更改target,使其与train_pred size一致。

重点!!!要想loss正确反向传播,应直接将model(data)传入loss函数。

最终修改代码如下:

for batch_idx, (data, target) in enumerate(train_loader):
  # Get Samples
  label = target.view(target.size(0), 1).long()
  target_onehot = torch.zeros(data.shape[0], args.num_classes).scatter_(1, label, 1)
  data, target_onehot = Variable(data.cuda()), Variable(target_onehot.cuda().float())
  
  model.zero_grad()

  # Predict
  train_pred = model(data)
  train_loss = F.binary_cross_entropy(train_pred, target_onehot)
  train_loss.backward()
  optimizer.step()

以上这篇浅谈pytorch grad_fn以及权重梯度不更新的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python通过yield实现数组全排列的方法
Mar 18 Python
一步步解析Python斗牛游戏的概率
Feb 12 Python
使用Python的Tornado框架实现一个Web端图书展示页面
Jul 11 Python
python中快速进行多个字符替换的方法小结
Dec 15 Python
python实现BackPropagation算法
Dec 14 Python
python实现用户答题功能
Jan 17 Python
Python工厂函数用法实例分析
May 14 Python
Django中Middleware中的函数详解
Jul 18 Python
Python学习笔记之列表和成员运算符及列表相关方法详解
Aug 22 Python
python数据爬下来保存的位置
Feb 17 Python
pandas使用之宽表变窄表的实现
Apr 12 Python
Pytorch GPU内存占用很高,但是利用率很低如何解决
Jun 01 Python
解决Pytorch 训练与测试时爆显存(out of memory)的问题
Aug 20 #Python
python中用logging实现日志滚动和过期日志删除功能
Aug 20 #Python
python3中替换python2中cmp函数的实现
Aug 20 #Python
python 并发编程 多路复用IO模型详解
Aug 20 #Python
关于pytorch中网络loss传播和参数更新的理解
Aug 20 #Python
对pytorch中的梯度更新方法详解
Aug 20 #Python
PyTorch: 梯度下降及反向传播的实例详解
Aug 20 #Python
You might like
PHP 加密与解密的斗争
2009/04/17 PHP
php学习笔记之 函数声明(二)
2011/06/09 PHP
ThinkPHP项目分组配置方法分析
2016/03/23 PHP
浅谈PHP中类和对象的相关函数
2017/04/26 PHP
50个优秀经典PHP算法大集合 附源码
2020/08/26 PHP
PHP数组实际占用内存大小原理解析
2020/12/11 PHP
JQuery autocomplete 使用手册
2010/04/01 Javascript
js中更短的 Array 类型转换
2011/10/30 Javascript
javascript截取字符串(通过substring实现并支持中英文混合)
2013/06/24 Javascript
js获取select标签的值且兼容IE与firefox
2013/12/30 Javascript
浅谈JavaScript中的Math.atan()方法的使用
2015/06/14 Javascript
JSP基于Bootstrap分页显示实例解析
2016/06/12 Javascript
判断输入的字符串是否是日期格式的简单方法
2016/07/11 Javascript
详解vue中使用express+fetch获取本地json文件
2017/10/10 Javascript
js使用swiper实现层叠轮播效果实例代码
2018/12/12 Javascript
vue进入页面时滚动条始终在底部代码实例
2019/03/26 Javascript
node获取客户端ip功能简单示例
2019/08/24 Javascript
jQuery鼠标滑过横向时间轴样式(代码详解)
2019/11/01 jQuery
nodejs如何在package.json中设置多条启动命令
2020/03/16 NodeJs
vue渲染方式render和template的区别
2020/06/05 Javascript
[57:38]2018DOTA2亚洲邀请赛3月30日 小组赛A组 OpTic VS OG
2018/03/31 DOTA
[56:46]2018DOTA2亚洲邀请赛 3.31 小组赛 B组 VP vs Effect
2018/04/01 DOTA
python妹子图简单爬虫实例
2015/07/07 Python
python之文件的读写和文件目录以及文件夹的操作实现代码
2016/08/28 Python
在Mac上删除自己安装的Python方法
2018/10/29 Python
python抓取京东小米8手机配置信息
2018/11/13 Python
django+mysql的使用示例
2018/11/23 Python
python 用户交互输入input的4种用法详解
2019/09/24 Python
HTML5+Canvas+CSS3实现齐天大圣孙悟空腾云驾雾效果
2016/04/26 HTML / CSS
土耳其时尚购物网站:Morhipo
2017/09/04 全球购物
英国最好的温室之家:Greenhouses Direct
2019/07/13 全球购物
波兰最大的电商平台:Allegro.pl
2021/02/06 全球购物
上课迟到检讨书100字
2014/01/11 职场文书
高中考试作弊检讨书
2014/01/14 职场文书
单身联谊活动方案
2014/01/29 职场文书