pytorch中的model.eval()和BN层的使用


Posted in Python onMay 22, 2021

看代码吧~

class ConvNet(nn.module):
    def __init__(self, num_class=10):
        super(ConvNet, self).__init__()
        self.layer1 = nn.Sequential(nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
                                    nn.BatchNorm2d(16),
                                    nn.ReLU(),
                                    nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer2 = nn.Sequential(nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
                                    nn.BatchNorm2d(32),
                                    nn.ReLU(),
                                    nn.MaxPool2d(kernel_size=2, stride=2))
        self.fc = nn.Linear(7*7*32, num_classes)
         
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        print(out.size())
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        return out
# Test the model
model.eval()  # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

如果网络模型model中含有BN层,则在预测时应当将模式切换为评估模式,即model.eval()。

评估模拟下BN层的均值和方差应该是整个训练集的均值和方差,即 moving mean/variance。

训练模式下BN层的均值和方差为mini-batch的均值和方差,因此应当特别注意。

补充:Pytorch 模型训练模式和eval模型下差别巨大(Pytorch train and eval)附解决方案

当pytorch模型写明是eval()时有时表现的结果相对于train(True)差别非常巨大,这种差别经过逐层查看,主要来源于使用了BN,在eval下,使用的BN是一个固定的running rate,而在train下这个running rate会根据输入发生改变。

解决方案是冻住bn

def freeze_bn(m):
    if isinstance(m, nn.BatchNorm2d):
        m.eval()
model.apply(freeze_bn)

这样可以获得稳定输出的结果。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python3图片转换二进制存入mysql
Dec 06 Python
Python处理json字符串转化为字典的简单实现
Jul 07 Python
深入探究Django中的Session与Cookie
Jul 30 Python
python对离散变量的one-hot编码方法
Jul 11 Python
python数据结构学习之实现线性表的顺序
Sep 28 Python
Python的bit_length函数来二进制的位数方法
Aug 27 Python
浅谈在django中使用redirect重定向数据传输的问题
Mar 13 Python
Python文件时间操作步骤代码详解
Apr 13 Python
Python基于gevent实现高并发代码实例
May 15 Python
使用PyCharm安装pytest及requests的问题
Jul 31 Python
python将图片转为矢量图的方法步骤
Mar 30 Python
python opencv旋转图片的使用方法
Jun 04 Python
解决Pytorch中关于model.eval的问题
Pytorch 中net.train 和 net.eval的使用说明
May 22 #Python
对PyTorch中inplace字段的全面理解
May 22 #Python
pytorch中F.avg_pool1d()和F.avg_pool2d()的使用操作
May 22 #Python
用python实现监控视频人数统计
Python基础之进程详解
如何在C++中调用Python
May 21 #Python
You might like
PHP持久连接mysql_pconnect()函数使用介绍
2012/02/05 PHP
PHP 将逗号、空格、回车分隔的字符串转换为数组的函数
2012/06/07 PHP
与文件上传有关的php配置参数总结
2013/06/14 PHP
浅析php header 跳转
2013/06/17 PHP
php实现的css文件背景图片下载器代码
2014/11/11 PHP
PHP正则验证Email的方法
2015/06/15 PHP
PHP实现的贪婪算法实例
2017/10/17 PHP
PHP 应用容器化以及部署方法
2018/02/12 PHP
javascript最常用与实用的创建类的代码
2010/08/12 Javascript
从零开始学习jQuery (三) 管理jQuery包装集
2011/02/23 Javascript
Jquery解析json数据详解
2013/12/26 Javascript
javascript 拷贝节点cloneNode()使用介绍
2014/04/03 Javascript
JS函数this的用法实例分析
2015/02/05 Javascript
Actionscript与javascript交互实例程序(修改)
2016/09/22 Javascript
微信小程序 MD5的方法详解及实例代码
2017/03/10 Javascript
VUE Error: getaddrinfo ENOTFOUND localhost
2018/05/03 Javascript
通过nodejs 服务器读取HTML文件渲染到页面的方法
2018/05/17 NodeJs
Vue2.0 实现歌手列表滚动及右侧快速入口功能
2018/08/08 Javascript
ES6的解构赋值实例详解
2019/05/06 Javascript
package.json配置文件构成详解
2019/08/27 Javascript
用js编写留言板
2020/03/17 Javascript
python抓取网页图片示例(python爬虫)
2014/04/27 Python
跟老齐学Python之集合的关系
2014/09/24 Python
Python入门教程之if语句的用法
2015/05/14 Python
详解Python中的array数组模块相关使用
2016/07/05 Python
详解Python给照片换底色(蓝底换红底)
2019/03/22 Python
Python调用Windows命令打印文件
2020/02/07 Python
html5各种页面切换效果和模态对话框用法总结
2014/12/15 HTML / CSS
LN-CC日本:高端男装和女装的奢侈时尚目的地
2019/09/01 全球购物
C++:局部变量能否和全局变量重名
2014/03/03 面试题
药剂专业毕业生求职信
2014/06/24 职场文书
师范生见习自我总结
2015/06/23 职场文书
公安干警正风肃纪心得体会
2016/01/15 职场文书
Go并发4种方法简明讲解
2022/04/06 Golang
Python线程池与GIL全局锁实现抽奖小案例
2022/04/13 Python
python pandas 解析(读取、写入)CSV 文件的操作方法
2022/12/24 Python