浅谈pytorch grad_fn以及权重梯度不更新的问题


Posted in Python onAugust 20, 2019

前提:我训练的是二分类网络,使用语言为pytorch

Varibale包含三个属性:

data:存储了Tensor,是本体的数据

grad:保存了data的梯度,本事是个Variable而非Tensor,与data形状一致

grad_fn:指向Function对象,用于反向传播的梯度计算之用

在构建网络时,刚开始的错误为:没有可以grad_fn属性的变量。

百度后得知要对需要进行迭代更新的变量设置requires_grad=True ,操作如下:

train_pred = Variable(train_pred.float(), requires_grad=True)`

这样设置之后网络是跑起来了,但是准确率一直没有提升,很明显可以看出网络什么都没学到。

我输出 model.parameters() (网络内部的权重和偏置)查看,发现它的权重并没有更新,一直是同一个值,至此可以肯定网络什么都没学到,还是迭代那里出了问题。

询问同门后发现问题不在这里。

计算loss时,target与train_pred的size不匹配,我以以下操作修改了train_pred,使两者尺寸一致,才导致了上述问题。

train_pred = model(data)
  train_pred = torch.max(train_pred, 1)[1].data.squeeze()
  train_pred = Variable(train_pred.float(), requires_grad=False)
  train_loss = F.binary_cross_entropy(validation_pred.float(), target)
  train_loss.backward()

对train_pred多次处理后,它已无法正确地反向传播,实际上应该更改target,使其与train_pred size一致。

重点!!!要想loss正确反向传播,应直接将model(data)传入loss函数。

最终修改代码如下:

for batch_idx, (data, target) in enumerate(train_loader):
  # Get Samples
  label = target.view(target.size(0), 1).long()
  target_onehot = torch.zeros(data.shape[0], args.num_classes).scatter_(1, label, 1)
  data, target_onehot = Variable(data.cuda()), Variable(target_onehot.cuda().float())
  
  model.zero_grad()

  # Predict
  train_pred = model(data)
  train_loss = F.binary_cross_entropy(train_pred, target_onehot)
  train_loss.backward()
  optimizer.step()

以上这篇浅谈pytorch grad_fn以及权重梯度不更新的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python字符串的方法与操作大全
Jan 30 Python
Python实现简单生成验证码功能【基于random模块】
Feb 10 Python
Numpy中转置transpose、T和swapaxes的实例讲解
Apr 17 Python
对Python random模块打乱数组顺序的实例讲解
Nov 08 Python
python实现浪漫的烟花秀
Jan 30 Python
Django使用redis缓存服务器的实现代码示例
Apr 28 Python
Ubuntu下Anaconda和Pycharm配置方法详解
Jun 14 Python
在python3中实现查找数组中最接近与某值的元素操作
Feb 29 Python
Python Selenium 设置元素等待的三种方式
Mar 18 Python
详解Python 最短匹配模式
Jul 29 Python
详解Django ORM引发的数据库N+1性能问题
Oct 12 Python
python基于socket模拟实现ssh远程执行命令
Dec 05 Python
解决Pytorch 训练与测试时爆显存(out of memory)的问题
Aug 20 #Python
python中用logging实现日志滚动和过期日志删除功能
Aug 20 #Python
python3中替换python2中cmp函数的实现
Aug 20 #Python
python 并发编程 多路复用IO模型详解
Aug 20 #Python
关于pytorch中网络loss传播和参数更新的理解
Aug 20 #Python
对pytorch中的梯度更新方法详解
Aug 20 #Python
PyTorch: 梯度下降及反向传播的实例详解
Aug 20 #Python
You might like
php读取30天之内的根据算法排序的代码
2008/04/06 PHP
php+mysqli实现批量替换数据库表前缀的方法
2014/12/29 PHP
PHP实现移除数组中为空或为某值元素的方法
2017/01/07 PHP
php常用字符串查找函数strstr()与strpos()实例分析
2019/06/21 PHP
在你的网页中嵌入外部网页的方法
2007/04/02 Javascript
jquery中的on方法使用介绍
2013/12/29 Javascript
基于javascript、ajax、memcache和PHP实现的简易在线聊天室
2015/02/03 Javascript
Bootstrap表单布局样式源代码
2016/07/04 Javascript
微信小程序如何获取用户信息
2018/01/26 Javascript
Vue实现的父组件向子组件传值功能示例
2019/01/19 Javascript
JS中的函数与对象的创建方式
2019/05/12 Javascript
JS中超越现实的匿名函数用法实例分析
2019/06/21 Javascript
ES6常用小技巧总结【去重、交换、合并、反转、迭代、计算等】
2019/12/21 Javascript
JavaScript中变量提升和函数提升的详解
2020/08/07 Javascript
vue 动态添加class,三个以上的条件做判断方式
2020/11/02 Javascript
[39:02]DOTA2亚洲邀请赛 3.31 小组赛 B组 Mineski vs VGJ.T
2018/04/01 DOTA
[01:27:30]LGD vs Newbee 2019国际邀请赛小组赛 BO2 第二场 8.16
2019/08/19 DOTA
python利用datetime模块计算时间差
2015/08/04 Python
Python3中使用PyMongo的方法详解
2017/07/28 Python
python实现K近邻回归,采用等权重和不等权重的方法
2019/01/23 Python
python可视化爬虫界面之天气查询
2019/07/03 Python
详解Python二维数组与三维数组切片的方法
2019/07/18 Python
Pytorch实现的手写数字mnist识别功能完整示例
2019/12/13 Python
python eventlet绿化和patch原理
2020/11/21 Python
墨西哥运动服饰和鞋网上商店:Netshoes墨西哥
2016/07/28 全球购物
Under Armour西班牙官网:美国知名的高端功能性运动品牌
2018/12/12 全球购物
西部世纪.net笔试题面试题
2014/04/03 面试题
公证委托书模板
2014/04/03 职场文书
党的群众路线教育实践活动个人批评与自我批评
2014/10/16 职场文书
党员示范岗材料
2014/12/19 职场文书
2016高考寄语集锦
2015/12/04 职场文书
《云雀的心愿》教学反思
2016/02/23 职场文书
导游词之井冈山
2019/11/20 职场文书
告别网页搜索!教你用python实现一款属于自己的翻译词典软件
2021/06/03 Python
Python合并pdf文件的工具
2021/07/01 Python
SpringBoot系列之MongoDB Aggregations用法详解
2022/02/12 MongoDB