浅谈pytorch grad_fn以及权重梯度不更新的问题


Posted in Python onAugust 20, 2019

前提:我训练的是二分类网络,使用语言为pytorch

Varibale包含三个属性:

data:存储了Tensor,是本体的数据

grad:保存了data的梯度,本事是个Variable而非Tensor,与data形状一致

grad_fn:指向Function对象,用于反向传播的梯度计算之用

在构建网络时,刚开始的错误为:没有可以grad_fn属性的变量。

百度后得知要对需要进行迭代更新的变量设置requires_grad=True ,操作如下:

train_pred = Variable(train_pred.float(), requires_grad=True)`

这样设置之后网络是跑起来了,但是准确率一直没有提升,很明显可以看出网络什么都没学到。

我输出 model.parameters() (网络内部的权重和偏置)查看,发现它的权重并没有更新,一直是同一个值,至此可以肯定网络什么都没学到,还是迭代那里出了问题。

询问同门后发现问题不在这里。

计算loss时,target与train_pred的size不匹配,我以以下操作修改了train_pred,使两者尺寸一致,才导致了上述问题。

train_pred = model(data)
  train_pred = torch.max(train_pred, 1)[1].data.squeeze()
  train_pred = Variable(train_pred.float(), requires_grad=False)
  train_loss = F.binary_cross_entropy(validation_pred.float(), target)
  train_loss.backward()

对train_pred多次处理后,它已无法正确地反向传播,实际上应该更改target,使其与train_pred size一致。

重点!!!要想loss正确反向传播,应直接将model(data)传入loss函数。

最终修改代码如下:

for batch_idx, (data, target) in enumerate(train_loader):
  # Get Samples
  label = target.view(target.size(0), 1).long()
  target_onehot = torch.zeros(data.shape[0], args.num_classes).scatter_(1, label, 1)
  data, target_onehot = Variable(data.cuda()), Variable(target_onehot.cuda().float())
  
  model.zero_grad()

  # Predict
  train_pred = model(data)
  train_loss = F.binary_cross_entropy(train_pred, target_onehot)
  train_loss.backward()
  optimizer.step()

以上这篇浅谈pytorch grad_fn以及权重梯度不更新的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
学习python处理python编码问题
Mar 13 Python
简单介绍Ruby中的CGI编程
Apr 10 Python
Python映射拆分操作符用法实例
May 19 Python
Python学习小技巧总结
Jun 10 Python
Python使用matplotlib绘制随机漫步图
Aug 27 Python
Matplotlib中文乱码的3种解决方案
Nov 15 Python
详解python实现数据归一化处理的方式:(0,1)标准化
Jul 17 Python
pyftplib中文乱码问题解决方案
Jan 11 Python
Python列表list操作相关知识小结
Jan 29 Python
python numpy实现多次循环读取文件 等间隔过滤数据示例
Mar 14 Python
python在CMD界面读取excel所有数据的示例
Sep 28 Python
python图像处理 PIL Image操作实例
Apr 09 Python
解决Pytorch 训练与测试时爆显存(out of memory)的问题
Aug 20 #Python
python中用logging实现日志滚动和过期日志删除功能
Aug 20 #Python
python3中替换python2中cmp函数的实现
Aug 20 #Python
python 并发编程 多路复用IO模型详解
Aug 20 #Python
关于pytorch中网络loss传播和参数更新的理解
Aug 20 #Python
对pytorch中的梯度更新方法详解
Aug 20 #Python
PyTorch: 梯度下降及反向传播的实例详解
Aug 20 #Python
You might like
php部分常见问题总结
2008/03/27 PHP
PHP 5.5 创建和验证哈希最简单的方法详解
2013/11/07 PHP
php中Session的生成机制、回收机制和存储机制探究
2014/08/19 PHP
Javascript YUI 读码日记之 YAHOO.util.Dom - Part.2 0
2008/03/22 Javascript
JavaScript 实现简单的倒计时弹窗DEMO附图
2014/03/05 Javascript
javascript框架设计读书笔记之模块加载系统
2014/12/02 Javascript
js使用DOM设置单选按钮、复选框及下拉菜单的方法
2015/01/20 Javascript
基于JS实现导航条之调用网页助手小精灵的方法
2016/06/17 Javascript
RequireJS多页面应用实例分析
2016/06/29 Javascript
vue2.0父子组件及非父子组件之间的通信方法
2017/01/21 Javascript
ES6新特性之解构、参数、模块和记号用法示例
2017/04/01 Javascript
MUI 解决动态列表页图片懒加载再次加载不成功的bug问题
2017/04/13 Javascript
Express之get,pos请求参数的获取
2017/05/02 Javascript
Vue2.0基于vue-cli+webpack父子组件通信(实例讲解)
2017/09/14 Javascript
vue 1.x 交互实现仿百度下拉列表示例
2017/10/21 Javascript
纯JavaScript实现实时反馈系统时间
2017/10/26 Javascript
分享ES6的7个实用技巧
2018/01/18 Javascript
JavaScript日期工具类DateUtils定义与用法示例
2018/09/03 Javascript
微信小程序 网络通信实现详解
2019/07/23 Javascript
封装 axios+promise通用请求函数操作
2020/08/11 Javascript
Python的Django框架中URLconf相关的一些技巧整理
2015/07/18 Python
TensorFlow使用Graph的基本操作的实现
2020/04/22 Python
Python爬虫模拟登陆哔哩哔哩(bilibili)并突破点选验证码功能
2020/12/21 Python
Python如何实现Paramiko的二次封装
2021/01/30 Python
纯css3实现宠物小鸡实例代码
2018/10/08 HTML / CSS
TripAdvisor台湾:全球最大旅游网站
2018/08/26 全球购物
SportsDirect.com马来西亚:英国第一体育零售商
2018/11/21 全球购物
C++的几个面试题附答案
2016/08/03 面试题
运动会入场词50字
2014/02/20 职场文书
可口可乐广告词
2014/03/20 职场文书
妈妈活动方案
2014/08/15 职场文书
Anaconda安装pytorch及配置PyCharm 2021环境
2021/06/04 Python
SQL实现LeetCode(178.分数排行)
2021/08/04 MySQL
Nginx图片服务器配置之后图片访问404的问题解决
2022/03/21 Servers
《黑岩★★射手 DAWN FALL》BD发售宣传CM公开
2022/04/04 日漫
不负正版帝国之名 《重返帝国》引领SLG手游制作新的标杆
2022/04/07 其他游戏