pytorch 实现在测试的时候启用dropout


Posted in Python onMay 27, 2021

我们知道,dropout一般都在训练的时候使用,那么测试的时候如何也开启dropout呢?

在pytorch中,网络有train和eval两种模式,在train模式下,dropout和batch normalization会生效,而val模式下,dropout不生效,bn固定参数。

想要在测试的时候使用dropout,可以把dropout单独设为train模式,这里可以使用apply函数:

def apply_dropout(m):
    if type(m) == nn.Dropout:
        m.train()

下面是完整demo代码:

# coding: utf-8
import torch
import torch.nn as nn
import numpy as np
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc = nn.Linear(8, 8)
        self.dropout = nn.Dropout(0.5)
    def forward(self, x):
        x = self.fc(x)
        x = self.dropout(x)
        return x
net = SimpleNet()
x = torch.FloatTensor([1]*8)
net.train()
y = net(x)
print('train mode result: ', y)
net.eval()
y = net(x)
print('eval mode result: ', y)
net.eval()
y = net(x)
print('eval2 mode result: ', y)
def apply_dropout(m):
    if type(m) == nn.Dropout:
        m.train()
net.eval()
net.apply(apply_dropout)
y = net(x)
print('apply eval result:', y)

运行结果:

pytorch 实现在测试的时候启用dropout

可以看到,在eval模式下,由于dropout未生效,每次跑的结果不同,利用apply函数,将Dropout单独设为train模式,dropout就生效了。

补充:Pytorch之dropout避免过拟合测试

一.做数据

pytorch 实现在测试的时候启用dropout

pytorch 实现在测试的时候启用dropout

二.搭建神经网络

pytorch 实现在测试的时候启用dropout

pytorch 实现在测试的时候启用dropout

三.训练

pytorch 实现在测试的时候启用dropout

四.对比测试结果

注意:测试过程中,一定要注意模式切换

pytorch 实现在测试的时候启用dropout

pytorch 实现在测试的时候启用dropout

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
解析Python中的二进制位运算符
May 13 Python
Python功能键的读取方法
May 28 Python
使用Python的Bottle框架写一个简单的服务接口的示例
Aug 25 Python
Flask框架的学习指南之制作简单blog系统
Nov 20 Python
tensorflow入门之训练简单的神经网络方法
Feb 26 Python
python机器学习之随机森林(七)
Mar 26 Python
Python多进程原理与用法分析
Aug 21 Python
Python 数据库操作 SQLAlchemy的示例代码
Feb 18 Python
django项目用higcharts统计最近七天文章点击量
Aug 17 Python
python 在右键菜单中加入复制目标文件的有效存放路径(单斜杠或者双反斜杠)
Apr 08 Python
Python基于Hypothesis测试库生成测试数据
Apr 29 Python
Python 用户输入和while循环的操作
May 23 Python
使用Python脚本对GiteePages进行一键部署的使用说明
教你使用Python pypinyin库实现汉字转拼音
基于tensorflow权重文件的解读
May 26 #Python
解决Python字典查找报Keyerror的问题
浅谈tf.train.Saver()与tf.train.import_meta_graph的要点
tensorflow中的数据类型dtype用法说明
May 26 #Python
详解Python魔法方法之描述符类
May 26 #Python
You might like
php xml常用函数的集合(比较详细)
2013/06/06 PHP
yii用户注册表单验证实例
2015/12/26 PHP
在laravel中实现ORM模型使用第二个数据库设置
2019/10/24 PHP
jquery text,radio,checkbox,select操作实现代码
2009/07/09 Javascript
javascript 冒泡排序 正序和倒序实现代码
2010/12/14 Javascript
jQuery+JSON+jPlayer实现QQ空间音乐查询功能示例
2013/06/17 Javascript
jquery slibings选取同级其他元素的实现代码
2013/11/15 Javascript
使用jquery实现的一个图片延迟加载插件(含图片延迟加载原理)
2014/06/05 Javascript
JavaScript模版引擎的基本实现方法浅析
2016/02/15 Javascript
快速掌握Node.js模块封装及使用
2016/03/21 Javascript
JS实现支持Ajax验证的表单插件
2016/03/24 Javascript
限制只能输入数字的实现代码
2016/05/16 Javascript
jquery插件treegrid树状表格的使用方法详解(.Net平台)
2017/01/03 Javascript
详解webpack分离css单独打包
2017/06/21 Javascript
vue和webpack打包项目相对路径修改的方法
2018/06/15 Javascript
vue利用v-for嵌套输出多层对象,分别输出到个表的方法
2018/09/07 Javascript
[01:07]DOTA2次级职业联赛 - Fpb战队宣传片
2014/12/01 DOTA
[01:19:54]DOTA2上海特级锦标赛主赛事日 - 2 败者组第二轮#1Alliance VS EHOME
2016/03/03 DOTA
[02:58]魔廷新尊——痛苦女王至宝语音台词节选
2020/06/14 DOTA
遍历python字典几种方法总结(推荐)
2016/09/11 Python
Django查询数据库的性能优化示例代码
2017/09/24 Python
Django ORM框架的定时任务如何使用详解
2017/10/19 Python
Python实现的归并排序算法示例
2017/11/21 Python
python GUI框架pyqt5 对图片进行流式布局的方法(瀑布流flowlayout)
2020/03/12 Python
django xadmin中form_layout添加字段显示方式
2020/03/30 Python
安装多个版本的TensorFlow的方法步骤
2020/04/21 Python
django 数据库 get_or_create函数返回值是tuple的问题
2020/05/15 Python
CSS3的calc()做响应模式布局的实现方法
2017/09/06 HTML / CSS
澳大利亚男士西服品牌:M.J.Bale
2018/02/06 全球购物
波兰快递服务:Globkurier.pl
2019/11/08 全球购物
本科生职业生涯规划书范文
2014/01/21 职场文书
巾帼文明岗申报材料
2014/05/01 职场文书
本科生求职信
2014/06/17 职场文书
五好家庭事迹材料
2014/12/20 职场文书
2015年六年级班主任工作总结
2015/10/15 职场文书
Tomcat安装使用及部署Web项目的3种方法汇总
2022/08/14 Servers