pytorch中的model.eval()和BN层的使用


Posted in Python onMay 22, 2021

看代码吧~

class ConvNet(nn.module):
    def __init__(self, num_class=10):
        super(ConvNet, self).__init__()
        self.layer1 = nn.Sequential(nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
                                    nn.BatchNorm2d(16),
                                    nn.ReLU(),
                                    nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer2 = nn.Sequential(nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
                                    nn.BatchNorm2d(32),
                                    nn.ReLU(),
                                    nn.MaxPool2d(kernel_size=2, stride=2))
        self.fc = nn.Linear(7*7*32, num_classes)
         
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        print(out.size())
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        return out
# Test the model
model.eval()  # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

如果网络模型model中含有BN层,则在预测时应当将模式切换为评估模式,即model.eval()。

评估模拟下BN层的均值和方差应该是整个训练集的均值和方差,即 moving mean/variance。

训练模式下BN层的均值和方差为mini-batch的均值和方差,因此应当特别注意。

补充:Pytorch 模型训练模式和eval模型下差别巨大(Pytorch train and eval)附解决方案

当pytorch模型写明是eval()时有时表现的结果相对于train(True)差别非常巨大,这种差别经过逐层查看,主要来源于使用了BN,在eval下,使用的BN是一个固定的running rate,而在train下这个running rate会根据输入发生改变。

解决方案是冻住bn

def freeze_bn(m):
    if isinstance(m, nn.BatchNorm2d):
        m.eval()
model.apply(freeze_bn)

这样可以获得稳定输出的结果。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Tornado协程在python2.7如何返回值(实现方法)
Jun 22 Python
matplotlib 纵坐标轴显示数据值的实例
May 25 Python
Python实现的读写json文件功能示例
Jun 05 Python
Python3数据库操作包pymysql的操作方法
Jul 16 Python
python 实现批量xls文件转csv文件的方法
Oct 23 Python
python使用wxpy轻松实现微信防撤回的方法
Feb 21 Python
Django框架使用mysql视图操作示例
May 15 Python
Django 多对多字段的更新和插入数据实例
Mar 31 Python
jupyter lab的目录调整及设置默认浏览器为chrome的方法
Apr 10 Python
解决keras,val_categorical_accuracy:,0.0000e+00问题
Jul 02 Python
python使用opencv resize图像不进行插值的操作
Jul 05 Python
python Paramiko使用示例
Sep 21 Python
解决Pytorch中关于model.eval的问题
Pytorch 中net.train 和 net.eval的使用说明
May 22 #Python
对PyTorch中inplace字段的全面理解
May 22 #Python
pytorch中F.avg_pool1d()和F.avg_pool2d()的使用操作
May 22 #Python
用python实现监控视频人数统计
Python基础之进程详解
如何在C++中调用Python
May 21 #Python
You might like
PHP面向对象法则
2012/02/23 PHP
PHP笔记之:日期函数的使用介绍
2013/04/24 PHP
laravel 5 实现模板主题功能(续)
2015/03/02 PHP
php PDO判断连接是否可用的实现方法
2017/04/03 PHP
jQuery实现的一个自定义Placeholder属性插件
2014/08/11 Javascript
JavaScript中的Math.LOG2E属性使用详解
2015/06/14 Javascript
JavaScript代码判断点击第几个按钮
2015/12/13 Javascript
jquery插件jquery.confirm弹出确认消息
2015/12/22 Javascript
MVC+jQuery.Ajax异步实现增删改查和分页
2020/12/22 Javascript
浅谈Webpack自动化构建实践指南
2017/12/18 Javascript
javascript实现自由编辑图片代码详解
2019/06/21 Javascript
koa2服务端使用jwt进行鉴权及路由权限分发的流程分析
2019/07/22 Javascript
详解elementui之el-image-viewer(图片查看器)
2019/08/30 Javascript
JS随机密码生成算法
2019/09/23 Javascript
JavaScript中变量提升和函数提升的详解
2020/08/07 Javascript
js实现验证码干扰(动态)
2021/02/23 Javascript
搭建Python的Django框架环境并建立和运行第一个App的教程
2016/07/02 Python
python 用下标截取字符串的实例
2018/12/25 Python
Python之使用adb shell命令启动应用的方法详解
2019/01/07 Python
Django使用中间键实现csrf认证详解
2019/07/22 Python
从训练好的tensorflow模型中打印训练变量实例
2020/01/20 Python
python读取文件指定行内容实例讲解
2020/03/02 Python
Django自定义列表 models字段显示方式
2020/04/03 Python
html5调用app分享功能示例(WebViewJavascriptBridge)
2018/03/21 HTML / CSS
荷兰和比利时时尚鞋店:Van Dalen
2018/04/23 全球购物
世界上最大的乐谱选择:Sheet Music Plus
2020/01/18 全球购物
法国购买二手电子产品网站:Asgoodasnew
2020/03/27 全球购物
英国时尚首饰品牌:Missoma
2020/06/29 全球购物
木马的传播途径主要有哪些
2016/04/08 面试题
CSS代码检查工具stylelint的使用方法详解
2021/03/27 HTML / CSS
个人充满哲理的自我评价
2014/02/20 职场文书
挂职自我鉴定
2014/02/26 职场文书
工地安全质量标语
2014/06/07 职场文书
合作协议书范文
2014/08/20 职场文书
党性心得体会
2014/09/03 职场文书
详解如何用Python实现感知器算法
2021/06/18 Python