浅谈pytorch grad_fn以及权重梯度不更新的问题


Posted in Python onAugust 20, 2019

前提:我训练的是二分类网络,使用语言为pytorch

Varibale包含三个属性:

data:存储了Tensor,是本体的数据

grad:保存了data的梯度,本事是个Variable而非Tensor,与data形状一致

grad_fn:指向Function对象,用于反向传播的梯度计算之用

在构建网络时,刚开始的错误为:没有可以grad_fn属性的变量。

百度后得知要对需要进行迭代更新的变量设置requires_grad=True ,操作如下:

train_pred = Variable(train_pred.float(), requires_grad=True)`

这样设置之后网络是跑起来了,但是准确率一直没有提升,很明显可以看出网络什么都没学到。

我输出 model.parameters() (网络内部的权重和偏置)查看,发现它的权重并没有更新,一直是同一个值,至此可以肯定网络什么都没学到,还是迭代那里出了问题。

询问同门后发现问题不在这里。

计算loss时,target与train_pred的size不匹配,我以以下操作修改了train_pred,使两者尺寸一致,才导致了上述问题。

train_pred = model(data)
  train_pred = torch.max(train_pred, 1)[1].data.squeeze()
  train_pred = Variable(train_pred.float(), requires_grad=False)
  train_loss = F.binary_cross_entropy(validation_pred.float(), target)
  train_loss.backward()

对train_pred多次处理后,它已无法正确地反向传播,实际上应该更改target,使其与train_pred size一致。

重点!!!要想loss正确反向传播,应直接将model(data)传入loss函数。

最终修改代码如下:

for batch_idx, (data, target) in enumerate(train_loader):
  # Get Samples
  label = target.view(target.size(0), 1).long()
  target_onehot = torch.zeros(data.shape[0], args.num_classes).scatter_(1, label, 1)
  data, target_onehot = Variable(data.cuda()), Variable(target_onehot.cuda().float())
  
  model.zero_grad()

  # Predict
  train_pred = model(data)
  train_loss = F.binary_cross_entropy(train_pred, target_onehot)
  train_loss.backward()
  optimizer.step()

以上这篇浅谈pytorch grad_fn以及权重梯度不更新的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python处理文本文件并生成指定格式的文件
Jul 31 Python
Python制作Windows系统服务
Mar 25 Python
Python实现导出数据生成excel报表的方法示例
Jul 12 Python
Python用户推荐系统曼哈顿算法实现完整代码
Dec 01 Python
numpy数组拼接简单示例
Dec 15 Python
利用python实现微信头像加红色数字功能
Mar 26 Python
Python中循环引用(import)失败的解决方法
Apr 22 Python
python使用matplotlib画饼状图
Sep 25 Python
在Python中字典根据多项规则排序的方法
Jan 21 Python
PyQt5 界面显示无响应的实现
Mar 26 Python
python手机号前7位归属地爬虫代码实例
Mar 31 Python
python迷宫问题深度优先遍历实例
Jun 20 Python
解决Pytorch 训练与测试时爆显存(out of memory)的问题
Aug 20 #Python
python中用logging实现日志滚动和过期日志删除功能
Aug 20 #Python
python3中替换python2中cmp函数的实现
Aug 20 #Python
python 并发编程 多路复用IO模型详解
Aug 20 #Python
关于pytorch中网络loss传播和参数更新的理解
Aug 20 #Python
对pytorch中的梯度更新方法详解
Aug 20 #Python
PyTorch: 梯度下降及反向传播的实例详解
Aug 20 #Python
You might like
php 魔术常量详解及实例代码
2016/12/04 PHP
php中文乱码问题的终极解决方案汇总
2017/08/01 PHP
thinkPHP框架实现的简单计算器示例
2018/12/07 PHP
javascript showModalDialog,open取得父窗口的方法
2010/03/10 Javascript
理解Javascript_09_Function与Object
2010/10/16 Javascript
中国地区三级联动下拉菜单效果分析
2012/11/15 Javascript
javascript的动态加载、缓存、更新以及复用(一)
2014/06/09 Javascript
jQuery中Ajax的load方法详解
2015/01/14 Javascript
举例详解Python中smtplib模块处理电子邮件的使用
2015/06/24 Javascript
原生js实现秒表计时器功能
2017/02/16 Javascript
原生js 封装get ,post, delete 请求的实例
2017/08/11 Javascript
基于zTree树形菜单的使用实例
2017/12/25 Javascript
基于jQuery实现无缝轮播与左右点击效果
2018/05/13 jQuery
详解angular2.x创建项目入门指令
2018/10/11 Javascript
微信小程序开发的基本流程步骤
2019/01/31 Javascript
vue动态注册组件实例代码详解
2019/05/30 Javascript
layui将table转化表单显示的方法(即table.render转为表单展示)
2019/09/24 Javascript
angula中使用iframe点击后不执行变更检测的问题
2020/05/10 Javascript
JavaScript通如何过RGraph实现动态仪表盘
2020/10/15 Javascript
跟老齐学Python之从格式化表达式到方法
2014/09/28 Python
介绍Python中几个常用的类方法
2015/04/08 Python
Python统计文件中去重后uuid个数的方法
2015/07/30 Python
使用Python 统计高频字数的方法
2019/01/31 Python
使用Python自动化破解自定义字体混淆信息的方法实例
2019/02/13 Python
python scipy卷积运算的实现方法
2019/09/16 Python
python 中Arduino串口传输数据到电脑并保存至excel表格
2019/10/14 Python
使用TensorFlow搭建一个全连接神经网络教程
2020/02/06 Python
加拿大约会网站:EliteSingles.ca
2018/01/12 全球购物
船餐厅和泰晤士河餐饮游轮:Bateaux London
2018/03/19 全球购物
查找廉价航班和发现新目的地:Kiwi.com
2019/02/25 全球购物
Yummie官方网站:塑身衣和衣柜必需品
2019/10/29 全球购物
JD Sports丹麦:英国领先的运动时尚零售商
2020/11/24 全球购物
应届大专毕业生自我鉴定
2014/04/08 职场文书
保护环境的标语
2014/06/09 职场文书
导游词之山西关帝庙
2019/11/01 职场文书
pytorch 两个GPU同时训练的解决方案
2021/06/01 Python