浅谈pytorch grad_fn以及权重梯度不更新的问题


Posted in Python onAugust 20, 2019

前提:我训练的是二分类网络,使用语言为pytorch

Varibale包含三个属性:

data:存储了Tensor,是本体的数据

grad:保存了data的梯度,本事是个Variable而非Tensor,与data形状一致

grad_fn:指向Function对象,用于反向传播的梯度计算之用

在构建网络时,刚开始的错误为:没有可以grad_fn属性的变量。

百度后得知要对需要进行迭代更新的变量设置requires_grad=True ,操作如下:

train_pred = Variable(train_pred.float(), requires_grad=True)`

这样设置之后网络是跑起来了,但是准确率一直没有提升,很明显可以看出网络什么都没学到。

我输出 model.parameters() (网络内部的权重和偏置)查看,发现它的权重并没有更新,一直是同一个值,至此可以肯定网络什么都没学到,还是迭代那里出了问题。

询问同门后发现问题不在这里。

计算loss时,target与train_pred的size不匹配,我以以下操作修改了train_pred,使两者尺寸一致,才导致了上述问题。

train_pred = model(data)
  train_pred = torch.max(train_pred, 1)[1].data.squeeze()
  train_pred = Variable(train_pred.float(), requires_grad=False)
  train_loss = F.binary_cross_entropy(validation_pred.float(), target)
  train_loss.backward()

对train_pred多次处理后,它已无法正确地反向传播,实际上应该更改target,使其与train_pred size一致。

重点!!!要想loss正确反向传播,应直接将model(data)传入loss函数。

最终修改代码如下:

for batch_idx, (data, target) in enumerate(train_loader):
  # Get Samples
  label = target.view(target.size(0), 1).long()
  target_onehot = torch.zeros(data.shape[0], args.num_classes).scatter_(1, label, 1)
  data, target_onehot = Variable(data.cuda()), Variable(target_onehot.cuda().float())
  
  model.zero_grad()

  # Predict
  train_pred = model(data)
  train_loss = F.binary_cross_entropy(train_pred, target_onehot)
  train_loss.backward()
  optimizer.step()

以上这篇浅谈pytorch grad_fn以及权重梯度不更新的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python Socket编程入门教程
Jul 11 Python
Python中MySQL数据迁移到MongoDB脚本的方法
Apr 28 Python
ubuntu中配置pyqt4环境教程
Dec 27 Python
Python实现简单生成验证码功能【基于random模块】
Feb 10 Python
tensorflow1.0学习之模型的保存与恢复(Saver)
Apr 23 Python
python实现kmp算法的实例代码
Apr 03 Python
python障碍式期权定价公式
Jul 19 Python
如何使用Flask-Migrate拓展数据库表结构
Jul 24 Python
django中使用事务及接入支付宝支付功能
Sep 15 Python
利用python实现PSO算法优化二元函数
Nov 13 Python
解决pyCharm中 module 调用失败的问题
Feb 12 Python
python中使用you-get库批量在线下载bilibili视频的教程
Mar 10 Python
解决Pytorch 训练与测试时爆显存(out of memory)的问题
Aug 20 #Python
python中用logging实现日志滚动和过期日志删除功能
Aug 20 #Python
python3中替换python2中cmp函数的实现
Aug 20 #Python
python 并发编程 多路复用IO模型详解
Aug 20 #Python
关于pytorch中网络loss传播和参数更新的理解
Aug 20 #Python
对pytorch中的梯度更新方法详解
Aug 20 #Python
PyTorch: 梯度下降及反向传播的实例详解
Aug 20 #Python
You might like
关于拼配咖啡,你要知道
2021/03/03 咖啡文化
thinkphp模板继承实例简述
2014/11/26 PHP
PHP实现多文件上传的方法
2015/07/08 PHP
Tab页界面 用jQuery及Ajax技术实现(php后台)
2011/10/12 Javascript
JavaScript对Json的增删改属性详解
2016/06/02 Javascript
在DWR中实现直接获取一个JAVA类的返回值的两种方法
2016/12/25 Javascript
JS使用正则截取两个字符串之间的字符串实现方法详解
2017/01/06 Javascript
JavaScript数组去重的6个方法
2017/01/21 Javascript
JS表格组件神器bootstrap table使用指南详解
2017/04/12 Javascript
什么是Vue.js框架 为什么选择它?
2017/10/17 Javascript
jQuery实现点击旋转,再点击恢复初始状态动画效果示例
2018/12/11 jQuery
node.js微信小程序配置消息推送的实现
2019/02/13 Javascript
vue组件系列之TagsInput详解
2020/05/14 Javascript
JavaScript前后端JSON使用方法教程
2020/11/23 Javascript
Python多线程学习资料
2012/12/19 Python
浅析Python中的多进程与多线程的使用
2015/04/07 Python
python利用urllib实现爬取京东网站商品图片的爬虫实例
2017/08/24 Python
python中numpy.zeros(np.zeros)的使用方法
2017/11/07 Python
对Python3.6 IDLE常用快捷键介绍
2018/07/16 Python
Python 运行 shell 获取输出结果的实例
2019/01/07 Python
Python寻找路径和查找文件路径的示例
2019/07/10 Python
深入了解Python iter() 方法的用法
2019/07/11 Python
对python 树状嵌套结构的实现思路详解
2019/08/09 Python
关于Python3 类方法、静态方法新解
2019/08/30 Python
HTML5之HTML元素扩展(下)—增强的Form表单元素值得关注
2013/01/31 HTML / CSS
清除canvas画布内容(点擦除+线擦除)
2020/08/12 HTML / CSS
日本卡普空电视游戏软件公司官方购物网站:e-CAPCOM
2018/07/17 全球购物
护士自我鉴定怎么写
2014/02/07 职场文书
2014年最新领导班子整改方案
2014/09/27 职场文书
2015年世界水日活动总结
2015/02/09 职场文书
2015年教学工作总结
2015/04/02 职场文书
早安问候语大全
2015/11/10 职场文书
自己搭建resnet18网络并加载torchvision自带权重的操作
2021/05/13 Python
如何理解PHP核心特性命名空间
2021/05/28 PHP
Python 可迭代对象 iterable的具体使用
2021/08/07 Python
C#连接ORACLE出现乱码问题的解决方法
2021/10/05 Oracle