Pytorch 中net.train 和 net.eval的使用说明


Posted in Python onMay 22, 2021

在训练模型时会在前面加上:

model.train()

在测试模型时在前面使用:

model.eval()

同时发现,如果不写这两个程序也可以运行,这是因为这两个方法是针对在网络训练和测试时采用不同方式的情况,比如Batch Normalization 和 Dropout。

训练时是正对每个min-batch的,但是在测试中往往是针对单张图片,即不存在min-batch的概念。

由于网络训练完毕后参数都是固定的,因此每个批次的均值和方差都是不变的,因此直接结算所有batch的均值和方差。

所有Batch Normalization的训练和测试时的操作不同

在训练中,每个隐层的神经元先乘概率P,然后在进行激活,在测试中,所有的神经元先进行激活,然后每个隐层神经元的输出乘P。

补充:Pytorch踩坑记录——model.eval()

最近在写代码时遇到一个问题,原本训练好的模型,加载进来进行inference准确率直接掉了5个点,尼玛,这简直不能忍啊~本菜鸡下意识地感知到我肯定又在哪里写了bug了~~~于是开始到处排查,从model load到data load,最终在一个被我封装好的module的犄角旮旯里找到了问题,于是顺便就在这里总结一下,避免以后再犯。

对于训练好的模型加载进来准确率和原先的不符,比较常见的有两方面的原因:

1)data

2)model.state_dict()

1) data

数据方面,检查前后两次加载的data有没有发生变化。首先检查 transforms.Normalize 使用的均值和方差是否和训练时相同;另外检查在这个过程中数据是否经过了存储形式的改变,这有可能会带来数据精度的变化导致一定的信息丢失。

比如我过用的其中一个数据集,原先将图片存储成向量形式,但其对应的是“png”格式的数据(后来在原始文件中发现了相应的描述。),而我进行了一次data-to-img操作,将向量转换成了“jpg”形式,这时加载进来便造成了掉点。

2)model.state_dict()

第一方面造成的掉点一般不会太严重,第二方面造成的掉点就比较严重了,一旦模型的参数加载错了,那就误差大了。

如果是参数没有正确加载进来则比较容易发现,这时准确率非常低,几乎等于瞎猜。

而我这次遇到的情况是,准确率并不是特别低,只掉了几个点,检查了多次,均显示模型参数已经成功加载了。后来仔细查看后发现在其中一次调用模型进行inference时,忘了写 ‘model.eval()',造成了模型的参数发生变化,再次调用则出现了掉点。

于是又回顾了一下model.eval()和model.train()的具体作用。如下:

model.train() 和 model.eval() 一般在模型训练和评价的时候会加上这两句,主要是针对由于model 在训练时和评价时 Batch

Normalization 和 Dropout 方法模式不同:

a) model.eval(),不启用 BatchNormalization 和 Dropout。此时pytorch会自动把BN和DropOut固定住,不会取平均,而是用训练好的值。不然的话,一旦test的batch_size过小,很容易就会因BN层导致模型performance损失较大;

b) model.train() :启用 BatchNormalization 和 Dropout。 在模型测试阶段使用model.train() 让model变成训练模式,此时 dropout和batch normalization的操作在训练q起到防止网络过拟合的问题。

因此,在使用PyTorch进行训练和测试时一定要记得把实例化的model指定train/eval。

model.eval()   vs   torch.no_grad()

虽然二者都是eval的时候使用,但其作用并不相同:

model.eval() 负责改变batchnorm、dropout的工作方式,如在eval()模式下,dropout是不工作的。 见下方代码:

import torch
  import torch.nn as nn
 
  drop = nn.Dropout()
  x = torch.ones(10)
  
  # Train mode   
  drop.train()
  print(drop(x)) # tensor([2., 2., 0., 2., 2., 2., 2., 0., 0., 2.])   
  
  # Eval mode   
  drop.eval()
  print(drop(x)) # tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

torch.no_grad() 负责关掉梯度计算,节省eval的时间。

只进行inference时,model.eval()是必须使用的,否则会影响结果准确性。 而torch.no_grad()并不是强制的,只影响运行效率。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python使用try except处理程序异常的三种常用方法分析
Sep 05 Python
在pycharm中python切换解释器失败的解决方法
Oct 29 Python
pandas.DataFrame删除/选取含有特定数值的行或列实例
Nov 07 Python
解决Python print输出不换行没空格的问题
Nov 14 Python
对python多线程SSH登录并发脚本详解
Feb 14 Python
Python数据分析模块pandas用法详解
Sep 04 Python
基于python的docx模块处理word和WPS的docx格式文件方式
Feb 13 Python
Python使用GitPython操作Git版本库的方法
Feb 29 Python
Keras框架中的epoch、bacth、batch size、iteration使用介绍
Jun 10 Python
经验丰富程序员才知道的8种高级Python技巧
Jul 27 Python
Python通过类的组合模拟街道红绿灯
Sep 16 Python
利用django创建一个简易的博客网站的示例
Sep 29 Python
对PyTorch中inplace字段的全面理解
May 22 #Python
pytorch中F.avg_pool1d()和F.avg_pool2d()的使用操作
May 22 #Python
用python实现监控视频人数统计
Python基础之进程详解
如何在C++中调用Python
May 21 #Python
python 定义函数 返回值只取其中一个的实现
May 21 #Python
Python+Appium实现自动抢微信红包
You might like
PHP下使用CURL方式POST数据至API接口的代码
2013/02/14 PHP
php生成mysql的数据字典
2016/07/07 PHP
laravel 5.1下php artisan migrate的使用注意事项总结
2017/06/07 PHP
[原创]PHP实现字节数Byte转换为KB、MB、GB、TB的方法
2017/08/31 PHP
PHP闭包定义与使用简单示例
2018/04/13 PHP
yii 框架实现按天,月,年,自定义时间段统计数据的方法分析
2020/04/04 PHP
再谈ie和firefox下的document.all属性
2009/10/21 Javascript
解决jquery .ajax 在IE下卡死问题的解决方法
2009/10/26 Javascript
JS将光标聚焦在文本最后的实现代码
2014/03/28 Javascript
javascript学习笔记之10个原生技巧
2014/05/21 Javascript
浅谈Javascript的静态属性和原型属性
2015/05/07 Javascript
快速掌握Node.js之Window下配置NodeJs环境
2016/03/21 NodeJs
使用nodejs中httpProxy代理时候出现404异常的解决方法
2016/08/15 NodeJs
微信小程序 在Chrome浏览器上运行以及WebStorm的使用
2016/09/27 Javascript
学习JavaScript图片预加载模块
2016/11/07 Javascript
解决Node.js使用MySQL出现connect ECONNREFUSED 127.0.0.1:3306的问题
2017/03/09 Javascript
JavaScript时间戳与时间日期间相互转换
2017/12/11 Javascript
JS根据Unix时间戳显示发布时间是多久前【项目实测】
2019/07/10 Javascript
[01:09:10]NB vs Liquid Supermajor小组赛 A组胜者组决赛 BO3 第一场 6.2
2018/06/04 DOTA
不要用强制方法杀掉python线程
2017/02/26 Python
对Python中实现两个数的值交换的集中方法详解
2019/01/11 Python
Django如何自定义model创建数据库索引的顺序
2019/06/20 Python
详解Python并发编程之从性能角度来初探并发编程
2019/08/23 Python
浅析Python requests 模块
2020/10/09 Python
linux系统下pip升级报错的解决方法
2021/01/31 Python
css3 flex布局 justify-content:space-between 最后一行左对齐
2020/01/02 HTML / CSS
Lowe’s加拿大:家居装修、翻新和五金店
2019/12/06 全球购物
最畅销的视频游戏享受高达90%的折扣:CDKeys
2020/02/10 全球购物
Linux开机引导的步骤是什么
2014/02/26 面试题
关于教师节的演讲稿
2014/09/04 职场文书
医药公司采购员岗位职责
2014/09/12 职场文书
普通党员群众路线教育实践活动心得体会
2014/11/04 职场文书
2014年流动人口工作总结
2014/11/26 职场文书
2015年个人自我剖析材料
2014/12/29 职场文书
税务会计岗位职责
2015/04/02 职场文书
安全教育第一课观后感
2015/06/17 职场文书