pytorch中的model.eval()和BN层的使用


Posted in Python onMay 22, 2021

看代码吧~

class ConvNet(nn.module):
    def __init__(self, num_class=10):
        super(ConvNet, self).__init__()
        self.layer1 = nn.Sequential(nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
                                    nn.BatchNorm2d(16),
                                    nn.ReLU(),
                                    nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer2 = nn.Sequential(nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
                                    nn.BatchNorm2d(32),
                                    nn.ReLU(),
                                    nn.MaxPool2d(kernel_size=2, stride=2))
        self.fc = nn.Linear(7*7*32, num_classes)
         
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        print(out.size())
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        return out
# Test the model
model.eval()  # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

如果网络模型model中含有BN层,则在预测时应当将模式切换为评估模式,即model.eval()。

评估模拟下BN层的均值和方差应该是整个训练集的均值和方差,即 moving mean/variance。

训练模式下BN层的均值和方差为mini-batch的均值和方差,因此应当特别注意。

补充:Pytorch 模型训练模式和eval模型下差别巨大(Pytorch train and eval)附解决方案

当pytorch模型写明是eval()时有时表现的结果相对于train(True)差别非常巨大,这种差别经过逐层查看,主要来源于使用了BN,在eval下,使用的BN是一个固定的running rate,而在train下这个running rate会根据输入发生改变。

解决方案是冻住bn

def freeze_bn(m):
    if isinstance(m, nn.BatchNorm2d):
        m.eval()
model.apply(freeze_bn)

这样可以获得稳定输出的结果。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python自动安装pip
Apr 24 Python
利用Django内置的认证视图实现用户密码重置功能详解
Nov 24 Python
python中字符串比较使用is、==和cmp()总结
Mar 18 Python
python获取当前目录路径和上级路径的实例
Apr 26 Python
python爬取微信公众号文章
Aug 31 Python
selenium + python 获取table数据的示例讲解
Oct 13 Python
Python 获取项目根路径的代码
Sep 27 Python
nginx+uwsgi+django环境搭建的方法步骤
Nov 25 Python
浅谈Python 函数式编程
Jun 20 Python
通过代码实例了解Python3编程技巧
Oct 13 Python
python+pytest接口自动化之token关联登录的实现
Apr 06 Python
Python first-order-model实现让照片动起来
Jun 25 Python
解决Pytorch中关于model.eval的问题
Pytorch 中net.train 和 net.eval的使用说明
May 22 #Python
对PyTorch中inplace字段的全面理解
May 22 #Python
pytorch中F.avg_pool1d()和F.avg_pool2d()的使用操作
May 22 #Python
用python实现监控视频人数统计
Python基础之进程详解
如何在C++中调用Python
May 21 #Python
You might like
五个PHP程序员工具
2008/05/26 PHP
PHP二维数组实现去除重复项的方法【保留各个键值】
2017/12/21 PHP
javascript写的简单的计算器,内容很多,方法实用,推荐
2011/12/29 Javascript
JavaScript中的全局对象介绍
2015/01/01 Javascript
JavaScript实现简单图片翻转的方法
2015/04/17 Javascript
js实现刷新iframe的方法汇总
2015/04/27 Javascript
JavaScript学习笔记(三):JavaScript也有入口Main函数
2015/09/12 Javascript
JS代码随机生成姓名、手机号、身份证号、银行卡号
2016/04/27 Javascript
使用jQuery Mobile框架开发移动端Web App的入门教程
2016/05/17 Javascript
微信小程序 教程之注册程序
2016/10/17 Javascript
jQuery中值得注意的trigger方法浅析
2016/12/12 Javascript
Angular.JS内置服务$http对数据库的增删改使用教程
2017/05/07 Javascript
javascript用rem来做响应式开发
2018/01/13 Javascript
解决js ajax同步请求造成浏览器假死的问题
2018/01/18 Javascript
微信、QQ、微博、Safari中使用js唤起App
2018/01/24 Javascript
jQuery实现的简单对话框拖动功能示例
2018/06/05 jQuery
vue项目打包上传github并制作预览链接(pages)
2019/04/19 Javascript
微信小程序实现拖拽功能
2019/09/26 Javascript
jQuery实现的上拉刷新功能组件示例
2020/05/01 jQuery
Element InputNumber计数器的使用方法
2020/07/27 Javascript
javascript实现贪吃蛇游戏(娱乐版)
2020/08/17 Javascript
jQuery实现回到顶部效果
2020/10/19 jQuery
pandas.DataFrame 根据条件新建列并赋值的方法
2018/04/08 Python
基于python实现聊天室程序
2018/07/27 Python
使用Filter过滤python中的日志输出的实现方法
2019/07/17 Python
快速解决Django关闭Debug模式无法加载media图片与static静态文件
2020/04/07 Python
HTML高亮关键字的实现代码
2018/10/22 HTML / CSS
html5开发之viewport使用
2013/10/17 HTML / CSS
简单整理HTML5的基本特性和语法
2016/02/18 HTML / CSS
美国著名的婴儿学步鞋老品牌:Robeez
2016/08/20 全球购物
美国专业级皮肤病和spa品质护肤品的高级零售网站:SkinCareRx
2017/02/06 全球购物
工程专业求职自荐书范文
2014/02/18 职场文书
餐厅执行经理岗位职责范本
2014/02/26 职场文书
后勤主管岗位职责
2014/03/01 职场文书
2014年工作总结及2015工作计划
2014/12/12 职场文书
Python使用socket去实现TCP客户端和TCP服务端
2022/04/12 Python