PyTorch dropout设置训练和测试模式的实现


Posted in Python onMay 27, 2021

看代码吧~

class Net(nn.Module):
…
model = Net()
…
model.train() # 把module设成训练模式,对Dropout和BatchNorm有影响
model.eval() # 把module设置为预测模式,对Dropout和BatchNorm模块有影响

补充:Pytorch遇到的坑——训练模式和测试模式切换

由于训练的时候Dropout和BN层起作用,每个batch BN层的参数不一样,dropout在训练时随机失效点具有随机性,所以训练和测试要区分开来。

使用时切记要根据实际情况切换:

model.train()
model.eval()

补充:Pytorch在测试与训练过程中的验证结果不一致问题

引言

今天在使用Pytorch导入此前保存的模型进行测试,在过程中发现输出的结果与验证结果差距甚大,经过排查后发现是forward与eval()顺序问题。

现象

此前的错误代码是

input_cpu = torch.ones((1, 2, 160, 160))
    target_cpu =torch.ones((1, 2, 160, 160))
    target_gpu, input_gpu = target_cpu.cuda(), input_cpu.cuda()
    model.set_input_2(input_gpu, target_gpu)
    model.eval()
    model.forward()

应该改为

input_cpu = torch.ones((1, 2, 160, 160))
    target_cpu =torch.ones((1, 2, 160, 160))
    target_gpu, input_gpu = target_cpu.cuda(), input_cpu.cuda()
    model.set_input_2(input_gpu, target_gpu)
    # 先forward再eval
    model.forward()
    model.eval()

当时有个疑虑,为什么要在forward后面再加eval(),查了下相关资料,主要是在BN层以及Dropout的问题。当使用eval()时,模型会自动固定BN层以及Dropout,选取训练好的值,否则则会取平均,可能导致生成的图片颜色失真。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
如何使用七牛Python SDK写一个同步脚本及使用教程
Aug 23 Python
python处理html转义字符的方法详解
Jul 01 Python
Python类的动态修改的实例方法
Mar 24 Python
python 字符串转列表 list 出现\ufeff的解决方法
Jun 22 Python
Python利用递归和walk()遍历目录文件的方法示例
Jul 14 Python
使用Python获取网段IP个数以及地址清单的方法
Nov 01 Python
Python两个字典键同值相加的几种方法
Mar 05 Python
Django中create和save方法的不同
Aug 13 Python
python使用beautifulsoup4爬取酷狗音乐代码实例
Dec 04 Python
Python安装tar.gz格式文件方法详解
Jan 19 Python
Python中 Global和Nonlocal的用法详解
Jan 20 Python
Python opencv缺陷检测的实现及问题解决
Apr 24 Python
pytorch Dropout过拟合的操作
浅谈pytorch中的dropout的概率p
May 27 #Python
让文件路径提取变得更简单的Python Path库
Pytorch中的数据集划分&正则化方法
Pytorch 如何实现常用正则化
PyTorch 实现L2正则化以及Dropout的操作
Python开发之QT解决无边框界面拖动卡屏问题(附带源码)
You might like
成为好程序员必须避免的5个坏习惯
2014/07/04 PHP
php读取csv数据保存到数组的方法
2015/01/03 PHP
Nigma vs Alliance BO5 第一场2.14
2021/03/10 DOTA
Jquery实现页面加载时弹出对话框代码
2013/04/19 Javascript
js 弹出新页面避免被浏览器、ad拦截的一种新方法
2014/04/30 Javascript
jQuery实现列表内容的动态载入特效
2015/08/08 Javascript
Bootstrap每天必学之栅格系统(布局)
2015/11/25 Javascript
jquery中ajax跨域方法实例分析
2015/12/18 Javascript
JS简单获取客户端IP地址的方法【调用搜狐接口】
2016/09/05 Javascript
轻松掌握JavaScript状态模式
2016/09/07 Javascript
jQuery实现边框动态效果的实例代码
2016/09/23 Javascript
[原创]JS基于FileSaver.js插件实现文件保存功能示例
2016/12/08 Javascript
浅析jsopn跨域请求原理及cors(跨域资源共享)的完美解决方法
2017/02/06 Javascript
Jquery获取radio选中的值
2017/05/05 jQuery
用js将long型数据转换成date型或datetime型的实例
2017/07/03 Javascript
nodejs超出最大的调用栈错误问题
2017/12/27 NodeJs
Js 利用正则表达式和replace函数获取string中所有被匹配到的文本(推荐)
2018/10/28 Javascript
vue动态绑定class的几种常用方式小结
2019/05/21 Javascript
如何用原生js写一个弹窗消息提醒插件
2019/05/24 Javascript
对layui初始化列表的CheckBox属性详解
2019/09/13 Javascript
vuex中store存储store.commit和store.dispatch的用法
2020/07/24 Javascript
基于JavaScript实现随机点名器
2021/02/25 Javascript
[52:41]OG vs IG 2018国际邀请赛小组赛BO2 第二场 8.18
2018/08/20 DOTA
解决DataFrame排序sort的问题
2018/06/07 Python
python高级特性和高阶函数及使用详解
2018/10/17 Python
简单了解Django ORM常用字段类型及参数配置
2020/01/07 Python
Django ORM判断查询结果是否为空,判断django中的orm为空实例
2020/07/09 Python
Python读取Excel一列并计算所有对象出现次数的方法
2020/09/04 Python
求两个数的乘积和商数,该作用由宏定义来实现
2013/03/13 面试题
化工实习心得体会
2014/09/09 职场文书
法定代表人授权委托书格式
2014/10/14 职场文书
督导岗位职责范本
2015/04/10 职场文书
毕业设计致谢语
2015/05/14 职场文书
python实现简单的井字棋
2021/05/26 Python
mysql外连接与内连接查询的不同之处
2021/06/03 MySQL
MySQL系列之四 SQL语法
2021/07/02 MySQL