pytorch 实现在测试的时候启用dropout


Posted in Python onMay 27, 2021

我们知道,dropout一般都在训练的时候使用,那么测试的时候如何也开启dropout呢?

在pytorch中,网络有train和eval两种模式,在train模式下,dropout和batch normalization会生效,而val模式下,dropout不生效,bn固定参数。

想要在测试的时候使用dropout,可以把dropout单独设为train模式,这里可以使用apply函数:

def apply_dropout(m):
    if type(m) == nn.Dropout:
        m.train()

下面是完整demo代码:

# coding: utf-8
import torch
import torch.nn as nn
import numpy as np
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc = nn.Linear(8, 8)
        self.dropout = nn.Dropout(0.5)
    def forward(self, x):
        x = self.fc(x)
        x = self.dropout(x)
        return x
net = SimpleNet()
x = torch.FloatTensor([1]*8)
net.train()
y = net(x)
print('train mode result: ', y)
net.eval()
y = net(x)
print('eval mode result: ', y)
net.eval()
y = net(x)
print('eval2 mode result: ', y)
def apply_dropout(m):
    if type(m) == nn.Dropout:
        m.train()
net.eval()
net.apply(apply_dropout)
y = net(x)
print('apply eval result:', y)

运行结果:

pytorch 实现在测试的时候启用dropout

可以看到,在eval模式下,由于dropout未生效,每次跑的结果不同,利用apply函数,将Dropout单独设为train模式,dropout就生效了。

补充:Pytorch之dropout避免过拟合测试

一.做数据

pytorch 实现在测试的时候启用dropout

pytorch 实现在测试的时候启用dropout

二.搭建神经网络

pytorch 实现在测试的时候启用dropout

pytorch 实现在测试的时候启用dropout

三.训练

pytorch 实现在测试的时候启用dropout

四.对比测试结果

注意:测试过程中,一定要注意模式切换

pytorch 实现在测试的时候启用dropout

pytorch 实现在测试的时候启用dropout

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python科学计算环境推荐——Anaconda
Jun 30 Python
python编码最佳实践之总结
Feb 14 Python
浅谈Python类里的__init__方法函数,Python类的构造函数
Dec 10 Python
利用Python如何实现数据驱动的接口自动化测试
May 11 Python
解决python nohup linux 后台运行输出的问题
May 11 Python
Python基于jieba库进行简单分词及词云功能实现方法
Jun 16 Python
Python动态参数/命名空间/函数嵌套/global和nonlocal
May 29 Python
使用PyTorch实现MNIST手写体识别代码
Jan 18 Python
Keras Convolution1D与Convolution2D区别说明
May 22 Python
浅析Python中字符串的intern机制
Oct 03 Python
Python函数调用追踪实现代码
Nov 27 Python
python 闭包函数详细介绍
Apr 19 Python
使用Python脚本对GiteePages进行一键部署的使用说明
教你使用Python pypinyin库实现汉字转拼音
基于tensorflow权重文件的解读
May 26 #Python
解决Python字典查找报Keyerror的问题
浅谈tf.train.Saver()与tf.train.import_meta_graph的要点
tensorflow中的数据类型dtype用法说明
May 26 #Python
详解Python魔法方法之描述符类
May 26 #Python
You might like
php array_map array_multisort 高效处理多维数组排序
2009/06/11 PHP
php实现图片文件与下载文件防盗链的方法
2014/11/03 PHP
PHP请求远程地址设置超时时间的解决方法
2016/10/29 PHP
php记录搜索引擎爬行记录的实现代码
2018/03/02 PHP
基于jQuery+HttpHandler实现图片裁剪效果代码(适用于论坛, SNS)
2011/09/02 Javascript
用jquery实现点击栏目背景色改变
2012/12/10 Javascript
js展开闭合效果演示代码
2013/07/24 Javascript
一个很有趣3D球状标签云兼容IE8
2014/08/22 Javascript
JavaScript模拟实现键盘打字效果
2015/06/29 Javascript
JavaScript变量的作用域全解析
2015/08/14 Javascript
jQuery+css3实现转动的正方形效果(附demo源码下载)
2016/01/27 Javascript
JS组件Bootstrap Table使用方法详解
2016/02/02 Javascript
浅谈Jquery中Ajax异步请求中的async参数的作用
2016/06/06 Javascript
Bootstrap轮播插件中图片变形的终极解决方案 使用jqthumb.js
2016/07/10 Javascript
jQuery阻止移动端遮罩层后页面滚动
2017/03/15 Javascript
ES6新特性之函数的扩展实例详解
2017/04/01 Javascript
详解ajax的data参数错误导致页面崩溃
2018/04/30 Javascript
vue实现通讯录功能
2018/07/14 Javascript
仿iPhone通讯录制作小程序自定义选择组件的实现
2019/05/23 Javascript
js实现mp3录音通过websocket实时传送+简易波形图效果
2020/06/12 Javascript
vue 限制input只能输入正数的操作
2020/08/05 Javascript
vue 如何使用递归组件
2020/10/23 Javascript
python删除文件示例分享
2014/01/28 Python
python中in在list和dict中查找效率的对比分析
2018/05/04 Python
Python数据分析matplotlib设置多个子图的间距方法
2018/08/03 Python
详解python 模拟豆瓣登录(豆瓣6.0)
2019/04/18 Python
几款Python编译器比较与推荐(小结)
2020/10/15 Python
Pycharm操作Git及GitHub的步骤详解
2020/10/27 Python
python 利用matplotlib在3D空间中绘制平面的案例
2021/02/06 Python
企业演讲稿范文
2013/12/28 职场文书
主持人演讲稿范文
2013/12/28 职场文书
财产公证书样本
2014/04/04 职场文书
小学学习委员竞选稿
2015/11/20 职场文书
初三化学教学反思
2016/02/22 职场文书
一条慢SQL语句引发的改造之路
2022/03/16 MySQL
HTML常用标签超详细整理
2022/03/19 HTML / CSS