浅谈pytorch grad_fn以及权重梯度不更新的问题


Posted in Python onAugust 20, 2019

前提:我训练的是二分类网络,使用语言为pytorch

Varibale包含三个属性:

data:存储了Tensor,是本体的数据

grad:保存了data的梯度,本事是个Variable而非Tensor,与data形状一致

grad_fn:指向Function对象,用于反向传播的梯度计算之用

在构建网络时,刚开始的错误为:没有可以grad_fn属性的变量。

百度后得知要对需要进行迭代更新的变量设置requires_grad=True ,操作如下:

train_pred = Variable(train_pred.float(), requires_grad=True)`

这样设置之后网络是跑起来了,但是准确率一直没有提升,很明显可以看出网络什么都没学到。

我输出 model.parameters() (网络内部的权重和偏置)查看,发现它的权重并没有更新,一直是同一个值,至此可以肯定网络什么都没学到,还是迭代那里出了问题。

询问同门后发现问题不在这里。

计算loss时,target与train_pred的size不匹配,我以以下操作修改了train_pred,使两者尺寸一致,才导致了上述问题。

train_pred = model(data)
  train_pred = torch.max(train_pred, 1)[1].data.squeeze()
  train_pred = Variable(train_pred.float(), requires_grad=False)
  train_loss = F.binary_cross_entropy(validation_pred.float(), target)
  train_loss.backward()

对train_pred多次处理后,它已无法正确地反向传播,实际上应该更改target,使其与train_pred size一致。

重点!!!要想loss正确反向传播,应直接将model(data)传入loss函数。

最终修改代码如下:

for batch_idx, (data, target) in enumerate(train_loader):
  # Get Samples
  label = target.view(target.size(0), 1).long()
  target_onehot = torch.zeros(data.shape[0], args.num_classes).scatter_(1, label, 1)
  data, target_onehot = Variable(data.cuda()), Variable(target_onehot.cuda().float())
  
  model.zero_grad()

  # Predict
  train_pred = model(data)
  train_loss = F.binary_cross_entropy(train_pred, target_onehot)
  train_loss.backward()
  optimizer.step()

以上这篇浅谈pytorch grad_fn以及权重梯度不更新的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python网络编程之UDP通信实例(含服务器端、客户端、UDP广播例子)
Apr 25 Python
Python中处理unchecked未捕获异常实例
Jan 17 Python
Python中title()方法的使用简介
May 20 Python
Python对数据库操作
Mar 28 Python
Python获取SQLite查询结果表列名的方法
Jun 21 Python
Python实现针对给定单链表删除指定节点的方法
Apr 12 Python
对python中的logger模块全面讲解
Apr 28 Python
pandas 将list切分后存入DataFrame中的实例
Jul 03 Python
pygame游戏之旅 创建游戏窗口界面
Nov 20 Python
Django unittest 设置跳过某些case的方法
Dec 26 Python
python实现转圈打印矩阵
Mar 02 Python
基于Python的一个自动录入表格的小程序
Aug 05 Python
解决Pytorch 训练与测试时爆显存(out of memory)的问题
Aug 20 #Python
python中用logging实现日志滚动和过期日志删除功能
Aug 20 #Python
python3中替换python2中cmp函数的实现
Aug 20 #Python
python 并发编程 多路复用IO模型详解
Aug 20 #Python
关于pytorch中网络loss传播和参数更新的理解
Aug 20 #Python
对pytorch中的梯度更新方法详解
Aug 20 #Python
PyTorch: 梯度下降及反向传播的实例详解
Aug 20 #Python
You might like
PHP对表单提交特殊字符的过滤和处理方法汇总
2014/02/18 PHP
解决PhpMyAdmin中导入2M以上大文件限制的方法分享
2014/06/06 PHP
php文件缓存类汇总
2014/11/21 PHP
php自定义urlencode,urldecode函数实例
2015/03/24 PHP
php检查页面是否被百度收录
2015/10/28 PHP
PHP新建类问题分析及解决思路
2015/11/19 PHP
PHP给源代码加密的几种方法汇总(推荐)
2018/02/06 PHP
用javascript实现的图片马赛克后显示并切换加文字功能
2007/04/21 Javascript
jquery 设置元素相对于另一个元素的top值(实例代码)
2013/11/06 Javascript
javascript数组快速打乱重排的方法
2014/01/02 Javascript
基于Jquery+Ajax+Json实现分页显示附效果图
2014/07/30 Javascript
jQuery 利用$.ajax 时获取原生XMLHttpRequest 对象的方法
2016/08/25 Javascript
浅谈JS中的!=、== 、!==、===的用法和区别
2016/09/24 Javascript
微信小程序中的onLoad详解及简单实例
2017/04/05 Javascript
JavaScript数据结构之二叉树的遍历算法示例
2017/04/13 Javascript
[02:27]《DAC最前线》之附加赛征程
2015/01/29 DOTA
python fabric实现远程操作和部署示例
2014/03/25 Python
总结Python中逻辑运算符的使用
2015/05/13 Python
python3之微信文章爬虫实例讲解
2017/07/12 Python
Python if语句知识点用法总结
2018/06/10 Python
对python模块中多个类的用法详解
2019/01/10 Python
python交换两个变量的值方法
2019/01/12 Python
python将excel转换为csv的代码方法总结
2019/07/03 Python
Pytorch 保存模型生成图片方式
2020/01/10 Python
解决Jupyter Notebook使用parser.parse_args出现错误问题
2020/04/20 Python
python是怎么被发明的
2020/06/15 Python
Python中openpyxl实现vlookup函数的实例
2020/10/28 Python
英国办公用品商店:Office Outlet
2018/04/04 全球购物
DOUGLAS波兰:在线销售香水和化妆品
2020/07/05 全球购物
英国领先的在线鱼贩:The Fish Society
2020/08/12 全球购物
灵泰克Java笔试题
2016/01/09 面试题
八一建军节部队活动方案
2014/02/04 职场文书
主持词开场白
2014/03/17 职场文书
中等生评语大全
2014/05/04 职场文书
2015年幼儿园班主任个人工作总结
2015/10/22 职场文书
html+css实现文字折叠特效实例
2021/06/02 HTML / CSS