PyTorch dropout设置训练和测试模式的实现


Posted in Python onMay 27, 2021

看代码吧~

class Net(nn.Module):
…
model = Net()
…
model.train() # 把module设成训练模式,对Dropout和BatchNorm有影响
model.eval() # 把module设置为预测模式,对Dropout和BatchNorm模块有影响

补充:Pytorch遇到的坑——训练模式和测试模式切换

由于训练的时候Dropout和BN层起作用,每个batch BN层的参数不一样,dropout在训练时随机失效点具有随机性,所以训练和测试要区分开来。

使用时切记要根据实际情况切换:

model.train()
model.eval()

补充:Pytorch在测试与训练过程中的验证结果不一致问题

引言

今天在使用Pytorch导入此前保存的模型进行测试,在过程中发现输出的结果与验证结果差距甚大,经过排查后发现是forward与eval()顺序问题。

现象

此前的错误代码是

input_cpu = torch.ones((1, 2, 160, 160))
    target_cpu =torch.ones((1, 2, 160, 160))
    target_gpu, input_gpu = target_cpu.cuda(), input_cpu.cuda()
    model.set_input_2(input_gpu, target_gpu)
    model.eval()
    model.forward()

应该改为

input_cpu = torch.ones((1, 2, 160, 160))
    target_cpu =torch.ones((1, 2, 160, 160))
    target_gpu, input_gpu = target_cpu.cuda(), input_cpu.cuda()
    model.set_input_2(input_gpu, target_gpu)
    # 先forward再eval
    model.forward()
    model.eval()

当时有个疑虑,为什么要在forward后面再加eval(),查了下相关资料,主要是在BN层以及Dropout的问题。当使用eval()时,模型会自动固定BN层以及Dropout,选取训练好的值,否则则会取平均,可能导致生成的图片颜色失真。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中获得当前目录和上级目录的实现方法
Oct 12 Python
利用pandas将numpy数组导出生成excel的实例
Jun 14 Python
Django中ORM外键和表的关系详解
May 20 Python
基于Python实现扑克牌面试题
Dec 11 Python
Python SSL证书验证问题解决方案
Jan 13 Python
Python基于read(size)方法读取超大文件
Mar 12 Python
django之从html页面表单获取输入的数据实例
Mar 16 Python
Python如何获取文件指定行的内容
May 27 Python
pygame用blit()实现动画效果的示例代码
May 28 Python
python基于socket函数实现端口扫描
May 28 Python
PyCharm2019.3永久激活破解详细图文教程,亲测可用(不定期更新)
Oct 29 Python
python 实现一个简单的线性回归案例
Dec 17 Python
pytorch Dropout过拟合的操作
浅谈pytorch中的dropout的概率p
May 27 #Python
让文件路径提取变得更简单的Python Path库
Pytorch中的数据集划分&正则化方法
Pytorch 如何实现常用正则化
PyTorch 实现L2正则化以及Dropout的操作
Python开发之QT解决无边框界面拖动卡屏问题(附带源码)
You might like
PHP临时文件的安全性分析
2014/07/04 PHP
Laravel实现短信注册的示例代码
2018/05/29 PHP
对laravel的session获取与存取方法详解
2019/10/08 PHP
asp javascript 实现关闭窗口时保存数据的办法
2007/11/24 Javascript
JS求平均值的小例子
2013/11/29 Javascript
javascript自动恢复文本框点击清除后的默认文本
2016/01/12 Javascript
深入理解$.each和$(selector).each
2016/05/15 Javascript
深入解析桶排序算法及Node.js上JavaScript的代码实现
2016/07/06 Javascript
Javascript中的神器——Promise
2017/02/08 Javascript
JS实现仿UC浏览器前进后退效果的实例代码
2017/07/17 Javascript
js异步编程小技巧详解
2017/08/14 Javascript
vue中页面跳转拦截器的实现方法
2017/08/23 Javascript
vue使用rem实现 移动端屏幕适配
2018/09/26 Javascript
[05:04]完美世界携手游戏风云打造 卡尔工作室地图界面篇
2013/04/23 DOTA
[54:54]Newbee vs Serenity 2018国际邀请赛小组赛BO2 第一场 8.17
2018/08/18 DOTA
Python中字典的基本知识初步介绍
2015/05/21 Python
python实现可以断点续传和并发的ftp程序
2016/09/13 Python
Python3.5编程实现修改IIS WEB.CONFIG的方法示例
2017/08/18 Python
pandas dataframe的合并实现(append, merge, concat)
2019/06/24 Python
Flask框架学习笔记之路由和反向路由详解【图文与实例】
2019/08/12 Python
python多进程(加入进程池)操作常见案例
2019/10/21 Python
Python 中 -m 的典型用法、原理解析与发展演变
2019/11/11 Python
Python tkinter常用操作代码实例
2020/01/03 Python
基于Python获取照片的GPS位置信息
2020/01/20 Python
Python3操作读写CSV文件使用包过程解析
2020/04/10 Python
keras模型保存为tensorflow的二进制模型方式
2020/05/25 Python
python属于软件吗
2020/06/18 Python
pytorch 查看cuda 版本方式
2020/06/23 Python
HTML5学习笔记之html5与传统html区别
2016/01/06 HTML / CSS
2013年大学生的自我鉴定
2013/10/24 职场文书
实用求职信范文分享
2013/12/25 职场文书
质量主管工作职责
2014/09/26 职场文书
奖金申请报告模板
2015/05/15 职场文书
《折线统计图》教学反思
2016/02/22 职场文书
授权协议书范本(3篇)
2019/10/15 职场文书
Python机器学习之逻辑回归
2021/05/11 Python