Pytorch 中net.train 和 net.eval的使用说明


Posted in Python onMay 22, 2021

在训练模型时会在前面加上:

model.train()

在测试模型时在前面使用:

model.eval()

同时发现,如果不写这两个程序也可以运行,这是因为这两个方法是针对在网络训练和测试时采用不同方式的情况,比如Batch Normalization 和 Dropout。

训练时是正对每个min-batch的,但是在测试中往往是针对单张图片,即不存在min-batch的概念。

由于网络训练完毕后参数都是固定的,因此每个批次的均值和方差都是不变的,因此直接结算所有batch的均值和方差。

所有Batch Normalization的训练和测试时的操作不同

在训练中,每个隐层的神经元先乘概率P,然后在进行激活,在测试中,所有的神经元先进行激活,然后每个隐层神经元的输出乘P。

补充:Pytorch踩坑记录——model.eval()

最近在写代码时遇到一个问题,原本训练好的模型,加载进来进行inference准确率直接掉了5个点,尼玛,这简直不能忍啊~本菜鸡下意识地感知到我肯定又在哪里写了bug了~~~于是开始到处排查,从model load到data load,最终在一个被我封装好的module的犄角旮旯里找到了问题,于是顺便就在这里总结一下,避免以后再犯。

对于训练好的模型加载进来准确率和原先的不符,比较常见的有两方面的原因:

1)data

2)model.state_dict()

1) data

数据方面,检查前后两次加载的data有没有发生变化。首先检查 transforms.Normalize 使用的均值和方差是否和训练时相同;另外检查在这个过程中数据是否经过了存储形式的改变,这有可能会带来数据精度的变化导致一定的信息丢失。

比如我过用的其中一个数据集,原先将图片存储成向量形式,但其对应的是“png”格式的数据(后来在原始文件中发现了相应的描述。),而我进行了一次data-to-img操作,将向量转换成了“jpg”形式,这时加载进来便造成了掉点。

2)model.state_dict()

第一方面造成的掉点一般不会太严重,第二方面造成的掉点就比较严重了,一旦模型的参数加载错了,那就误差大了。

如果是参数没有正确加载进来则比较容易发现,这时准确率非常低,几乎等于瞎猜。

而我这次遇到的情况是,准确率并不是特别低,只掉了几个点,检查了多次,均显示模型参数已经成功加载了。后来仔细查看后发现在其中一次调用模型进行inference时,忘了写 ‘model.eval()',造成了模型的参数发生变化,再次调用则出现了掉点。

于是又回顾了一下model.eval()和model.train()的具体作用。如下:

model.train() 和 model.eval() 一般在模型训练和评价的时候会加上这两句,主要是针对由于model 在训练时和评价时 Batch

Normalization 和 Dropout 方法模式不同:

a) model.eval(),不启用 BatchNormalization 和 Dropout。此时pytorch会自动把BN和DropOut固定住,不会取平均,而是用训练好的值。不然的话,一旦test的batch_size过小,很容易就会因BN层导致模型performance损失较大;

b) model.train() :启用 BatchNormalization 和 Dropout。 在模型测试阶段使用model.train() 让model变成训练模式,此时 dropout和batch normalization的操作在训练q起到防止网络过拟合的问题。

因此,在使用PyTorch进行训练和测试时一定要记得把实例化的model指定train/eval。

model.eval()   vs   torch.no_grad()

虽然二者都是eval的时候使用,但其作用并不相同:

model.eval() 负责改变batchnorm、dropout的工作方式,如在eval()模式下,dropout是不工作的。 见下方代码:

import torch
  import torch.nn as nn
 
  drop = nn.Dropout()
  x = torch.ones(10)
  
  # Train mode   
  drop.train()
  print(drop(x)) # tensor([2., 2., 0., 2., 2., 2., 2., 0., 0., 2.])   
  
  # Eval mode   
  drop.eval()
  print(drop(x)) # tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

torch.no_grad() 负责关掉梯度计算,节省eval的时间。

只进行inference时,model.eval()是必须使用的,否则会影响结果准确性。 而torch.no_grad()并不是强制的,只影响运行效率。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
wxPython中文教程入门实例
Jun 09 Python
简单使用Python自动生成文章
Dec 25 Python
快速实现基于Python的微信聊天机器人示例代码
Mar 03 Python
详解python之多进程和进程池(Processing库)
Jun 09 Python
Python导入模块时遇到的错误分析
Aug 30 Python
python pandas 对series和dataframe的重置索引reindex方法
Jun 07 Python
PyQt+socket实现远程操作服务器的方法示例
Aug 22 Python
django实现将修改好的新模型写入数据库
Mar 31 Python
vscode+PyQt5安装详解步骤
Aug 12 Python
Python web框架(django,flask)实现mysql数据库读写分离的示例
Nov 18 Python
用Python制作灯光秀短视频的思路详解
Apr 13 Python
Python实现单例模式的5种方法
Jun 15 Python
对PyTorch中inplace字段的全面理解
May 22 #Python
pytorch中F.avg_pool1d()和F.avg_pool2d()的使用操作
May 22 #Python
用python实现监控视频人数统计
Python基础之进程详解
如何在C++中调用Python
May 21 #Python
python 定义函数 返回值只取其中一个的实现
May 21 #Python
Python+Appium实现自动抢微信红包
You might like
咖啡知识 除了喝咖啡还有那些知识点
2021/03/06 新手入门
PHP提取中文首字母
2008/04/09 PHP
php自定义函数截取汉字长度
2014/05/15 PHP
PHP递归遍历指定目录的文件并统计文件数量的方法
2015/03/24 PHP
PHP中file_get_contents函数抓取https地址出错的解决方法(两种方法)
2015/09/22 PHP
WordPress用户登录框密码的隐藏与部分显示技巧
2015/12/31 PHP
ThinkPHP5.1表单令牌Token失效问题的解决
2019/03/22 PHP
JavaScript中的排序算法代码
2011/02/22 Javascript
jquery的ajax请求全面了解
2013/03/20 Javascript
jquery实现图片灯箱明暗的遮罩效果
2013/11/15 Javascript
JQuery操作iframe父页面与子页面的元素与方法(实例讲解)
2013/11/20 Javascript
第四章之BootStrap表单与图片
2016/04/25 Javascript
微信小程序 开发之滑块视图容器(swiper)详解及实例代码
2017/02/22 Javascript
vue项目常用组件和框架结构介绍
2017/12/24 Javascript
video.js 实现视频只能后退不能快进的思路详解
2018/08/09 Javascript
vue 解决循环引用组件报错的问题
2018/09/06 Javascript
详解如何在微信小程序开发中正确的使用vant ui组件
2018/09/13 Javascript
[55:56]NB vs Infamous 2019国际邀请赛淘汰赛 败者组 BO3 第二场 8.22
2019/09/05 DOTA
Python下Fabric的简单部署方法
2015/07/14 Python
python学习教程之使用py2exe打包
2017/09/24 Python
Python3实现的简单验证码识别功能示例
2018/05/02 Python
python GUI库图形界面开发之PyQt5多行文本框控件QTextEdit详细使用方法实例
2020/02/28 Python
利用HTML5 Canvas API绘制矩形的超级攻略
2016/03/21 HTML / CSS
美国知名运动产品零售商:Foot Locker
2016/07/23 全球购物
高山背包:High Sierra
2017/11/23 全球购物
大唐面试试题(CPU,UNIX等等)
2012/01/11 面试题
运动会闭幕式解说词
2014/02/21 职场文书
公司管理建议书范文
2014/03/12 职场文书
大专应届毕业生求职信
2014/07/15 职场文书
作风整顿个人剖析材料
2014/10/06 职场文书
中学感恩教育活动总结
2015/05/05 职场文书
地道战观后感2000字
2015/06/04 职场文书
庆七一主持词
2015/06/29 职场文书
2016暑期师德培训心得体会
2016/01/09 职场文书
多表查询、事务、DCL
2021/04/05 MySQL
PyQt5实现多张图片显示并滚动
2021/06/11 Python