聊聊pytorch测试的时候为何要加上model.eval()


Posted in Python onMay 23, 2021

Do need to use model.eval() when I test?

Sure, Dropout works as a regularization for preventing overfitting during training.

It randomly zeros the elements of inputs in Dropout layer on forward call.

It should be disabled during testing since you may want to use full model (no element is masked)

使用PyTorch进行训练和测试时一定注意要把实例化的model指定train/eval,eval()时,框架会自动把BN和DropOut固定住,不会取平均,而是用训练好的值,不然的话,一旦test的batch_size过小,很容易就会被BN层导致生成图片颜色失真极大!!!!!!

补充:pytorch中model eval和torch no grad()的区别

model.eval()和with torch.no_grad()的区别

在PyTorch中进行validation时,会使用model.eval()切换到测试模式,在该模式下,

主要用于通知dropout层和batchnorm层在train和val模式间切换

在train模式下,dropout网络层会按照设定的参数p设置保留激活单元的概率(保留概率=p); batchnorm层会继续计算数据的mean和var等参数并更新。

在val模式下,dropout层会让所有的激活单元都通过,而batchnorm层会停止计算和更新mean和var,直接使用在训练阶段已经学出的mean和var值。

该模式不会影响各层的gradient计算行为,即gradient计算和存储与training模式一样,只是不进行反传(backprobagation)

而with torch.no_grad()则主要是用于停止autograd模块的工作,以起到加速和节省显存的作用,具体行为就是停止gradient计算,从而节省了GPU算力和显存,但是并不会影响dropout和batchnorm层的行为。

使用场景

如果不在意显存大小和计算时间的话,仅仅使用model.eval()已足够得到正确的validation的结果;而with torch.zero_grad()则是更进一步加速和节省gpu空间(因为不用计算和存储gradient),从而可以更快计算,也可以跑更大的batch来测试。

补充:Pytorch的modle.train,model.eval,with torch.no_grad的个人理解

1. 最近在学习pytorch过程中遇到了几个问题

不理解为什么在训练和测试函数中model.eval(),和model.train()的区别,经查阅后做如下整理

一般情况下,我们训练过程如下:

1、拿到数据后进行训练,在训练过程中,使用

model.train():告诉我们的网络,这个阶段是用来训练的,可以更新参数。

2、训练完成后进行预测,在预测过程中,使用

model.eval() : 告诉我们的网络,这个阶段是用来测试的,于是模型的参数在该阶段不进行更新。

2. 但是为什么在eval()阶段会使用with torch.no_grad()?

查阅相关资料:传送门

with torch.no_grad - disables tracking of gradients in autograd.

model.eval() changes the forward() behaviour of the module it is called upon

eg, it disables dropout and has batch norm use the entire population statistics

总结一下就是说,在eval阶段了,即使不更新,但是在模型中所使用的dropout或者batch norm也就失效了,直接都会进行预测,而使用no_grad则设置让梯度Autograd设置为False(因为在训练中我们默认是True),这样保证了反向过程为纯粹的测试,而不变参数。

另外,参考文档说这样避免每一个参数都要设置,解放了GPU底层的时间开销,在测试阶段统一梯度设置为False

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python算法之求n个节点不同二叉树个数
Oct 27 Python
Python实现绘制双柱状图并显示数值功能示例
Jun 23 Python
django 发送邮件和缓存的实现代码
Jul 18 Python
详解flask入门模板引擎
Jul 18 Python
对TensorFlow的assign赋值用法详解
Jul 30 Python
matplotlib给子图添加图例的方法
Aug 03 Python
pybind11在Windows下的使用教程
Jul 04 Python
常用python爬虫库介绍与简要说明
Jan 25 Python
在django中使用apscheduler 执行计划任务的实现方法
Feb 11 Python
Jupyter notebook 远程配置及SSL加密教程
Apr 14 Python
Pycharm 设置默认解释器路径和编码格式的操作
Feb 05 Python
Python快速优雅的批量修改Word文档样式
May 20 Python
PyTorch 如何自动计算梯度
May 23 #Python
解决numpy和torch数据类型转化的问题
May 23 #Python
Python 用户输入和while循环的操作
May 23 #Python
解决Tkinter中button按钮未按却主动执行command函数的问题
May 23 #Python
python tkinter Entry控件的焦点移动操作
May 22 #Python
python3.7.2 tkinter entry框限定输入数字的操作
May 22 #Python
tensorboard 可视化之localhost:6006不显示的解决方案
You might like
php学习之function的用法
2012/07/14 PHP
PHP动态生成javascript文件的2个例子
2014/04/11 PHP
php防止伪造的数据从URL提交方法
2014/06/27 PHP
PHP基于接口技术实现简单的多态应用完整实例
2017/04/26 PHP
php+ajax实现仿百度查询下拉内容功能示例
2017/10/20 PHP
Yii2 中实现单点登录的方法
2018/03/09 PHP
表单填写时用回车代替TAB的实现方法
2007/10/09 Javascript
在IE,Firefox,Safari,Chrome,Opera浏览器上调试javascript
2008/12/02 Javascript
Javascript 类型转换方法
2010/10/24 Javascript
关于JAVASCRIPT urldecode URL解码的问题
2012/01/08 Javascript
基于jQuery的input输入框下拉提示层(自动邮箱后缀名)
2012/06/14 Javascript
JS创建自定义表格具体实现
2014/02/11 Javascript
跟我学Nodejs(一)--- Node.js简介及安装开发环境
2014/05/20 NodeJs
JQuery中extend的用法实例分析
2015/02/08 Javascript
JavaScript实现自动对页面上敏感词进行屏蔽的方法
2015/07/27 Javascript
详解vue2.0组件通信各种情况总结与实例分析
2017/03/22 Javascript
vue2导航根据路由传值,而改变导航内容的实例
2017/11/10 Javascript
ES6学习笔记之map、set与数组、对象的对比
2018/03/01 Javascript
简述JS浏览器的三种弹窗
2018/07/15 Javascript
layui实现table加载的示例代码
2018/08/14 Javascript
js实现京东秒杀倒计时功能
2019/01/21 Javascript
JavaScript中的连续赋值问题实例分析
2019/07/12 Javascript
微信小程序实现单个卡片左滑显示按钮并防止上下滑动干扰功能
2019/12/06 Javascript
javascript 易错知识点实例小结
2020/04/25 Javascript
[01:20]辉夜杯背景故事宣传片《辉夜传说》
2015/12/25 DOTA
python连接远程ftp服务器并列出目录下文件的方法
2015/04/01 Python
python使用代理ip访问网站的实例
2018/05/07 Python
python自动化报告的输出用例详解
2018/05/30 Python
Python爬虫实现抓取京东店铺信息及下载图片功能示例
2018/08/07 Python
用Python实现读写锁的示例代码
2018/11/05 Python
Python3 pip3 list 出现 DEPRECATION 警告的解决方法
2019/02/16 Python
python+OpenCV实现车牌号码识别
2019/11/08 Python
pytorch获取模型某一层参数名及参数值方式
2019/12/30 Python
社区国庆节活动总结
2015/03/23 职场文书
详解前端任务构建利器Gulp.js使用指南
2021/04/30 Javascript
收音机爱好者玩机13年,简评其使用过的19台收音机
2022/04/30 无线电