pytorch中的model.eval()和BN层的使用


Posted in Python onMay 22, 2021

看代码吧~

class ConvNet(nn.module):
    def __init__(self, num_class=10):
        super(ConvNet, self).__init__()
        self.layer1 = nn.Sequential(nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
                                    nn.BatchNorm2d(16),
                                    nn.ReLU(),
                                    nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer2 = nn.Sequential(nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
                                    nn.BatchNorm2d(32),
                                    nn.ReLU(),
                                    nn.MaxPool2d(kernel_size=2, stride=2))
        self.fc = nn.Linear(7*7*32, num_classes)
         
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        print(out.size())
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        return out
# Test the model
model.eval()  # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

如果网络模型model中含有BN层,则在预测时应当将模式切换为评估模式,即model.eval()。

评估模拟下BN层的均值和方差应该是整个训练集的均值和方差,即 moving mean/variance。

训练模式下BN层的均值和方差为mini-batch的均值和方差,因此应当特别注意。

补充:Pytorch 模型训练模式和eval模型下差别巨大(Pytorch train and eval)附解决方案

当pytorch模型写明是eval()时有时表现的结果相对于train(True)差别非常巨大,这种差别经过逐层查看,主要来源于使用了BN,在eval下,使用的BN是一个固定的running rate,而在train下这个running rate会根据输入发生改变。

解决方案是冻住bn

def freeze_bn(m):
    if isinstance(m, nn.BatchNorm2d):
        m.eval()
model.apply(freeze_bn)

这样可以获得稳定输出的结果。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python的subprocess模块总结
Nov 07 Python
python实现的简单文本类游戏实例
Apr 28 Python
python中的内置函数max()和min()及mas()函数的高级用法
Mar 29 Python
Python django使用多进程连接mysql错误的解决方法
Oct 08 Python
使用python进行拆分大文件的方法
Dec 10 Python
Python3 log10()函数简单用法
Feb 19 Python
人工神经网络算法知识点总结
Jun 11 Python
python原类、类的创建过程与方法详解
Jul 19 Python
python Paramiko使用示例
Sep 21 Python
Autopep8的使用(python自动编排工具)
Mar 02 Python
Python实现socket库网络通信套接字
Jun 04 Python
python识别围棋定位棋盘位置
Jul 26 Python
解决Pytorch中关于model.eval的问题
Pytorch 中net.train 和 net.eval的使用说明
May 22 #Python
对PyTorch中inplace字段的全面理解
May 22 #Python
pytorch中F.avg_pool1d()和F.avg_pool2d()的使用操作
May 22 #Python
用python实现监控视频人数统计
Python基础之进程详解
如何在C++中调用Python
May 21 #Python
You might like
thinkphp 多表 事务详解
2013/06/17 PHP
php操作xml入门之xml基本介绍及xml标签元素
2015/01/23 PHP
PHP实现根据数组的值进行分组的方法
2017/04/20 PHP
js传值 判断
2006/10/26 Javascript
收藏一些不常用,但是有用的代码
2007/03/12 Javascript
通过下拉框的值来确定输入框是否可以为空的代码
2011/10/18 Javascript
关于js日期转化为毫秒数“节省20%的效率和和节省9个字符“问题
2012/03/01 Javascript
js substr支持中文截取函数代码(中文是双字节)
2013/04/17 Javascript
浅析JS刷新框架中的其他页面 && JS刷新窗口方法汇总
2013/07/08 Javascript
Firefox和IE兼容性问题及解决方法总结
2013/10/08 Javascript
使用POST方式弹出窗口的两种方法示例介绍
2014/01/29 Javascript
PHP和NodeJs开发的应用如何共用Session
2015/04/16 NodeJs
javascript实现英文首字母大写
2015/04/23 Javascript
JavaScript中setUTCFullYear()方法的使用简介
2015/06/12 Javascript
原生js实现无缝轮播图效果
2017/01/11 Javascript
JS实现最简单的冒泡排序算法
2017/02/15 Javascript
Vue组件之自定义事件的功能图解
2018/02/01 Javascript
详解Vue中使用Echarts的两种方式
2018/07/03 Javascript
微信小程序 JS动态修改样式的实现方法
2018/12/16 Javascript
React通过redux-persist持久化数据存储的方法示例
2019/02/14 Javascript
微信小程序如何实现全局重新加载
2019/06/05 Javascript
vue路由切换之淡入淡出的简单实现
2019/10/31 Javascript
vue 实现v-for循环回来的数据动态绑定id
2019/11/07 Javascript
JS三级联动代码格式实例详解
2019/12/30 Javascript
[04:14]从西雅图到上海——玩家自制DOTA2主题歌曲应援TI9
2019/07/11 DOTA
python调用java的Webservice示例
2014/03/10 Python
python爬虫爬取某站上海租房图片
2018/02/04 Python
基于django2.2连oracle11g解决版本冲突的问题
2020/07/02 Python
Python unittest装饰器实现原理及代码
2020/09/08 Python
如何让IE9以下版本(ie6/7/8)认识html5元素
2013/04/01 HTML / CSS
美国电力供应商店/电气批发商:USESI
2018/10/12 全球购物
学生的自我鉴定范文
2013/10/24 职场文书
会议主持词
2014/03/17 职场文书
设备管理实施方案
2014/05/31 职场文书
不服劳动仲裁起诉书
2015/05/20 职场文书
私人贷款担保书该怎么写呢?
2019/07/02 职场文书