PyTorch dropout设置训练和测试模式的实现


Posted in Python onMay 27, 2021

看代码吧~

class Net(nn.Module):
…
model = Net()
…
model.train() # 把module设成训练模式,对Dropout和BatchNorm有影响
model.eval() # 把module设置为预测模式,对Dropout和BatchNorm模块有影响

补充:Pytorch遇到的坑——训练模式和测试模式切换

由于训练的时候Dropout和BN层起作用,每个batch BN层的参数不一样,dropout在训练时随机失效点具有随机性,所以训练和测试要区分开来。

使用时切记要根据实际情况切换:

model.train()
model.eval()

补充:Pytorch在测试与训练过程中的验证结果不一致问题

引言

今天在使用Pytorch导入此前保存的模型进行测试,在过程中发现输出的结果与验证结果差距甚大,经过排查后发现是forward与eval()顺序问题。

现象

此前的错误代码是

input_cpu = torch.ones((1, 2, 160, 160))
    target_cpu =torch.ones((1, 2, 160, 160))
    target_gpu, input_gpu = target_cpu.cuda(), input_cpu.cuda()
    model.set_input_2(input_gpu, target_gpu)
    model.eval()
    model.forward()

应该改为

input_cpu = torch.ones((1, 2, 160, 160))
    target_cpu =torch.ones((1, 2, 160, 160))
    target_gpu, input_gpu = target_cpu.cuda(), input_cpu.cuda()
    model.set_input_2(input_gpu, target_gpu)
    # 先forward再eval
    model.forward()
    model.eval()

当时有个疑虑,为什么要在forward后面再加eval(),查了下相关资料,主要是在BN层以及Dropout的问题。当使用eval()时,模型会自动固定BN层以及Dropout,选取训练好的值,否则则会取平均,可能导致生成的图片颜色失真。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
用python + openpyxl处理excel2007文档思路以及心得
Jul 14 Python
python smtplib模块发送SSL/TLS安全邮件实例
Apr 08 Python
Python实现字典的key和values的交换
Aug 04 Python
初探TensorFLow从文件读取图片的四种方式
Feb 06 Python
详解Python并发编程之从性能角度来初探并发编程
Aug 23 Python
如何使用python进行pdf文件分割
Nov 11 Python
Python3常用内置方法代码实例
Nov 18 Python
python生成13位或16位时间戳以及反向解析时间戳的实例
Mar 03 Python
python爬虫实现获取下一页代码
Mar 13 Python
使用python采集Excel表中某一格数据
May 14 Python
python使用nibabel和sitk读取保存nii.gz文件实例
Jul 01 Python
python的setattr函数实例用法
Dec 16 Python
pytorch Dropout过拟合的操作
浅谈pytorch中的dropout的概率p
May 27 #Python
让文件路径提取变得更简单的Python Path库
Pytorch中的数据集划分&正则化方法
Pytorch 如何实现常用正则化
PyTorch 实现L2正则化以及Dropout的操作
Python开发之QT解决无边框界面拖动卡屏问题(附带源码)
You might like
JAVA/JSP学习系列之四
2006/10/09 PHP
用PHP读注册表
2006/10/09 PHP
使用php发送有附件的电子邮件-(PHPMailer使用的实例分析)
2013/04/26 PHP
php中替换字符串中的空格为逗号','的方法
2014/06/09 PHP
php中让人头疼的浮点数运算分析
2016/10/10 PHP
PHP异常类及异常处理操作实例详解
2018/12/19 PHP
js兼容的placeholder属性详解
2013/08/18 Javascript
分享Javascript中最常用的55个经典小技巧
2013/11/29 Javascript
介绍JavaScript的一个微型模版
2015/06/24 Javascript
使用jquery实现仿百度自动补全特效
2015/07/23 Javascript
JS条形码(一维码)插件JsBarcode用法详解【编码类型、参数、属性】
2017/04/19 Javascript
javaScript 逻辑运算符使用技巧整理
2017/05/03 Javascript
Js利用console计算代码运行时间的方法示例
2017/09/24 Javascript
React组件内事件传参实现tab切换的示例代码
2018/07/04 Javascript
在vue中更换字体,本地存储字体非引用在线字体库的方法
2018/09/28 Javascript
jquery.param()实现数组或对象的序列化方法
2018/10/08 jQuery
详解关于Angular4 ng-zorro使用过程中遇到的问题
2018/12/05 Javascript
探索JavaScript中私有成员的相关知识
2019/06/13 Javascript
vue实现微信浏览器左上角返回按钮拦截功能
2020/01/18 Javascript
vue Element左侧无限级菜单实现
2020/06/10 Javascript
js+h5 canvas实现图片验证码
2020/10/11 Javascript
Antd的Table组件嵌套Table以及选择框联动操作
2020/10/24 Javascript
python实现支持目录FTP上传下载文件的方法
2015/06/03 Python
详解Python函数作用域的LEGB顺序
2016/05/14 Python
matplotlib在python上绘制3D散点图实例详解
2017/12/09 Python
详解Selenium+PhantomJS+python简单实现爬虫的功能
2019/07/14 Python
Python人工智能之路 jieba gensim 最好别分家之最简单的相似度实现
2019/08/13 Python
Farfetch澳大利亚官网:Farfetch Australia
2020/04/26 全球购物
中学自我评价
2014/01/31 职场文书
感恩节红领巾广播稿
2014/02/11 职场文书
中学生教师节演讲稿
2014/09/03 职场文书
2014第二批党员干部对照“四风”找差距检查材料思想汇报
2014/09/18 职场文书
社区国庆节活动总结
2015/03/23 职场文书
幼儿园小班班务总结
2015/08/03 职场文书
奖学金申请个人主要事迹材料
2015/11/04 职场文书
Python包argparse模块常用方法
2021/06/04 Python