浅谈pytorch grad_fn以及权重梯度不更新的问题


Posted in Python onAugust 20, 2019

前提:我训练的是二分类网络,使用语言为pytorch

Varibale包含三个属性:

data:存储了Tensor,是本体的数据

grad:保存了data的梯度,本事是个Variable而非Tensor,与data形状一致

grad_fn:指向Function对象,用于反向传播的梯度计算之用

在构建网络时,刚开始的错误为:没有可以grad_fn属性的变量。

百度后得知要对需要进行迭代更新的变量设置requires_grad=True ,操作如下:

train_pred = Variable(train_pred.float(), requires_grad=True)`

这样设置之后网络是跑起来了,但是准确率一直没有提升,很明显可以看出网络什么都没学到。

我输出 model.parameters() (网络内部的权重和偏置)查看,发现它的权重并没有更新,一直是同一个值,至此可以肯定网络什么都没学到,还是迭代那里出了问题。

询问同门后发现问题不在这里。

计算loss时,target与train_pred的size不匹配,我以以下操作修改了train_pred,使两者尺寸一致,才导致了上述问题。

train_pred = model(data)
  train_pred = torch.max(train_pred, 1)[1].data.squeeze()
  train_pred = Variable(train_pred.float(), requires_grad=False)
  train_loss = F.binary_cross_entropy(validation_pred.float(), target)
  train_loss.backward()

对train_pred多次处理后,它已无法正确地反向传播,实际上应该更改target,使其与train_pred size一致。

重点!!!要想loss正确反向传播,应直接将model(data)传入loss函数。

最终修改代码如下:

for batch_idx, (data, target) in enumerate(train_loader):
  # Get Samples
  label = target.view(target.size(0), 1).long()
  target_onehot = torch.zeros(data.shape[0], args.num_classes).scatter_(1, label, 1)
  data, target_onehot = Variable(data.cuda()), Variable(target_onehot.cuda().float())
  
  model.zero_grad()

  # Predict
  train_pred = model(data)
  train_loss = F.binary_cross_entropy(train_pred, target_onehot)
  train_loss.backward()
  optimizer.step()

以上这篇浅谈pytorch grad_fn以及权重梯度不更新的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python判断文件和文件夹是否存在的方法
May 21 Python
python使用生成器实现可迭代对象
Mar 20 Python
python生成ppt的方法
Jun 07 Python
python之信息加密题目详解
Jun 26 Python
Python使用ffmpy将amr格式的音频转化为mp3格式的例子
Aug 08 Python
调用其他python脚本文件里面的类和方法过程解析
Nov 15 Python
Pytorch技巧:DataLoader的collate_fn参数使用详解
Jan 08 Python
Python守护进程实现过程详解
Feb 10 Python
Python *args和**kwargs用法实例解析
Mar 02 Python
Python 文本滚动播放器的实现代码
Apr 25 Python
有趣的二维码:使用MyQR和qrcode来制作二维码
May 10 Python
Python利用capstone实现反汇编
Apr 06 Python
解决Pytorch 训练与测试时爆显存(out of memory)的问题
Aug 20 #Python
python中用logging实现日志滚动和过期日志删除功能
Aug 20 #Python
python3中替换python2中cmp函数的实现
Aug 20 #Python
python 并发编程 多路复用IO模型详解
Aug 20 #Python
关于pytorch中网络loss传播和参数更新的理解
Aug 20 #Python
对pytorch中的梯度更新方法详解
Aug 20 #Python
PyTorch: 梯度下降及反向传播的实例详解
Aug 20 #Python
You might like
PHP中的替代语法简介
2014/08/22 PHP
php 使用ActiveMQ发送消息,与处理消息操作示例
2020/02/23 PHP
js验证模型自我实现的具体方法
2013/06/21 Javascript
jquery 获取表单元素里面的值示例代码
2013/07/28 Javascript
Dojo Javascript 编程规范 规范自己的JavaScript书写
2014/10/26 Javascript
JQUERY简单按钮轮换选中效果实现方法
2015/05/07 Javascript
基于JS+Canves实现点击按钮水波纹效果
2016/09/15 Javascript
JS异步加载的三种实现方式
2017/03/16 Javascript
angular实现页面打印局部功能的思考与方法
2018/04/13 Javascript
jQuery实现表单动态加减、ajax表单提交功能
2018/06/08 jQuery
解析vue路由异步组件和懒加载案例
2018/06/08 Javascript
vue构建动态表单的方法示例
2018/09/22 Javascript
vue拖拽组件使用方法详解
2018/12/01 Javascript
VUE解决微信签名及SPA微信invalid signature问题(完美处理)
2019/03/29 Javascript
NodeJs 模仿SIP话机注册的方法
2019/06/21 NodeJs
利用不到200行代码写一款属于你自己的js类库
2019/07/08 Javascript
微信小程序 Storage更新详解
2019/07/16 Javascript
jQuery zTree插件快速实现目录树
2019/08/16 jQuery
解决layui使用layui-icon出现默认图标的问题
2019/09/11 Javascript
Node.js web 应用如何封装到Docker容器中
2020/09/01 Javascript
Ant-design-vue Table组件customRow属性的使用说明
2020/10/28 Javascript
布同 Python中文问题解决方法(总结了多位前人经验,初学者必看)
2011/03/13 Python
对python中字典keys,values,items的使用详解
2019/02/03 Python
pytorch 多分类问题,计算百分比操作
2020/07/09 Python
Python下载网易云歌单歌曲的示例代码
2020/08/12 Python
事业单位请假制度
2014/01/13 职场文书
小学中秋节活动方案
2014/02/06 职场文书
应届毕业生应聘自荐信范文
2014/02/26 职场文书
法制教育演讲稿
2014/09/10 职场文书
党的群众路线教育实践活动个人对照检查材料范文
2014/09/25 职场文书
个人务虚会发言材料
2014/10/20 职场文书
学校团代会开幕词
2016/03/04 职场文书
Nginx代理同域名前后端分离项目的完整步骤
2021/03/31 Servers
Django开发RESTful API实现增删改查(入门级)
2021/05/10 Python
WINDOWS下安装mysql 8.x 的方法图文教程
2022/04/19 MySQL
js前端图片加载异常兜底方案
2022/06/21 Javascript