pytorch中的model.eval()和BN层的使用


Posted in Python onMay 22, 2021

看代码吧~

class ConvNet(nn.module):
    def __init__(self, num_class=10):
        super(ConvNet, self).__init__()
        self.layer1 = nn.Sequential(nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
                                    nn.BatchNorm2d(16),
                                    nn.ReLU(),
                                    nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer2 = nn.Sequential(nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
                                    nn.BatchNorm2d(32),
                                    nn.ReLU(),
                                    nn.MaxPool2d(kernel_size=2, stride=2))
        self.fc = nn.Linear(7*7*32, num_classes)
         
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        print(out.size())
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        return out
# Test the model
model.eval()  # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

如果网络模型model中含有BN层,则在预测时应当将模式切换为评估模式,即model.eval()。

评估模拟下BN层的均值和方差应该是整个训练集的均值和方差,即 moving mean/variance。

训练模式下BN层的均值和方差为mini-batch的均值和方差,因此应当特别注意。

补充:Pytorch 模型训练模式和eval模型下差别巨大(Pytorch train and eval)附解决方案

当pytorch模型写明是eval()时有时表现的结果相对于train(True)差别非常巨大,这种差别经过逐层查看,主要来源于使用了BN,在eval下,使用的BN是一个固定的running rate,而在train下这个running rate会根据输入发生改变。

解决方案是冻住bn

def freeze_bn(m):
    if isinstance(m, nn.BatchNorm2d):
        m.eval()
model.apply(freeze_bn)

这样可以获得稳定输出的结果。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python使用pymysql小技巧
Jun 04 Python
在IPython中进行Python程序执行时间的测量方法
Nov 01 Python
10分钟教你用Python实现微信自动回复功能
Nov 28 Python
解决django中ModelForm多表单组合的问题
Jul 18 Python
Python3 JSON编码解码方法详解
Sep 06 Python
TensorFlow实现自定义Op方式
Feb 04 Python
Python打包模块wheel的使用方法与将python包发布到PyPI的方法详解
Feb 12 Python
python中的 zip函数详解及用法举例
Feb 16 Python
用python发送微信消息
Dec 21 Python
Django一小时写出账号密码管理系统
Apr 29 Python
python之PySide2安装使用及QT Designer UI设计案例教程
Jul 26 Python
一篇文章弄懂Python中的内建函数
Aug 07 Python
解决Pytorch中关于model.eval的问题
Pytorch 中net.train 和 net.eval的使用说明
May 22 #Python
对PyTorch中inplace字段的全面理解
May 22 #Python
pytorch中F.avg_pool1d()和F.avg_pool2d()的使用操作
May 22 #Python
用python实现监控视频人数统计
Python基础之进程详解
如何在C++中调用Python
May 21 #Python
You might like
php md5下16位和32位的实现代码
2008/04/09 PHP
解析zend Framework如何自动加载类
2013/06/28 PHP
YII动态模型(动态表名)支持分析
2016/03/29 PHP
PHP将MySQL的查询结果转换为数组并用where拼接的示例
2016/05/13 PHP
php实现有序数组打印或排序的方法【附Python、C及Go语言实现代码】
2016/11/10 PHP
PHP单例模式详解及实例代码
2016/12/21 PHP
Javascript实例教程(19) 使用HoTMetal(2)
2006/12/23 Javascript
一个简单的瀑布流效果(主体形式自写)
2013/05/27 Javascript
JQuery插件ajaxfileupload.js异步上传文件实例
2015/05/19 Javascript
JavaScript的Date()方法使用详解
2015/06/09 Javascript
js实现的倒计时按钮实例
2015/06/24 Javascript
Bootstrap实现input控件失去焦点时验证
2016/08/04 Javascript
Vue.js组件tabs实现选项卡切换效果
2016/12/01 Javascript
使用JavaScript实现一个小程序之99乘法表
2017/09/21 Javascript
vue2.0组件之间传值、通信的多种方式(干货)
2018/02/10 Javascript
Intellij IDEA搭建vue-cli项目的方法步骤
2018/10/20 Javascript
python实现的简单窗口倒计时界面实例
2015/05/05 Python
python实现音乐下载器
2018/04/15 Python
python队列queue模块详解
2018/04/27 Python
python的json包位置及用法总结
2020/06/21 Python
python求解汉诺塔游戏
2020/07/09 Python
python中return不返回值的问题解析
2020/07/22 Python
python爬虫看看虎牙女主播中谁最“顶”步骤详解
2020/12/01 Python
CSS3 实现雷达扫描图的示例代码
2020/09/21 HTML / CSS
菲律宾票务网站:StubHub菲律宾
2018/04/21 全球购物
大学生自我鉴定范文模板
2014/01/21 职场文书
地理信息科学专业推荐信
2014/09/08 职场文书
工伤事故赔偿协议书(标准)
2014/09/29 职场文书
党委书记群众路线对照检查材料思想汇报
2014/10/04 职场文书
2015年感恩父亲节演讲稿
2015/03/19 职场文书
2015年读书月活动总结
2015/03/26 职场文书
老员工辞职信范文
2015/05/12 职场文书
运动会三级跳加油稿
2015/07/21 职场文书
pytorch 如何使用batch训练lstm网络
2021/05/28 Python
Python实现列表拼接和去重的三种方式
2021/07/02 Python
【海涛解说】暗牧也疯狂,牛蛙成配角
2022/04/01 DOTA