PyTorch dropout设置训练和测试模式的实现


Posted in Python onMay 27, 2021

看代码吧~

class Net(nn.Module):
…
model = Net()
…
model.train() # 把module设成训练模式,对Dropout和BatchNorm有影响
model.eval() # 把module设置为预测模式,对Dropout和BatchNorm模块有影响

补充:Pytorch遇到的坑——训练模式和测试模式切换

由于训练的时候Dropout和BN层起作用,每个batch BN层的参数不一样,dropout在训练时随机失效点具有随机性,所以训练和测试要区分开来。

使用时切记要根据实际情况切换:

model.train()
model.eval()

补充:Pytorch在测试与训练过程中的验证结果不一致问题

引言

今天在使用Pytorch导入此前保存的模型进行测试,在过程中发现输出的结果与验证结果差距甚大,经过排查后发现是forward与eval()顺序问题。

现象

此前的错误代码是

input_cpu = torch.ones((1, 2, 160, 160))
    target_cpu =torch.ones((1, 2, 160, 160))
    target_gpu, input_gpu = target_cpu.cuda(), input_cpu.cuda()
    model.set_input_2(input_gpu, target_gpu)
    model.eval()
    model.forward()

应该改为

input_cpu = torch.ones((1, 2, 160, 160))
    target_cpu =torch.ones((1, 2, 160, 160))
    target_gpu, input_gpu = target_cpu.cuda(), input_cpu.cuda()
    model.set_input_2(input_gpu, target_gpu)
    # 先forward再eval
    model.forward()
    model.eval()

当时有个疑虑,为什么要在forward后面再加eval(),查了下相关资料,主要是在BN层以及Dropout的问题。当使用eval()时,模型会自动固定BN层以及Dropout,选取训练好的值,否则则会取平均,可能导致生成的图片颜色失真。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Pyhton中单行和多行注释的使用方法及规范
Oct 11 Python
Python数据分析之如何利用pandas查询数据示例代码
Sep 01 Python
Python编程之黑板上排列组合,你舍得解开吗
Oct 30 Python
详解python字节码
Feb 07 Python
Python 获取中文字拼音首个字母的方法
Nov 28 Python
python 实现语音聊天机器人的示例代码
Dec 02 Python
python+opencv打开摄像头,保存视频、拍照功能的实现方法
Jan 08 Python
浅谈Django+Gunicorn+Nginx部署之路
Sep 11 Python
python操作gitlab API过程解析
Dec 27 Python
python super函数使用方法详解
Feb 14 Python
python 实现德洛内三角剖分的操作
Apr 22 Python
详解pytorch创建tensor函数
Mar 22 Python
pytorch Dropout过拟合的操作
浅谈pytorch中的dropout的概率p
May 27 #Python
让文件路径提取变得更简单的Python Path库
Pytorch中的数据集划分&正则化方法
Pytorch 如何实现常用正则化
PyTorch 实现L2正则化以及Dropout的操作
Python开发之QT解决无边框界面拖动卡屏问题(附带源码)
You might like
PHP正则表达式之捕获组与非捕获组
2015/11/06 PHP
编写PHP脚本来实现WordPress中评论分页的功能
2015/12/10 PHP
PHP读取mssql json数据中文乱码的解决办法
2016/04/11 PHP
基于php实现的php代码加密解密类完整实例
2016/10/12 PHP
PHP程序员简单的开展服务治理架构操作详解(三)
2020/05/14 PHP
JavaScript this 深入理解
2009/07/30 Javascript
纯Javascript实现Windows 8 Metro风格实现
2013/10/15 Javascript
用JS实现3D球状标签云示例代码
2013/12/01 Javascript
JS获得浏览器版本和操作系统版本的例子
2014/05/13 Javascript
ECMAScript5(ES5)中bind方法使用小结
2015/05/07 Javascript
详解JavaScript中setSeconds()方法的使用
2015/06/11 Javascript
JS+CSS相对定位实现的下拉菜单
2015/10/06 Javascript
jquery mobile开发常见问题分析
2016/01/21 Javascript
bootstrap css样式之表单
2017/01/19 Javascript
JS原生轮播图的简单实现(推荐)
2017/07/22 Javascript
解决vue 格式化银行卡(信用卡)每4位一个符号隔断的问题
2018/09/14 Javascript
Vue.js中provide/inject实现响应式数据更新的方法示例
2019/10/16 Javascript
js回调函数仿360开机
2019/12/26 Javascript
Vue实现简单的留言板
2020/10/23 Javascript
[01:04]DOTA2:伟大的Roshan雕塑震撼来临
2015/01/30 DOTA
[04:05]TI9战队采访 - Natus Vincere
2019/08/22 DOTA
python在windows下实现备份程序实例
2014/07/04 Python
Python深入学习之装饰器
2014/08/31 Python
python制作websocket服务器实例分享
2016/11/20 Python
Python英文文本分词(无空格)模块wordninja的使用实例
2019/02/20 Python
Django之提交表单与前后端交互的方法
2019/07/19 Python
Python实现代码统计工具
2019/09/19 Python
浅谈关于html5中图片抛物线运动的一些心得
2018/01/09 HTML / CSS
Beach Bunny Swimwear官网:设计师泳装和性感比基尼
2019/03/13 全球购物
Interrail法国:乘火车探索欧洲,最受欢迎的欧洲铁路通票
2019/08/27 全球购物
耐克亚太地区:Nike APAC
2019/12/07 全球购物
Jowissa官方网站:瑞士制造的手表,优雅简约的设计
2020/07/29 全球购物
《草虫的村落》教学反思
2014/02/16 职场文书
2014年采购工作总结
2014/11/20 职场文书
综合素质自我评价评语
2015/03/06 职场文书
幼儿园大班开学寄语(2015秋季)
2015/05/27 职场文书