pytorch中的model.eval()和BN层的使用


Posted in Python onMay 22, 2021

看代码吧~

class ConvNet(nn.module):
    def __init__(self, num_class=10):
        super(ConvNet, self).__init__()
        self.layer1 = nn.Sequential(nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
                                    nn.BatchNorm2d(16),
                                    nn.ReLU(),
                                    nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer2 = nn.Sequential(nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
                                    nn.BatchNorm2d(32),
                                    nn.ReLU(),
                                    nn.MaxPool2d(kernel_size=2, stride=2))
        self.fc = nn.Linear(7*7*32, num_classes)
         
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        print(out.size())
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        return out
# Test the model
model.eval()  # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

如果网络模型model中含有BN层,则在预测时应当将模式切换为评估模式,即model.eval()。

评估模拟下BN层的均值和方差应该是整个训练集的均值和方差,即 moving mean/variance。

训练模式下BN层的均值和方差为mini-batch的均值和方差,因此应当特别注意。

补充:Pytorch 模型训练模式和eval模型下差别巨大(Pytorch train and eval)附解决方案

当pytorch模型写明是eval()时有时表现的结果相对于train(True)差别非常巨大,这种差别经过逐层查看,主要来源于使用了BN,在eval下,使用的BN是一个固定的running rate,而在train下这个running rate会根据输入发生改变。

解决方案是冻住bn

def freeze_bn(m):
    if isinstance(m, nn.BatchNorm2d):
        m.eval()
model.apply(freeze_bn)

这样可以获得稳定输出的结果。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python打开网页和暂停实例
Sep 30 Python
python Opencv将图片转为字符画
Feb 19 Python
Python之时间和日期使用小结
Feb 14 Python
python2和python3在处理字符串上的区别详解
May 29 Python
Python read函数按字节(字符)读取文件的实现
Jul 03 Python
Django分页功能的实现代码详解
Jul 29 Python
python中matplotlib条件背景颜色的实现
Sep 02 Python
Python实现CNN的多通道输入实例
Jan 17 Python
python实现扫雷小游戏
Apr 24 Python
tensorflow实现残差网络方式(mnist数据集)
May 26 Python
如何基于Python Matplotlib实现网格动画
Jul 20 Python
Python爬虫爬取有道实现翻译功能
Nov 27 Python
解决Pytorch中关于model.eval的问题
Pytorch 中net.train 和 net.eval的使用说明
May 22 #Python
对PyTorch中inplace字段的全面理解
May 22 #Python
pytorch中F.avg_pool1d()和F.avg_pool2d()的使用操作
May 22 #Python
用python实现监控视频人数统计
Python基础之进程详解
如何在C++中调用Python
May 21 #Python
You might like
使用数据库保存session的方法
2006/10/09 PHP
Symfony学习十分钟入门经典教程
2016/02/03 PHP
PHP的PDO预定义常量讲解
2019/01/24 PHP
PHP XML Expat解析器知识点总结
2019/02/15 PHP
Laravel5.7 Eloquent ORM快速入门详解
2019/04/12 PHP
gearman中worker常驻后台,导致MySQL server has gone away的解决方法
2020/02/27 PHP
PHP设计模式之 策略模式Strategy详解【对象行为型】
2020/05/01 PHP
PHP执行普通shell命令流程解析
2020/08/24 PHP
JQuery的html(data)方法与<script>脚本块的解决方法
2010/03/09 Javascript
javascript仿qq界面的折叠菜单实现代码
2012/12/12 Javascript
在easyUI开发中,出现jquery.easyui.min.js函数库问题的解决办法
2015/09/11 Javascript
Vue方法与事件处理器详解
2016/12/01 Javascript
微信小程序 自定义对话框实例详解
2017/01/20 Javascript
Angular4绑定html内容出现警告的处理方法
2017/11/03 Javascript
vue中组件的过渡动画及实现代码
2018/11/21 Javascript
基于layui的下拉列表的数据回显方法
2019/09/24 Javascript
在vue中利用全局路由钩子给url统一添加公共参数的例子
2019/11/01 Javascript
vue中实现点击按钮滚动到页面对应位置的方法(使用c3平滑属性实现)
2019/12/29 Javascript
了不起的11个JavaScript代码重构最佳实践小结
2021/01/11 Javascript
对python:threading.Thread类的使用方法详解
2019/01/31 Python
树莓派+摄像头实现对移动物体的检测
2019/06/22 Python
python 使用matplotlib 实现从文件中读取x,y坐标的可视化方法
2019/07/04 Python
Python requests模块cookie实例解析
2020/04/14 Python
安装pyecharts1.8.0版本后导入pyecharts模块绘图时报错: “所有图表类型将在 v1.9.0 版本开始强制使用 ChartItem 进行数据项配置 ”的解决方法
2020/08/18 Python
python中Mako库实例用法
2020/12/31 Python
家得宝官网:The Home Depot(全球最大的家居装饰专业零售商)
2018/12/17 全球购物
复古服装:RetroStage
2019/05/10 全球购物
新闻记者个人求职的自我评价
2013/11/28 职场文书
大一新生军训时的自我评价分享
2013/12/05 职场文书
环境科学专业优秀毕业生自荐书
2014/02/03 职场文书
电视节目策划方案
2014/05/16 职场文书
乡镇镇长个人整改措施
2014/10/01 职场文书
征求意见函
2015/06/05 职场文书
幼儿园托班开学寄语(2016春季)
2015/12/03 职场文书
JVM入门之类加载与字节码技术(类加载与类的加载器)
2021/06/15 Java/Android
前端vue+express实现文件的上传下载示例
2022/02/18 Vue.js