PyTorch dropout设置训练和测试模式的实现


Posted in Python onMay 27, 2021

看代码吧~

class Net(nn.Module):
…
model = Net()
…
model.train() # 把module设成训练模式,对Dropout和BatchNorm有影响
model.eval() # 把module设置为预测模式,对Dropout和BatchNorm模块有影响

补充:Pytorch遇到的坑——训练模式和测试模式切换

由于训练的时候Dropout和BN层起作用,每个batch BN层的参数不一样,dropout在训练时随机失效点具有随机性,所以训练和测试要区分开来。

使用时切记要根据实际情况切换:

model.train()
model.eval()

补充:Pytorch在测试与训练过程中的验证结果不一致问题

引言

今天在使用Pytorch导入此前保存的模型进行测试,在过程中发现输出的结果与验证结果差距甚大,经过排查后发现是forward与eval()顺序问题。

现象

此前的错误代码是

input_cpu = torch.ones((1, 2, 160, 160))
    target_cpu =torch.ones((1, 2, 160, 160))
    target_gpu, input_gpu = target_cpu.cuda(), input_cpu.cuda()
    model.set_input_2(input_gpu, target_gpu)
    model.eval()
    model.forward()

应该改为

input_cpu = torch.ones((1, 2, 160, 160))
    target_cpu =torch.ones((1, 2, 160, 160))
    target_gpu, input_gpu = target_cpu.cuda(), input_cpu.cuda()
    model.set_input_2(input_gpu, target_gpu)
    # 先forward再eval
    model.forward()
    model.eval()

当时有个疑虑,为什么要在forward后面再加eval(),查了下相关资料,主要是在BN层以及Dropout的问题。当使用eval()时,模型会自动固定BN层以及Dropout,选取训练好的值,否则则会取平均,可能导致生成的图片颜色失真。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中wx将图标显示在右下角的脚本代码
Mar 08 Python
Python实现各种排序算法的代码示例总结
Dec 11 Python
python批量添加zabbix Screens的两个脚本分享
Jan 16 Python
Python使用base64模块进行二进制数据编码详解
Jan 11 Python
在pycharm中python切换解释器失败的解决方法
Oct 29 Python
Python使用pydub库对mp3与wav格式进行互转的方法
Jan 10 Python
windows系统中Python多版本与jupyter notebook使用虚拟环境的过程
May 15 Python
bluepy 一款python封装的BLE利器简单介绍
Jun 25 Python
解决python3中os.popen()出错的问题
Nov 19 Python
2021年值得向Python开发者推荐的VS Code扩展插件
Jan 25 Python
python如何进行基准测试
Apr 26 Python
Python使用plt.boxplot()函数绘制箱图、常用方法以及含义详解
Aug 14 Python
pytorch Dropout过拟合的操作
浅谈pytorch中的dropout的概率p
May 27 #Python
让文件路径提取变得更简单的Python Path库
Pytorch中的数据集划分&正则化方法
Pytorch 如何实现常用正则化
PyTorch 实现L2正则化以及Dropout的操作
Python开发之QT解决无边框界面拖动卡屏问题(附带源码)
You might like
PHPShop存在多个安全漏洞
2006/10/09 PHP
php计算几分钟前、几小时前、几天前的几个函数、类分享
2014/04/09 PHP
PHP封装的XML简单操作类完整实例
2017/11/13 PHP
PHP通过文件路径获取文件名的实例代码
2018/10/14 PHP
PHP的mysqli_thread_id()函数讲解
2019/01/24 PHP
Swoole4.4协程抢占式调度器详解
2019/05/23 PHP
让你的PHP,APACHE,NGINX支持大文件上传
2021/03/09 PHP
用户注册常用javascript代码
2009/08/29 Javascript
深入理解JavaScript定时机制
2010/10/29 Javascript
ASP.NET jQuery 实例15 通过控件CustomValidator验证CheckBoxList
2012/02/03 Javascript
Javascript 按位与运算符 (&)使用介绍
2014/02/04 Javascript
js实现的标题栏新消息闪烁提示效果
2014/06/06 Javascript
javascript模拟post提交隐藏地址栏的参数
2014/09/03 Javascript
innerHTML属性,outerHTML属性,textContent属性,innerText属性区别详解
2015/03/13 Javascript
微信小程序图片横向左右滑动案例
2017/05/19 Javascript
jQuery扇形定时器插件pietimer使用方法详解
2017/07/18 jQuery
Javascript中绑定click事件的四种方式介绍
2018/10/26 Javascript
Nuxt.js实现一个SSR的前端博客的示例代码
2019/09/06 Javascript
vue css 引入asstes中的图片无法显示的四种解决方法
2020/03/16 Javascript
深入webpack打包原理及loader和plugin的实现
2020/05/06 Javascript
JS如何生成动态列表
2020/09/22 Javascript
Python(Tornado)模拟登录小米抢手机
2013/11/12 Python
python实现的简单窗口倒计时界面实例
2015/05/05 Python
TensorFlow实现随机训练和批量训练的方法
2018/04/28 Python
python实现名片管理系统项目
2019/04/26 Python
Python 3.8 新功能全解
2019/07/25 Python
pytorch 使用单个GPU与多个GPU进行训练与测试的方法
2019/08/19 Python
python实现上传文件到linux指定目录的方法
2020/01/03 Python
Pytorch 扩展Tensor维度、压缩Tensor维度的方法
2020/09/09 Python
Html5大文件断点续传实现方法
2015/12/05 HTML / CSS
德国苹果商店:MacTrade
2020/05/18 全球购物
考试不及格检讨书
2014/01/09 职场文书
数控专业毕业生求职信
2014/06/12 职场文书
送达通知书
2015/04/25 职场文书
公安干警正风肃纪心得体会
2016/01/15 职场文书
委托书范本格式
2019/04/18 职场文书