pytorch 实现在测试的时候启用dropout


Posted in Python onMay 27, 2021

我们知道,dropout一般都在训练的时候使用,那么测试的时候如何也开启dropout呢?

在pytorch中,网络有train和eval两种模式,在train模式下,dropout和batch normalization会生效,而val模式下,dropout不生效,bn固定参数。

想要在测试的时候使用dropout,可以把dropout单独设为train模式,这里可以使用apply函数:

def apply_dropout(m):
    if type(m) == nn.Dropout:
        m.train()

下面是完整demo代码:

# coding: utf-8
import torch
import torch.nn as nn
import numpy as np
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc = nn.Linear(8, 8)
        self.dropout = nn.Dropout(0.5)
    def forward(self, x):
        x = self.fc(x)
        x = self.dropout(x)
        return x
net = SimpleNet()
x = torch.FloatTensor([1]*8)
net.train()
y = net(x)
print('train mode result: ', y)
net.eval()
y = net(x)
print('eval mode result: ', y)
net.eval()
y = net(x)
print('eval2 mode result: ', y)
def apply_dropout(m):
    if type(m) == nn.Dropout:
        m.train()
net.eval()
net.apply(apply_dropout)
y = net(x)
print('apply eval result:', y)

运行结果:

pytorch 实现在测试的时候启用dropout

可以看到,在eval模式下,由于dropout未生效,每次跑的结果不同,利用apply函数,将Dropout单独设为train模式,dropout就生效了。

补充:Pytorch之dropout避免过拟合测试

一.做数据

pytorch 实现在测试的时候启用dropout

pytorch 实现在测试的时候启用dropout

二.搭建神经网络

pytorch 实现在测试的时候启用dropout

pytorch 实现在测试的时候启用dropout

三.训练

pytorch 实现在测试的时候启用dropout

四.对比测试结果

注意:测试过程中,一定要注意模式切换

pytorch 实现在测试的时候启用dropout

pytorch 实现在测试的时候启用dropout

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python threading模块操作多线程介绍
Apr 08 Python
在Python的struct模块中进行数据格式转换的方法
Jun 17 Python
python使用MySQLdb访问mysql数据库的方法
Aug 03 Python
日常整理python执行系统命令的常见方法(全)
Oct 22 Python
Python实现的归并排序算法示例
Nov 21 Python
Django中cookie的基本使用方法示例
Feb 03 Python
python画折线图的程序
Jul 26 Python
python2和python3的输入和输出区别介绍
Nov 20 Python
Python实现微信好友的数据分析
Dec 16 Python
Pycharm学生免费专业版安装教程的方法步骤
Sep 24 Python
Python3.9.0 a1安装pygame出错解决全过程(小结)
Feb 02 Python
据Python爬虫不靠谱预测可知今年双十一销售额将超过6000亿元
Nov 11 Python
使用Python脚本对GiteePages进行一键部署的使用说明
教你使用Python pypinyin库实现汉字转拼音
基于tensorflow权重文件的解读
May 26 #Python
解决Python字典查找报Keyerror的问题
浅谈tf.train.Saver()与tf.train.import_meta_graph的要点
tensorflow中的数据类型dtype用法说明
May 26 #Python
详解Python魔法方法之描述符类
May 26 #Python
You might like
一个分页的论坛
2006/10/09 PHP
FCKeditor的安装(PHP)
2007/01/13 PHP
codeigniter自带数据库类使用方法说明
2014/03/25 PHP
destoon利用Rewrite规则设置网站安全
2014/06/21 PHP
PHP使用array_fill定义多维数组的方法
2015/03/18 PHP
如何在HTML 中嵌入 PHP 代码
2015/05/13 PHP
详解PHP中的mb_detect_encoding函数使用方法
2015/08/18 PHP
PHP使用MPDF类生成PDF的方法
2015/12/08 PHP
php获得客户端浏览器名称及版本的方法(基于ECShop函数)
2015/12/23 PHP
laravel 5.5 关闭token的3种实现方式
2019/10/24 PHP
分析 JavaScript 中令人困惑的变量赋值
2007/08/13 Javascript
Javascript document.referrer判断访客来源网址
2020/05/15 Javascript
可以将word转成html的js代码
2010/04/11 Javascript
jquery实现效果比较好的table选中行颜色
2014/03/25 Javascript
jQuery选择器源码解读(八):addCombinator函数
2015/03/31 Javascript
跟我学习javascript的call(),apply(),bind()与回调
2015/11/16 Javascript
自学实现angularjs依赖注入
2016/12/20 Javascript
Vue.js中对css的操作(修改)具体方式详解
2018/10/30 Javascript
JavaScript实现图片的放大缩小及拖拽功能示例
2019/05/14 Javascript
layer弹出层自适应高度,垂直水平居中的实现
2019/09/16 Javascript
整理 node-sass 安装失败的原因及解决办法(小结)
2020/02/19 Javascript
uni-app使用countdown插件实现倒计时
2020/11/01 Javascript
[03:39]DOTA2英雄梦之声_第05期_幽鬼
2014/06/23 DOTA
Python实现获取命令行输出结果的方法
2017/06/10 Python
Python实现解析Bit Torrent种子文件内容的方法
2017/08/29 Python
利用python操作SQLite数据库及文件操作详解
2017/09/22 Python
对python中使用requests模块参数编码的不同处理方法
2018/05/18 Python
PyCharm-错误-找不到指定文件python.exe的解决方法
2019/07/01 Python
python3 写一个WAV音频文件播放器的代码
2019/09/27 Python
python通过文本在一个图中画多条线的实例
2020/02/21 Python
Python爬虫获取页面所有URL链接过程详解
2020/06/04 Python
opencv 形态学变换(开运算,闭运算,梯度运算)
2020/07/07 Python
美国中西部家用医疗设备商店:Med Mart(轮椅、踏板车、升降机等)
2019/04/26 全球购物
国贸专业的职业规划书
2014/03/15 职场文书
go语言求任意类型切片的长度操作
2021/04/26 Golang
win10电脑关机快捷键是哪个 win10快速关机的几种方法
2022/08/14 数码科技