踩坑:pytorch中eval模式下结果远差于train模式介绍


Posted in Python onJune 23, 2020

首先,eval模式和train模式得到不同的结果是正常的。我的模型中,eval模式和train模式不同之处在于Batch Normalization和Dropout。Dropout比较简单,在train时会丢弃一部分连接,在eval时则不会。Batch Normalization,在train时不仅使用了当前batch的均值和方差,也使用了历史batch统计上的均值和方差,并做一个加权平均(momentum参数)。在test时,由于此时batchsize不一定一致,因此不再使用当前batch的均值和方差,仅使用历史训练时的统计值。

我出bug的现象是,train模式下可以收敛,但一旦在测试中切换到了eval模式,结果就很差。如果在测试中仍沿用train模式,反而可以得到不错的结果。为了确保是程序bug而不是算法本身就不适合于预测,我在测试时再次使用了训练集,正常情况下此时应发生过拟合,正确率一定会很高,然而eval模式下正确率仍然很低。参照网上的一些说法(Performance highly degraded when eval() is activated in the test phase
),我调大了batchsize,降低了BN层的momentum,检查了是否存在不同层使用相同BN层的bug,均不见效。有一种方法说应在BN层设置track_running_stats为False,它虽然带来了好的效果,但实际上它只不过是不用eval模式,切回train模式罢了,所以也不对。

学习了在训练过程中,如何将BN层中统计的均值和方差输出。即在forward()中,

# bn是一个BN层,torch.nn.batch_normalization(...)
print(bn.running_mean)
print(bn.running_var)

同时学习了如何输出一个Tensor自身的均值和方差,即

# x是一个Tensor,dims是需要计算的维度
print(x.cpu().detach().numpy().mean(dims)
print(x.cpu().detach().numpy().var(dims)

观察每一层的输出结果,发现出现了很大的方差,才猛然意识到自己的输入数据没有做归一化(事后想想也确实如此,毕竟模型和训练方法都是github上参考别人的,出错概率很小;反而是自己写的DataSet部分,其实是最容易出错的)。给模型加上归一化后,eval和train的结果就没有问题了。

再次验证了我的观点:越是玄学的问题,越是傻逼的bug。

补充知识:Pytorch中的train和eval用法注意点

1.介绍

一般情况,model.train()是在训练的时候用到,model.eval()是在测试的时候用到

2.用法

如果模型中没有类似于BN这样的归一化或者Dropout,model.train()和model.eval()可以不要(建议写一下,比较安全),并且model.train()和model.eval()得到的效果是一样

如果模型中有类似于BN这样的归一化或者Dropout,并且程序需要边训练和边测试,最好就是用model.eval()测试完之后,后面补一个model.train()。

其中model.train()是保证BN用每一批数据的均值和方差,而model.eval()是保证BN用全部训练数据的均值和方差;而对于Dropout,model.train()是随机取一部分网络连接来训练更新参数,而model.eval()是利用到了所有网络连接(结果是取了平均)

以上这篇踩坑:pytorch中eval模式下结果远差于train模式介绍就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python采集博客中上传的QQ截图文件
Jul 18 Python
Python实现树的先序、中序、后序排序算法示例
Jun 23 Python
python线程池threadpool实现篇
Apr 27 Python
python 通过logging写入日志到文件和控制台的实例
Apr 28 Python
在python中以相同顺序shuffle两个list的方法
Dec 13 Python
VSCode Python开发环境配置的详细步骤
Feb 22 Python
详解python中的生成器、迭代器、闭包、装饰器
Aug 22 Python
python输出带颜色字体实例方法
Sep 01 Python
Python 类的私有属性和私有方法实例分析
Sep 29 Python
基于Python实现船舶的MMSI的获取(推荐)
Oct 21 Python
python os模块常用的29种方法使用详解
Jun 02 Python
解决TensorFlow训练模型及保存数量限制的问题
Mar 03 Python
pytorch掉坑记录:model.eval的作用说明
Jun 23 #Python
Python使用Selenium实现淘宝抢单的流程分析
Jun 23 #Python
python2和python3哪个使用率高
Jun 23 #Python
python使用QQ邮箱实现自动发送邮件
Jun 22 #Python
浅谈keras中loss与val_loss的关系
Jun 22 #Python
python实现简易版学生成绩管理系统
Jun 22 #Python
python能否java成为主流语言吗
Jun 22 #Python
You might like
php获取地址栏信息的代码
2008/10/08 PHP
Thinkphp中import的几个用法详细介绍
2014/07/02 PHP
基于PHP的简单采集数据入库程序
2014/07/30 PHP
如何使用php脚本给html中引用的js和css路径打上版本号
2015/11/18 PHP
写了一个layout,拖动条连贯,内容区可为iframe
2007/08/19 Javascript
JS使用正则表达式实现关键字替换加粗功能示例
2016/08/03 Javascript
jQuery Validation Engine验证控件调用外部函数验证的方法
2017/01/18 Javascript
用nodejs搭建websocket服务器
2017/01/23 NodeJs
单击按钮发送验证码,出现倒计时的简单实例
2017/03/17 Javascript
浅谈js基础数据类型和引用类型,深浅拷贝问题,以及内存分配问题
2017/09/02 Javascript
分享vue里swiper的一些坑
2018/08/30 Javascript
微信小程序图片加载失败时替换为默认图片的方法
2019/12/09 Javascript
echarts饼图各个板块之间的空隙如何实现
2020/12/01 Javascript
微信小程序实现购物车小功能
2020/12/30 Javascript
Django集成百度富文本编辑器uEditor攻略
2014/07/04 Python
python开发之tkinter实现图形随鼠标移动的方法
2015/11/11 Python
python监控文件或目录变化
2016/06/07 Python
详解爬虫被封的问题
2019/04/23 Python
pandas将多个dataframe以多个sheet的形式保存到一个excel文件中
2019/10/10 Python
Python实现元素等待代码实例
2019/11/11 Python
python3 webp转gif格式的实现示例
2019/12/10 Python
From CSV to SQLite3 by python 导入csv到sqlite实例
2020/02/14 Python
关于Python turtle库使用时坐标的确定方法
2020/03/19 Python
pytorch中的weight-initilzation用法
2020/06/24 Python
python正则表达式 匹配反斜杠的操作方法
2020/08/07 Python
Pycharm 设置默认解释器路径和编码格式的操作
2021/02/05 Python
美国著名手表网站:Timepiece
2017/11/15 全球购物
Beach Bunny Swimwear官网:设计师泳装和性感比基尼
2019/03/13 全球购物
Java中各种基本数据类型的默认值都是什么
2016/12/22 面试题
家居饰品店创业计划书
2014/01/31 职场文书
2014年端午节活动方案
2014/03/11 职场文书
授权委托书
2014/09/17 职场文书
校园开放日新闻稿
2015/07/17 职场文书
四十年同学聚会致辞
2015/07/28 职场文书
记者节感言
2015/08/03 职场文书
zabbix 代理服务器的部署与 zabbix-snmp 监控问题
2022/07/15 Servers