浅谈pytorch grad_fn以及权重梯度不更新的问题


Posted in Python onAugust 20, 2019

前提:我训练的是二分类网络,使用语言为pytorch

Varibale包含三个属性:

data:存储了Tensor,是本体的数据

grad:保存了data的梯度,本事是个Variable而非Tensor,与data形状一致

grad_fn:指向Function对象,用于反向传播的梯度计算之用

在构建网络时,刚开始的错误为:没有可以grad_fn属性的变量。

百度后得知要对需要进行迭代更新的变量设置requires_grad=True ,操作如下:

train_pred = Variable(train_pred.float(), requires_grad=True)`

这样设置之后网络是跑起来了,但是准确率一直没有提升,很明显可以看出网络什么都没学到。

我输出 model.parameters() (网络内部的权重和偏置)查看,发现它的权重并没有更新,一直是同一个值,至此可以肯定网络什么都没学到,还是迭代那里出了问题。

询问同门后发现问题不在这里。

计算loss时,target与train_pred的size不匹配,我以以下操作修改了train_pred,使两者尺寸一致,才导致了上述问题。

train_pred = model(data)
  train_pred = torch.max(train_pred, 1)[1].data.squeeze()
  train_pred = Variable(train_pred.float(), requires_grad=False)
  train_loss = F.binary_cross_entropy(validation_pred.float(), target)
  train_loss.backward()

对train_pred多次处理后,它已无法正确地反向传播,实际上应该更改target,使其与train_pred size一致。

重点!!!要想loss正确反向传播,应直接将model(data)传入loss函数。

最终修改代码如下:

for batch_idx, (data, target) in enumerate(train_loader):
  # Get Samples
  label = target.view(target.size(0), 1).long()
  target_onehot = torch.zeros(data.shape[0], args.num_classes).scatter_(1, label, 1)
  data, target_onehot = Variable(data.cuda()), Variable(target_onehot.cuda().float())
  
  model.zero_grad()

  # Predict
  train_pred = model(data)
  train_loss = F.binary_cross_entropy(train_pred, target_onehot)
  train_loss.backward()
  optimizer.step()

以上这篇浅谈pytorch grad_fn以及权重梯度不更新的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python程序员鲜为人知但你应该知道的17个问题
Jun 04 Python
跟老齐学Python之一个免费的实验室
Sep 14 Python
python对字典进行排序实例
Sep 25 Python
从Python的源码来解析Python下的freeblock
May 11 Python
为什么你还不懂得怎么使用Python协程
May 13 Python
django项目中使用手机号登录的实例代码
Aug 15 Python
python nmap实现端口扫描器教程
May 28 Python
通过实例解析python描述符原理作用
Jan 22 Python
jenkins+python自动化测试持续集成教程
May 12 Python
Keras: model实现固定部分layer,训练部分layer操作
Jun 28 Python
Python+Dlib+Opencv实现人脸采集并表情判别功能的代码
Jul 01 Python
keras:model.compile损失函数的用法
Jul 01 Python
解决Pytorch 训练与测试时爆显存(out of memory)的问题
Aug 20 #Python
python中用logging实现日志滚动和过期日志删除功能
Aug 20 #Python
python3中替换python2中cmp函数的实现
Aug 20 #Python
python 并发编程 多路复用IO模型详解
Aug 20 #Python
关于pytorch中网络loss传播和参数更新的理解
Aug 20 #Python
对pytorch中的梯度更新方法详解
Aug 20 #Python
PyTorch: 梯度下降及反向传播的实例详解
Aug 20 #Python
You might like
Laravel 批量更新多条数据的示例
2017/11/27 PHP
关于jQuery参考实例 1.0 jQuery的哲学
2013/04/07 Javascript
JS调用CS里的带参方法实例
2013/08/01 Javascript
JS和函数式语言的三特性
2014/03/05 Javascript
基于Jquery实现键盘按键监听
2014/05/11 Javascript
浅谈javascript 函数属性和方法
2015/01/21 Javascript
js计算德州扑克牌面值的方法
2015/03/04 Javascript
jQuery实现瀑布流布局详解(PC和移动端)
2020/09/01 Javascript
JavaScript Length 属性的总结
2015/11/02 Javascript
JavaScrip常见的一些算法总结
2015/12/28 Javascript
基于JavaScript判断浏览器到底是关闭还是刷新(超准确)
2016/02/01 Javascript
概述jQuery的元素筛选
2016/11/23 Javascript
jQuery内容筛选选择器实例代码
2017/02/06 Javascript
用nodeJS搭建本地文件服务器的几种方法小结
2017/03/16 NodeJs
JavaScript中 DOM操作方法小结
2017/04/25 Javascript
Angularjs的键盘事件的绑定
2017/07/27 Javascript
获取本机IP地址的实例(JavaScript / Node.js)
2017/11/24 Javascript
详解redis在nodejs中的应用
2018/05/02 NodeJs
微信小程序 wxParse插件显示视频问题
2019/09/27 Javascript
微信小程序静默登录的实现代码
2020/01/08 Javascript
微信小程序清空输入框信息与实现屏幕往上滚动的示例代码
2020/06/23 Javascript
在vue中使用inheritAttrs实现组件的扩展性介绍
2020/12/07 Vue.js
Python 数据处理库 pandas 入门教程基本操作
2018/04/19 Python
python+pygame实现坦克大战小游戏的示例代码(可以自定义子弹速度)
2020/08/11 Python
css3个性化字体_动力节点Java学院整理
2017/07/12 HTML / CSS
美国时尚大码女装购物网站:Avenue
2019/05/24 全球购物
入党积极分子学习两会心得体会范文
2014/03/17 职场文书
户籍证明格式
2014/09/15 职场文书
校园运动会广播稿
2015/08/19 职场文书
三好学生竞选稿
2015/11/21 职场文书
志愿者工作心得体会
2016/01/15 职场文书
如何计划开一家便利店?
2019/07/31 职场文书
HTML速写之Emmet语法规则的实现
2021/04/07 HTML / CSS
原生JS中应该禁止出现的写法
2021/05/05 Javascript
使用Redis实现秒杀功能的简单方法
2021/05/08 Redis
nginx配置文件使用环境变量的操作方法
2021/06/02 Servers