pytorch中的model.eval()和BN层的使用


Posted in Python onMay 22, 2021

看代码吧~

class ConvNet(nn.module):
    def __init__(self, num_class=10):
        super(ConvNet, self).__init__()
        self.layer1 = nn.Sequential(nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
                                    nn.BatchNorm2d(16),
                                    nn.ReLU(),
                                    nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer2 = nn.Sequential(nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
                                    nn.BatchNorm2d(32),
                                    nn.ReLU(),
                                    nn.MaxPool2d(kernel_size=2, stride=2))
        self.fc = nn.Linear(7*7*32, num_classes)
         
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        print(out.size())
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        return out
# Test the model
model.eval()  # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

如果网络模型model中含有BN层,则在预测时应当将模式切换为评估模式,即model.eval()。

评估模拟下BN层的均值和方差应该是整个训练集的均值和方差,即 moving mean/variance。

训练模式下BN层的均值和方差为mini-batch的均值和方差,因此应当特别注意。

补充:Pytorch 模型训练模式和eval模型下差别巨大(Pytorch train and eval)附解决方案

当pytorch模型写明是eval()时有时表现的结果相对于train(True)差别非常巨大,这种差别经过逐层查看,主要来源于使用了BN,在eval下,使用的BN是一个固定的running rate,而在train下这个running rate会根据输入发生改变。

解决方案是冻住bn

def freeze_bn(m):
    if isinstance(m, nn.BatchNorm2d):
        m.eval()
model.apply(freeze_bn)

这样可以获得稳定输出的结果。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Pyramid添加Middleware的方法实例
Nov 27 Python
探索Python3.4中新引入的asyncio模块
Apr 08 Python
python将每个单词按空格分开并保存到文件中
Mar 19 Python
pandas获取groupby分组里最大值所在的行方法
Apr 20 Python
ActiveMQ:使用Python访问ActiveMQ的方法
Jan 30 Python
python爬虫 execjs安装配置及使用
Jul 30 Python
python集合常见运算案例解析
Oct 17 Python
python 变量初始化空列表的例子
Nov 28 Python
Django之form组件自动校验数据实现
Jan 14 Python
python中如何写类
Jun 29 Python
Python Pandas模块实现数据的统计分析的方法
Jun 24 Python
Python制作一个随机抽奖小工具的实现
Jul 07 Python
解决Pytorch中关于model.eval的问题
Pytorch 中net.train 和 net.eval的使用说明
May 22 #Python
对PyTorch中inplace字段的全面理解
May 22 #Python
pytorch中F.avg_pool1d()和F.avg_pool2d()的使用操作
May 22 #Python
用python实现监控视频人数统计
Python基础之进程详解
如何在C++中调用Python
May 21 #Python
You might like
php实现多城市切换特效
2015/08/09 PHP
javascript 装载iframe子页面,自适应高度
2009/03/20 Javascript
JQuery中的html()、text()、val()区别示例介绍
2014/09/01 Javascript
jQuery实现下滑菜单导航效果代码
2015/08/25 Javascript
js仿微博实现统计字符和本地存储功能
2015/12/22 Javascript
js简单判断移动端系统的方法
2016/02/25 Javascript
JQuery 在文档中查找指定name的元素并移除的实现方法
2016/05/19 Javascript
利用nodejs监控文件变化并使用sftp上传到服务器
2017/02/18 NodeJs
js判断PC端与移动端跳转
2020/12/24 Javascript
jQuery编写textarea输入字数限制代码
2017/03/23 jQuery
原生javascript实现分页效果
2017/04/21 Javascript
JS作用域链详解
2017/06/26 Javascript
利用angular自动编译andriod APK的绕坑经历分享
2019/03/08 Javascript
一些你可能不熟悉的JS知识点总结
2019/03/15 Javascript
jQuery实现动态生成年月日级联下拉列表示例
2019/05/11 jQuery
帮你彻底搞懂JS中的prototype、__proto__与constructor(图解)
2019/08/23 Javascript
vuex+axios+element-ui实现页面请求loading操作示例
2020/02/02 Javascript
python获取网页状态码示例
2014/03/30 Python
Python numpy数组转置与轴变换
2019/11/15 Python
Python批量处理csv并保存过程解析
2020/05/16 Python
python基于pygame实现飞机大作战小游戏
2020/11/19 Python
Watchshop德国:欧洲在线手表No.1
2019/06/20 全球购物
人力资源管理毕业生自荐信
2013/11/21 职场文书
经贸日语专业个人求职信范文
2013/12/28 职场文书
会计大学生职业生涯规划书范文
2014/01/13 职场文书
韩国商务邀请函
2014/01/14 职场文书
2014年派出所工作总结
2014/11/21 职场文书
2015年保洁工作总结范文
2015/04/28 职场文书
2015年度电厂个人工作总结
2015/05/13 职场文书
校园新闻稿范文
2015/07/18 职场文书
父亲节感言
2015/08/03 职场文书
golang通过递归遍历生成树状结构的操作
2021/04/28 Golang
MySQL sql_mode的使用详解
2021/05/08 MySQL
python3实现Dijkstra算法最短路径的实现
2021/05/12 Python
教你使用pyinstaller打包Python教程
2021/05/27 Python
Python通过loop.run_in_executor执行同步代码 同步变为异步
2022/04/11 Python