pytorch 实现在测试的时候启用dropout


Posted in Python onMay 27, 2021

我们知道,dropout一般都在训练的时候使用,那么测试的时候如何也开启dropout呢?

在pytorch中,网络有train和eval两种模式,在train模式下,dropout和batch normalization会生效,而val模式下,dropout不生效,bn固定参数。

想要在测试的时候使用dropout,可以把dropout单独设为train模式,这里可以使用apply函数:

def apply_dropout(m):
    if type(m) == nn.Dropout:
        m.train()

下面是完整demo代码:

# coding: utf-8
import torch
import torch.nn as nn
import numpy as np
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc = nn.Linear(8, 8)
        self.dropout = nn.Dropout(0.5)
    def forward(self, x):
        x = self.fc(x)
        x = self.dropout(x)
        return x
net = SimpleNet()
x = torch.FloatTensor([1]*8)
net.train()
y = net(x)
print('train mode result: ', y)
net.eval()
y = net(x)
print('eval mode result: ', y)
net.eval()
y = net(x)
print('eval2 mode result: ', y)
def apply_dropout(m):
    if type(m) == nn.Dropout:
        m.train()
net.eval()
net.apply(apply_dropout)
y = net(x)
print('apply eval result:', y)

运行结果:

pytorch 实现在测试的时候启用dropout

可以看到,在eval模式下,由于dropout未生效,每次跑的结果不同,利用apply函数,将Dropout单独设为train模式,dropout就生效了。

补充:Pytorch之dropout避免过拟合测试

一.做数据

pytorch 实现在测试的时候启用dropout

pytorch 实现在测试的时候启用dropout

二.搭建神经网络

pytorch 实现在测试的时候启用dropout

pytorch 实现在测试的时候启用dropout

三.训练

pytorch 实现在测试的时候启用dropout

四.对比测试结果

注意:测试过程中,一定要注意模式切换

pytorch 实现在测试的时候启用dropout

pytorch 实现在测试的时候启用dropout

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现倒计时的示例
Feb 14 Python
Python常用内置函数总结
Feb 08 Python
python对指定目录下文件进行批量重命名的方法
Apr 18 Python
Python正则表达式使用经典实例
Jun 21 Python
基于asyncio 异步协程框架实现收集B站直播弹幕
Sep 11 Python
Python实现Logger打印功能的方法详解
Sep 01 Python
PyQt5 对图片进行缩放的实例
Jun 18 Python
Python 脚本拉取 Docker 镜像问题
Nov 10 Python
通过代码实例了解Python异常本质
Sep 16 Python
Python+OpenCV图像处理——实现轮廓发现
Oct 23 Python
pycharm debug 断点调试心得分享
Apr 16 Python
Python中np.random.randint()参数详解及用法实例
Sep 23 Python
使用Python脚本对GiteePages进行一键部署的使用说明
教你使用Python pypinyin库实现汉字转拼音
基于tensorflow权重文件的解读
May 26 #Python
解决Python字典查找报Keyerror的问题
浅谈tf.train.Saver()与tf.train.import_meta_graph的要点
tensorflow中的数据类型dtype用法说明
May 26 #Python
详解Python魔法方法之描述符类
May 26 #Python
You might like
phpmyadmin 常用选项设置详解版
2010/03/07 PHP
关于php curl获取301或302转向的网址问题的解决方法
2011/06/02 PHP
PHP入门经历和学习过程分享
2014/04/11 PHP
PHP网页游戏学习之Xnova(ogame)源码解读(一)
2014/06/23 PHP
PHP下通过QRCode类库创建中间带网站LOGO的二维码
2014/07/12 PHP
php中__destruct与register_shutdown_function执行的先后顺序问题
2014/10/17 PHP
php 三大特点:封装,继承,多态
2017/02/19 PHP
PHP经典实用正则表达式小结
2017/05/04 PHP
脚本安需导入(装载)的三种模式的对比
2007/06/24 Javascript
重载toString实现JS HashMap分析
2011/03/13 Javascript
javascript实现在网页任意处点左键弹出隐藏菜单的方法
2015/05/13 Javascript
Web前端新人笔记之jquery入门心得(新手必看)
2016/05/17 Javascript
javascript的函数劫持浅析
2016/09/26 Javascript
flag和jq on 的绑定多个对象和方法(必看)
2017/02/27 Javascript
js实现下一页页码效果
2017/03/07 Javascript
vue elementUI 表单校验功能之数组多层嵌套
2019/06/04 Javascript
Layui Form 自定义验证的实例代码
2019/09/14 Javascript
countup.js实现数字动态叠加效果
2019/10/17 Javascript
python动态性强类型用法实例
2015/05/09 Python
关于Python作用域自学总结
2019/06/10 Python
Python学习笔记之While循环用法分析
2019/08/14 Python
Python Django框架防御CSRF攻击的方法分析
2019/10/18 Python
Python matplotlib以日期为x轴作图代码实例
2019/11/22 Python
tensorflow 报错unitialized value的解决方法
2020/02/06 Python
Django 404、500页面全局配置知识点详解
2020/03/10 Python
Python参数传递机制传值和传引用原理详解
2020/05/22 Python
python报错: 'list' object has no attribute 'shape'的解决
2020/07/15 Python
详解如何在css中引入自定义字体(font-face)
2018/05/17 HTML / CSS
美国美妆网站:B-Glowing
2016/10/12 全球购物
COACH德国官方网站:纽约现代奢侈品牌,1941年
2018/06/09 全球购物
工程业务员工作职责
2013/12/07 职场文书
英文自荐信常用句子
2014/03/26 职场文书
2015年人事专员工作总结
2015/04/29 职场文书
停发工资证明范本
2015/06/12 职场文书
详解Python如何批量采集京东商品数据流程
2022/01/22 Python
MySQL常用慢查询分析工具详解
2022/08/14 MySQL