浅谈pytorch grad_fn以及权重梯度不更新的问题


Posted in Python onAugust 20, 2019

前提:我训练的是二分类网络,使用语言为pytorch

Varibale包含三个属性:

data:存储了Tensor,是本体的数据

grad:保存了data的梯度,本事是个Variable而非Tensor,与data形状一致

grad_fn:指向Function对象,用于反向传播的梯度计算之用

在构建网络时,刚开始的错误为:没有可以grad_fn属性的变量。

百度后得知要对需要进行迭代更新的变量设置requires_grad=True ,操作如下:

train_pred = Variable(train_pred.float(), requires_grad=True)`

这样设置之后网络是跑起来了,但是准确率一直没有提升,很明显可以看出网络什么都没学到。

我输出 model.parameters() (网络内部的权重和偏置)查看,发现它的权重并没有更新,一直是同一个值,至此可以肯定网络什么都没学到,还是迭代那里出了问题。

询问同门后发现问题不在这里。

计算loss时,target与train_pred的size不匹配,我以以下操作修改了train_pred,使两者尺寸一致,才导致了上述问题。

train_pred = model(data)
  train_pred = torch.max(train_pred, 1)[1].data.squeeze()
  train_pred = Variable(train_pred.float(), requires_grad=False)
  train_loss = F.binary_cross_entropy(validation_pred.float(), target)
  train_loss.backward()

对train_pred多次处理后,它已无法正确地反向传播,实际上应该更改target,使其与train_pred size一致。

重点!!!要想loss正确反向传播,应直接将model(data)传入loss函数。

最终修改代码如下:

for batch_idx, (data, target) in enumerate(train_loader):
  # Get Samples
  label = target.view(target.size(0), 1).long()
  target_onehot = torch.zeros(data.shape[0], args.num_classes).scatter_(1, label, 1)
  data, target_onehot = Variable(data.cuda()), Variable(target_onehot.cuda().float())
  
  model.zero_grad()

  # Predict
  train_pred = model(data)
  train_loss = F.binary_cross_entropy(train_pred, target_onehot)
  train_loss.backward()
  optimizer.step()

以上这篇浅谈pytorch grad_fn以及权重梯度不更新的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现网站文件的全备份和差异备份
Nov 30 Python
django开发之settings.py中变量的全局引用详解
Mar 29 Python
Python中执行存储过程及获取存储过程返回值的方法
Oct 07 Python
Django实现简单分页功能的方法详解
Dec 05 Python
使用python实现knn算法
Dec 20 Python
Python cookbook(数据结构与算法)保存最后N个元素的方法
Feb 13 Python
Python爬虫实现全国失信被执行人名单查询功能示例
May 03 Python
python 爬虫 批量获取代理ip的实例代码
May 22 Python
详解python3中zipfile模块用法
Jun 18 Python
Python文本文件的合并操作方法代码实例
Mar 31 Python
Python内存映射文件读写方式
Apr 24 Python
Pycharm 如何设置HTML文件自动补全代码或标签
May 21 Python
解决Pytorch 训练与测试时爆显存(out of memory)的问题
Aug 20 #Python
python中用logging实现日志滚动和过期日志删除功能
Aug 20 #Python
python3中替换python2中cmp函数的实现
Aug 20 #Python
python 并发编程 多路复用IO模型详解
Aug 20 #Python
关于pytorch中网络loss传播和参数更新的理解
Aug 20 #Python
对pytorch中的梯度更新方法详解
Aug 20 #Python
PyTorch: 梯度下降及反向传播的实例详解
Aug 20 #Python
You might like
CodeIgniter框架数据库事务处理的设计缺陷和解决方案
2014/07/25 PHP
php中实现字符串翻转的方法
2017/02/22 PHP
PHP中使用CURL发送get/post请求上传图片批处理功能
2018/10/15 PHP
php中访问修饰符的知识点总结
2019/01/27 PHP
JS的Document属性和方法小结
2013/09/17 Javascript
动态创建script标签实现跨域资源访问的方法介绍
2014/02/28 Javascript
JS+CSS实现弹出全屏灰黑色透明遮罩效果的方法
2014/12/20 Javascript
javascript上下方向键控制表格行选中并高亮显示的方法
2015/02/13 Javascript
基于javascript实现动态时钟效果
2020/08/18 Javascript
node.js缺少mysql模块运行报错的解决方法
2016/11/13 Javascript
基于javascript实现按圆形排列DIV元素(一)
2016/12/02 Javascript
node.js程序作为服务并在windows下开机自启动(用forever)
2017/03/29 Javascript
详解Vue路由开启keep-alive时的注意点
2017/06/20 Javascript
Node.js中Bootstrap-table的两种分页的实现方法
2017/09/18 Javascript
Angular实现的日程表功能【可添加及隐藏显示内容】
2017/12/27 Javascript
js 将canvas生成图片保存,或直接保存一张图片的实现方法
2018/01/02 Javascript
vue中实现methods一个方法调用另外一个方法
2018/02/08 Javascript
vue中$refs的用法及作用详解
2018/04/24 Javascript
vue里input根据value改变背景色的实例
2018/09/29 Javascript
Element输入框带历史查询记录的实现示例
2019/01/15 Javascript
深入浅析vue中cross-env的使用
2019/09/12 Javascript
python 中split 和 strip的实例详解
2017/07/12 Python
python实现关键词提取的示例讲解
2018/04/28 Python
对python中for、if、while的区别与比较方法
2018/06/25 Python
浅谈python图片处理Image和skimage的区别
2019/08/04 Python
Python爬虫实现使用beautifulSoup4爬取名言网功能案例
2019/09/15 Python
Python flask路由间传递变量实例详解
2020/06/03 Python
python进行二次方程式计算的实例讲解
2020/12/06 Python
东南亚旅游平台:The Trip Guru
2018/01/01 全球购物
美国排名第一的葡萄酒俱乐部:Firstleaf Wine Club
2020/01/02 全球购物
工程质量月活动方案
2014/02/19 职场文书
身边的榜样活动方案
2014/08/20 职场文书
夫妻婚内购房协议书
2014/10/05 职场文书
Java循环队列与非循环队列的区别总结
2021/06/22 Java/Android
一篇文章弄清楚Ajax请求的五个步骤
2022/03/17 Javascript
「月刊Action」2022年5月号封面公开
2022/03/21 日漫