pytorch 实现在测试的时候启用dropout


Posted in Python onMay 27, 2021

我们知道,dropout一般都在训练的时候使用,那么测试的时候如何也开启dropout呢?

在pytorch中,网络有train和eval两种模式,在train模式下,dropout和batch normalization会生效,而val模式下,dropout不生效,bn固定参数。

想要在测试的时候使用dropout,可以把dropout单独设为train模式,这里可以使用apply函数:

def apply_dropout(m):
    if type(m) == nn.Dropout:
        m.train()

下面是完整demo代码:

# coding: utf-8
import torch
import torch.nn as nn
import numpy as np
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc = nn.Linear(8, 8)
        self.dropout = nn.Dropout(0.5)
    def forward(self, x):
        x = self.fc(x)
        x = self.dropout(x)
        return x
net = SimpleNet()
x = torch.FloatTensor([1]*8)
net.train()
y = net(x)
print('train mode result: ', y)
net.eval()
y = net(x)
print('eval mode result: ', y)
net.eval()
y = net(x)
print('eval2 mode result: ', y)
def apply_dropout(m):
    if type(m) == nn.Dropout:
        m.train()
net.eval()
net.apply(apply_dropout)
y = net(x)
print('apply eval result:', y)

运行结果:

pytorch 实现在测试的时候启用dropout

可以看到,在eval模式下,由于dropout未生效,每次跑的结果不同,利用apply函数,将Dropout单独设为train模式,dropout就生效了。

补充:Pytorch之dropout避免过拟合测试

一.做数据

pytorch 实现在测试的时候启用dropout

pytorch 实现在测试的时候启用dropout

二.搭建神经网络

pytorch 实现在测试的时候启用dropout

pytorch 实现在测试的时候启用dropout

三.训练

pytorch 实现在测试的时候启用dropout

四.对比测试结果

注意:测试过程中,一定要注意模式切换

pytorch 实现在测试的时候启用dropout

pytorch 实现在测试的时候启用dropout

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
K-means聚类算法介绍与利用python实现的代码示例
Nov 13 Python
Scrapy框架CrawlSpiders的介绍以及使用详解
Nov 29 Python
Python+matplotlib绘制不同大小和颜色散点图实例
Jan 19 Python
Python爬虫实现验证码登录代码实例
May 10 Python
pandas实现将dataframe满足某一条件的值选出
Jun 12 Python
Python Django form 组件动态从数据库取choices数据实例
May 19 Python
python查看矩阵的行列号以及维数方式
May 22 Python
Python如何在循环内使用list.remove()
Jun 01 Python
python+selenium 简易地疫情信息自动打卡签到功能的实现代码
Aug 22 Python
Pytorch之Tensor和Numpy之间的转换的实现方法
Sep 03 Python
Pytorch实现WGAN用于动漫头像生成
Mar 04 Python
Python常遇到的错误和异常
Nov 02 Python
使用Python脚本对GiteePages进行一键部署的使用说明
教你使用Python pypinyin库实现汉字转拼音
基于tensorflow权重文件的解读
May 26 #Python
解决Python字典查找报Keyerror的问题
浅谈tf.train.Saver()与tf.train.import_meta_graph的要点
tensorflow中的数据类型dtype用法说明
May 26 #Python
详解Python魔法方法之描述符类
May 26 #Python
You might like
为什么夜间收到的中波电台比白天多
2021/03/01 无线电
xajax写的留言本
2006/11/25 PHP
一个php Mysql类 可以参考学习熟悉下
2009/06/21 PHP
php echo, print, print_r, sprintf, var_dump, var_expor的使用区别
2013/06/20 PHP
PHP中一些可以替代正则表达式函数的字符串操作函数
2014/11/17 PHP
php将HTML表格每行每列转为数组实现采集表格数据的方法
2015/04/03 PHP
利用PHP访问MySql数据库的逻辑操作以及增删改查的实例讲解
2017/08/30 PHP
JavaScript Cookie 直接浏览网站分网址
2009/12/08 Javascript
jquery处理json对象
2014/11/03 Javascript
JSON对象 详解及实例代码
2016/10/18 Javascript
bootstrap datetimepicker日期插件使用方法
2017/01/13 Javascript
vue中如何实现变量和字符串拼接
2017/06/19 Javascript
Bootstrap实现翻页效果
2017/11/27 Javascript
Bootstrap开发中Tab标签页切换图表显示问题的解决方法
2018/07/13 Javascript
简单了解node npm cnpm的具体使用方法
2019/02/27 Javascript
js回文数的4种判断方法示例
2019/06/04 Javascript
vue实现在线预览pdf文件和下载(pdf.js)
2019/11/26 Javascript
vue远程加载sfc组件思路详解
2019/12/25 Javascript
Vue 实现登录界面验证码功能
2020/01/03 Javascript
微信小程序仿抖音视频之整屏上下切换功能的实现代码
2020/05/24 Javascript
Vue切换div显示隐藏,多选,单选代码解析
2020/07/14 Javascript
浅谈JS for循环中使用break和continue的区别
2020/07/21 Javascript
详解Python的Django框架中的模版相关知识
2015/07/15 Python
python模块smtplib实现纯文本邮件发送功能
2018/05/22 Python
python脚本监控Tomcat服务器的方法
2018/07/06 Python
对python xlrd读取datetime类型数据的方法详解
2018/12/26 Python
Python使用pyserial进行串口通信的实例
2019/07/02 Python
Pymysql实现往表中插入数据过程解析
2020/06/02 Python
Python常用GUI框架原理解析汇总
2020/12/07 Python
css3 transform导致子元素固定定位变成绝对定位的方法
2020/03/06 HTML / CSS
师范院校学生自荐信范文
2013/12/27 职场文书
员工安全生产责任书
2014/07/22 职场文书
党员对照检查材料思想汇报
2014/09/16 职场文书
交通事故赔偿协议书怎么写
2014/10/04 职场文书
百年孤独读书笔记
2015/06/29 职场文书
Python PIL按比例裁剪图片
2022/05/11 Python