PyTorch dropout设置训练和测试模式的实现


Posted in Python onMay 27, 2021

看代码吧~

class Net(nn.Module):
…
model = Net()
…
model.train() # 把module设成训练模式,对Dropout和BatchNorm有影响
model.eval() # 把module设置为预测模式,对Dropout和BatchNorm模块有影响

补充:Pytorch遇到的坑——训练模式和测试模式切换

由于训练的时候Dropout和BN层起作用,每个batch BN层的参数不一样,dropout在训练时随机失效点具有随机性,所以训练和测试要区分开来。

使用时切记要根据实际情况切换:

model.train()
model.eval()

补充:Pytorch在测试与训练过程中的验证结果不一致问题

引言

今天在使用Pytorch导入此前保存的模型进行测试,在过程中发现输出的结果与验证结果差距甚大,经过排查后发现是forward与eval()顺序问题。

现象

此前的错误代码是

input_cpu = torch.ones((1, 2, 160, 160))
    target_cpu =torch.ones((1, 2, 160, 160))
    target_gpu, input_gpu = target_cpu.cuda(), input_cpu.cuda()
    model.set_input_2(input_gpu, target_gpu)
    model.eval()
    model.forward()

应该改为

input_cpu = torch.ones((1, 2, 160, 160))
    target_cpu =torch.ones((1, 2, 160, 160))
    target_gpu, input_gpu = target_cpu.cuda(), input_cpu.cuda()
    model.set_input_2(input_gpu, target_gpu)
    # 先forward再eval
    model.forward()
    model.eval()

当时有个疑虑,为什么要在forward后面再加eval(),查了下相关资料,主要是在BN层以及Dropout的问题。当使用eval()时,模型会自动固定BN层以及Dropout,选取训练好的值,否则则会取平均,可能导致生成的图片颜色失真。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python获取局域网占带宽最大3个ip的方法
Jul 09 Python
Python处理XML格式数据的方法详解
Mar 21 Python
Python re 模块findall() 函数返回值展现方式解析
Aug 09 Python
Windows10下Tensorflow2.0 安装及环境配置教程(图文)
Nov 21 Python
tensorflow实现打印ckpt模型保存下的变量名称及变量值
Jan 04 Python
解决Python import docx出错DLL load failed的问题
Feb 13 Python
Python面向对象程序设计之类和对象、实例变量、类变量用法分析
Mar 23 Python
python数据库编程 ODBC方式实现通讯录
Mar 27 Python
django 解决自定义序列化返回处理数据为null的问题
May 20 Python
python语音识别指南终极版(有这一篇足矣)
Sep 09 Python
Python爬取你好李焕英豆瓣短评生成词云的示例代码
Feb 24 Python
这样写python注释让代码更加的优雅
Jun 02 Python
pytorch Dropout过拟合的操作
浅谈pytorch中的dropout的概率p
May 27 #Python
让文件路径提取变得更简单的Python Path库
Pytorch中的数据集划分&正则化方法
Pytorch 如何实现常用正则化
PyTorch 实现L2正则化以及Dropout的操作
Python开发之QT解决无边框界面拖动卡屏问题(附带源码)
You might like
PHP编实现程动态图像的创建代码
2008/09/28 PHP
打造超酷的PHP数据饼图效果实现代码
2011/11/23 PHP
php解析xml提示Invalid byte 1 of 1-byte UTF-8 sequence错误的处理方法
2013/11/14 PHP
PHP5.5在windows安装使用memcached服务端的方法
2014/04/16 PHP
WordPress中设置Post Type自定义文章类型的实例教程
2016/05/10 PHP
利用PHP访问带有密码的Redis方法示例
2017/02/09 PHP
Hutia 的 JS 代码集
2006/10/24 Javascript
js下通过prototype扩展实现indexOf的代码
2010/12/08 Javascript
js异常捕获方法介绍
2013/04/10 Javascript
jquery scroll()区分横向纵向滚动条的方法
2014/04/04 Javascript
JS/Jquery判断对象为空的方法
2015/06/11 Javascript
详解Wondows下Node.js使用MongoDB的环境配置
2016/03/01 Javascript
修改js confirm alert 提示框文字的简单实例
2016/06/10 Javascript
JS实现字符串转驼峰格式的方法
2016/12/16 Javascript
vue-cli 组件的导入与使用教程详解
2018/04/11 Javascript
vue router+vuex实现首页登录验证判断逻辑
2018/05/17 Javascript
微信小程序使用gitee进行版本管理
2018/09/20 Javascript
Vue中的vue-resource示例详解
2018/11/02 Javascript
async/await让异步操作同步执行的方法详解
2019/11/01 Javascript
原生JavaScript实现五子棋游戏
2020/11/09 Javascript
python中as用法实例分析
2015/04/30 Python
使用CodeMirror实现Python3在线编辑器的示例代码
2019/01/14 Python
python super函数使用方法详解
2020/02/14 Python
python多进程下的生产者和消费者模型
2020/05/07 Python
keras slice layer 层实现方式
2020/06/11 Python
python怎么判断模块安装完成
2020/06/19 Python
Python实现上下文管理器的方法
2020/08/07 Python
python爬虫中抓取指数的实例讲解
2020/12/01 Python
英国礼品和生活方式品牌:Treat Republic
2020/11/21 全球购物
2014法院干警廉洁警示教育思想汇报
2014/09/13 职场文书
教代会开幕词
2015/01/28 职场文书
紧急通知
2015/04/17 职场文书
导游词之宁夏贺兰山岩画
2019/11/08 职场文书
Promise面试题详解之控制并发
2021/05/14 面试题
windows安装python超详细图文教程
2021/05/21 Python
自从在 IDEA 中用了热部署神器 JRebel 之后,开发效率提升了 10(真棒)
2021/06/26 Java/Android