浅谈pytorch grad_fn以及权重梯度不更新的问题


Posted in Python onAugust 20, 2019

前提:我训练的是二分类网络,使用语言为pytorch

Varibale包含三个属性:

data:存储了Tensor,是本体的数据

grad:保存了data的梯度,本事是个Variable而非Tensor,与data形状一致

grad_fn:指向Function对象,用于反向传播的梯度计算之用

在构建网络时,刚开始的错误为:没有可以grad_fn属性的变量。

百度后得知要对需要进行迭代更新的变量设置requires_grad=True ,操作如下:

train_pred = Variable(train_pred.float(), requires_grad=True)`

这样设置之后网络是跑起来了,但是准确率一直没有提升,很明显可以看出网络什么都没学到。

我输出 model.parameters() (网络内部的权重和偏置)查看,发现它的权重并没有更新,一直是同一个值,至此可以肯定网络什么都没学到,还是迭代那里出了问题。

询问同门后发现问题不在这里。

计算loss时,target与train_pred的size不匹配,我以以下操作修改了train_pred,使两者尺寸一致,才导致了上述问题。

train_pred = model(data)
  train_pred = torch.max(train_pred, 1)[1].data.squeeze()
  train_pred = Variable(train_pred.float(), requires_grad=False)
  train_loss = F.binary_cross_entropy(validation_pred.float(), target)
  train_loss.backward()

对train_pred多次处理后,它已无法正确地反向传播,实际上应该更改target,使其与train_pred size一致。

重点!!!要想loss正确反向传播,应直接将model(data)传入loss函数。

最终修改代码如下:

for batch_idx, (data, target) in enumerate(train_loader):
  # Get Samples
  label = target.view(target.size(0), 1).long()
  target_onehot = torch.zeros(data.shape[0], args.num_classes).scatter_(1, label, 1)
  data, target_onehot = Variable(data.cuda()), Variable(target_onehot.cuda().float())
  
  model.zero_grad()

  # Predict
  train_pred = model(data)
  train_loss = F.binary_cross_entropy(train_pred, target_onehot)
  train_loss.backward()
  optimizer.step()

以上这篇浅谈pytorch grad_fn以及权重梯度不更新的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python3.3教程之模拟百度登陆代码分享
Jan 16 Python
python爬虫实战之爬取京东商城实例教程
Apr 24 Python
TensorFlow安装及jupyter notebook配置方法
Sep 08 Python
Selenium 模拟浏览器动态加载页面的实现方法
May 16 Python
Python实现的多叉树寻找最短路径算法示例
Jul 30 Python
python按修改时间顺序排列文件的实例代码
Jul 25 Python
python修改字典键(key)的方法
Aug 05 Python
python-序列解包(对可迭代元素的快速取值方法)
Aug 24 Python
python中dict()的高级用法实现
Nov 13 Python
python selenium 执行完毕关闭chromedriver进程示例
Nov 15 Python
Python爬虫工具requests-html使用解析
Apr 29 Python
Django开发RESTful API实现增删改查(入门级)
May 10 Python
解决Pytorch 训练与测试时爆显存(out of memory)的问题
Aug 20 #Python
python中用logging实现日志滚动和过期日志删除功能
Aug 20 #Python
python3中替换python2中cmp函数的实现
Aug 20 #Python
python 并发编程 多路复用IO模型详解
Aug 20 #Python
关于pytorch中网络loss传播和参数更新的理解
Aug 20 #Python
对pytorch中的梯度更新方法详解
Aug 20 #Python
PyTorch: 梯度下降及反向传播的实例详解
Aug 20 #Python
You might like
PHP数据库开发知多少
2006/10/09 PHP
PHP文件去掉PHP注释空格的函数分析(PHP代码压缩)
2013/07/02 PHP
ThinkPHP 3.2 数据分页代码分享
2014/10/14 PHP
yii2.0实现创建简单widgets示例
2016/07/18 PHP
php微信开发之谷歌测距
2018/06/14 PHP
PHP INT类型在内存中占字节详解
2019/07/20 PHP
jQuery 剧场版 你必须知道的javascript
2009/05/27 Javascript
prototype 中文参数乱码解决方案
2009/11/09 Javascript
jquery中get和post的简单实例
2014/02/04 Javascript
我的Node.js学习之路(二)NPM模块管理
2014/07/06 Javascript
使用JavaScript实现旋转的彩圈特效
2015/06/23 Javascript
js实现网页抽奖实例
2015/08/05 Javascript
javascript性能优化之DOM交互操作实例分析
2015/12/12 Javascript
js数组去重的hash方法
2016/12/22 Javascript
JS实现浏览上传文件的代码
2017/08/23 Javascript
JavaScript实现的超简单计算器功能示例
2017/12/23 Javascript
vue2 前端搜索实现示例
2018/02/26 Javascript
angular6 填坑之sdk的方法
2018/12/27 Javascript
Vue使用.sync 实现父子组件的双向绑定数据问题
2019/04/04 Javascript
微信小程序聊天功能的示例代码
2020/01/13 Javascript
JS数组方法reduce的用法实例分析
2020/03/03 Javascript
vue 弹出遮罩层样式实例
2020/07/22 Javascript
原生JS实现拖拽效果
2020/12/04 Javascript
[00:57]林俊杰助阵DOTA2亚洲邀请赛
2015/01/28 DOTA
详解python3中tkinter知识点
2018/06/21 Python
Python os.access()用法实例
2019/02/18 Python
PyQt5 在label显示的图片中绘制矩形的方法
2019/06/17 Python
pymysql模块的使用(增删改查)详解
2019/09/09 Python
pytorch多GPU并行运算的实现
2019/09/27 Python
html5适合移动应用开发的12大特性
2014/03/19 HTML / CSS
html5 canvas的绘制文本自动换行的示例代码
2018/09/17 HTML / CSS
美国体育用品商店:Paragon Sports
2017/10/08 全球购物
波兰快递服务:Globkurier.pl
2019/11/08 全球购物
家具促销活动方案
2014/02/16 职场文书
医学生职业生涯规划书范文
2014/03/13 职场文书
营销计划书
2015/01/17 职场文书