PyTorch dropout设置训练和测试模式的实现


Posted in Python onMay 27, 2021

看代码吧~

class Net(nn.Module):
…
model = Net()
…
model.train() # 把module设成训练模式,对Dropout和BatchNorm有影响
model.eval() # 把module设置为预测模式,对Dropout和BatchNorm模块有影响

补充:Pytorch遇到的坑——训练模式和测试模式切换

由于训练的时候Dropout和BN层起作用,每个batch BN层的参数不一样,dropout在训练时随机失效点具有随机性,所以训练和测试要区分开来。

使用时切记要根据实际情况切换:

model.train()
model.eval()

补充:Pytorch在测试与训练过程中的验证结果不一致问题

引言

今天在使用Pytorch导入此前保存的模型进行测试,在过程中发现输出的结果与验证结果差距甚大,经过排查后发现是forward与eval()顺序问题。

现象

此前的错误代码是

input_cpu = torch.ones((1, 2, 160, 160))
    target_cpu =torch.ones((1, 2, 160, 160))
    target_gpu, input_gpu = target_cpu.cuda(), input_cpu.cuda()
    model.set_input_2(input_gpu, target_gpu)
    model.eval()
    model.forward()

应该改为

input_cpu = torch.ones((1, 2, 160, 160))
    target_cpu =torch.ones((1, 2, 160, 160))
    target_gpu, input_gpu = target_cpu.cuda(), input_cpu.cuda()
    model.set_input_2(input_gpu, target_gpu)
    # 先forward再eval
    model.forward()
    model.eval()

当时有个疑虑,为什么要在forward后面再加eval(),查了下相关资料,主要是在BN层以及Dropout的问题。当使用eval()时,模型会自动固定BN层以及Dropout,选取训练好的值,否则则会取平均,可能导致生成的图片颜色失真。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python的常见命令注入威胁
Feb 18 Python
Python多线程同步Lock、RLock、Semaphore、Event实例
Nov 21 Python
Python open()文件处理使用介绍
Nov 30 Python
批处理与python代码混合编程的方法
May 19 Python
python中强大的format函数实例详解
Dec 05 Python
Python设计模式之代理模式实例详解
Jan 19 Python
python redis 删除key脚本的实例
Feb 19 Python
python pandas cumsum求累计次数的用法
Jul 29 Python
Python安装依赖(包)模块方法详解
Feb 14 Python
基于python实现数组格式参数加密计算
Apr 21 Python
Flask模板引擎Jinja2使用实例
Apr 23 Python
如何使用Django Admin管理后台导入CSV
Nov 06 Python
pytorch Dropout过拟合的操作
浅谈pytorch中的dropout的概率p
May 27 #Python
让文件路径提取变得更简单的Python Path库
Pytorch中的数据集划分&正则化方法
Pytorch 如何实现常用正则化
PyTorch 实现L2正则化以及Dropout的操作
Python开发之QT解决无边框界面拖动卡屏问题(附带源码)
You might like
php 分页原理详解
2009/08/21 PHP
探讨:如何编写PHP扩展
2013/06/13 PHP
PHP curl 获取响应的状态码的方法
2014/01/13 PHP
php生成圆角图片的方法
2015/04/07 PHP
[原创]php token使用与验证示例【测试可用】
2017/08/30 PHP
详解Laravel设置多态关系模型别名的方式
2019/10/17 PHP
浅谈Laravel模板实体转义带来的坑
2019/10/22 PHP
HTA版JSMin(省略修饰语若干)基于javascript语言编写
2009/12/24 Javascript
JS代码同步文本框内容的实例方法
2013/07/12 Javascript
AngularJs Javascript MVC 框架
2016/06/20 Javascript
js轮盘抽奖实例分析
2020/04/17 Javascript
JavaScript中利用构造器函数模拟类的方法
2017/02/16 Javascript
JS Testing Properties 判断属性是否在对象里的方法
2017/10/01 Javascript
bootstrap table合并行数据并居中对齐效果
2018/10/17 Javascript
vue中组件的过渡动画及实现代码
2018/11/21 Javascript
jQuery实现移动端下拉展现新的内容回弹动画
2020/06/24 jQuery
vue v-model的用法解析
2020/10/19 Javascript
Pyhton中单行和多行注释的使用方法及规范
2016/10/11 Python
python修改list中所有元素类型的三种方法
2018/04/09 Python
简单的Python调度器Schedule详解
2019/08/30 Python
使用Python实现分别输出每个数组
2019/12/06 Python
pytorch 中的重要模块化接口nn.Module的使用
2020/04/02 Python
windows10在visual studio2019下配置使用openCV4.3.0
2020/07/14 Python
Jupyter Notebook添加代码自动补全功能的实现
2021/01/07 Python
聪明的粉丝购买门票的地方:TickPick
2018/03/09 全球购物
西班牙最大的在线滑板和街头服饰商店:Fillow.net
2019/04/15 全球购物
荷兰在线啤酒店:Beerwulf
2019/08/26 全球购物
接口可以包含哪些成员
2012/09/30 面试题
车辆维修工自我评价怎么写
2013/09/20 职场文书
自荐信格式写作方法有哪些呢
2013/11/20 职场文书
国窖1573广告词
2014/03/21 职场文书
学校评语大全
2014/05/06 职场文书
党员先锋岗事迹材料
2014/05/08 职场文书
高中课程设置方案
2014/05/28 职场文书
项目投资意向书范本
2015/05/09 职场文书
聘任通知书
2015/09/21 职场文书