浅谈keras中Dropout在预测过程中是否仍要起作用


Posted in Python onJuly 09, 2020

因为需要,要重写训练好的keras模型,虽然只具备预测功能,但是发现还是有很多坑要趟过。其中Dropout这个坑,我记忆犹新。

一开始,我以为预测时要保持和训练时完全一样的网络结构,也就是预测时用的网络也是有丢弃的网络节点,但是这样想就掉进了一个大坑!因为无法通过已经训练好的模型,来获取其训练时随机丢弃的网络节点是那些,这本身就根本不可能。

更重要的是:我发现每一个迭代周期丢弃的神经元也不完全一样。

假若迭代500次,网络共有1000个神经元, 在第n(1<= n <500)个迭代周期内,从1000个神经元里随机丢弃了200个神经元,在n+1个迭代周期内,会在这1000个神经元里(不是在剩余得800个)重新随机丢弃200个神经元。

训练过程中,使用Dropout,其实就是对部分权重和偏置在某次迭代训练过程中,不参与计算和更新而已,并不是不再使用这些权重和偏置了(预测时,会使用全部的神经元,包括使用训练时丢弃的神经元)。

也就是说在预测过程中完全没有Dropout什么事了,他只是在训练时有用,特别是针对训练集比较小时防止过拟合非常有用。

补充知识:TensorFlow直接使用ckpt模型predict不用restore

我就废话不多说了,大家还是直接看代码吧~

# -*- coding: utf-8 -*-
# from util import *
import cv2
import numpy as np
import tensorflow as tf
# from tensorflow.python.framework import graph_util
import os

os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
image_path = './8760.pgm'

input_checkpoint = './model/xu_spatial_model_1340.ckpt'

sess = tf.Session()
saver = tf.train.import_meta_graph(input_checkpoint + '.meta')
saver.restore(sess, input_checkpoint)

# input:0作为输入图像,keep_prob:0作为dropout的参数,测试时值为1,is_training:0训练参数
input_image_tensor = sess.graph.get_tensor_by_name("coef_input:0")
is_training = sess.graph.get_tensor_by_name('is_training:0')
batch_size = sess.graph.get_tensor_by_name('batch_size:0')
# 定义输出的张量名称
output_tensor_name = sess.graph.get_tensor_by_name("xuNet/logits:0") # xuNet/Logits/logits
image = cv2.imread(image_path, 0)
# 读取测试图片
out = sess.run(output_tensor_name, feed_dict={input_image_tensor: np.reshape(image, (1, 512, 512, 1)),
                       is_training: False,
                       batch_size: 1})
print(out)

ckpt模型中的所有节点名称,可以这样查看

[n.name for n in tf.get_default_graph().as_graph_def().node]

以上这篇浅谈keras中Dropout在预测过程中是否仍要起作用就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 过滤字符串的技巧,map与itertools.imap
Sep 06 Python
Python开发实例分享bt种子爬虫程序和种子解析
May 21 Python
Python入门篇之列表和元组
Oct 17 Python
Python MySQLdb Linux下安装笔记
May 09 Python
Python中random模块生成随机数详解
Mar 10 Python
python logging日志模块的详解
Oct 29 Python
django之session与分页(实例讲解)
Nov 13 Python
Python之多线程爬虫抓取网页图片的示例代码
Jan 10 Python
对dataframe进行列相加,行相加的实例
Jun 08 Python
Python hashlib加密模块常用方法解析
Dec 18 Python
pytorch 实现cross entropy损失函数计算方式
Jan 02 Python
Python多线程Threading、子线程与守护线程实例详解
Mar 24 Python
在keras中对单一输入图像进行预测并返回预测结果操作
Jul 09 #Python
python求解汉诺塔游戏
Jul 09 #Python
Django中Aggregation聚合的基本使用方法
Jul 09 #Python
Python  word实现读取及导出代码解析
Jul 09 #Python
推荐技术人员一款Python开源库(造数据神器)
Jul 08 #Python
实例讲解Python 迭代器与生成器
Jul 08 #Python
opencv 阈值分割的具体使用
Jul 08 #Python
You might like
详解PHP中的Traits
2015/07/29 PHP
php中switch语句用法详解
2015/08/17 PHP
jQuery+css实现图片滚动效果(附源码)
2013/03/18 Javascript
用js格式化金额可设置保留的小数位数
2014/05/09 Javascript
js 获取浏览器版本以此来调整CSS的样式
2014/06/03 Javascript
使用jQuery实现图片遮罩半透明坠落遮挡
2015/03/16 Javascript
jQuery中trigger()与bind()用法分析
2015/12/18 Javascript
javascript计算渐变颜色的实例
2017/09/22 Javascript
vue.js轮播图组件使用方法详解
2018/07/03 Javascript
vue环形进度条组件实例应用
2018/10/10 Javascript
Vue2 添加数据可视化支持的方法步骤
2019/01/02 Javascript
jQuery实现input[type=file]多图预览上传删除等功能
2019/08/02 jQuery
简单了解微信小程序 e.target与e.currentTarget的不同
2019/09/27 Javascript
React倒计时功能实现代码——解耦通用
2020/09/18 Javascript
[02:18]《我与DAC》之工作人员:为了热爱DOTA2的玩家们
2018/03/28 DOTA
Python3爬虫学习入门教程
2018/12/11 Python
pandas分组聚合详解
2020/04/10 Python
python对接ihuyi实现短信验证码发送
2020/05/10 Python
win7上tensorflow2.2.0安装成功 引用DLL load failed时找不到指定模块 tensorflow has no attribute xxx 解决方法
2020/05/20 Python
html5 input输入实时检测以及延时优化
2018/07/18 HTML / CSS
世界首屈一指的钓鱼用品商店:TackleDirect
2016/07/26 全球购物
adidas官方旗舰店:德国运动用品制造商
2017/11/25 全球购物
泰国网上购物:Shopee泰国
2018/09/14 全球购物
澳大利亚网上买书:Angus & Robertson
2019/07/21 全球购物
美国领先的眼镜和太阳镜在线零售商:Glasses.com
2019/08/26 全球购物
贯彻学习两会心得体会范文
2014/03/17 职场文书
劲霸男装广告词改编版
2014/03/21 职场文书
升学宴主持词
2014/04/02 职场文书
公路绿化方案
2014/05/12 职场文书
小区文明倡议书
2014/05/16 职场文书
运动会铅球比赛加油稿
2014/09/26 职场文书
预备党员群众路线思想汇报2014
2014/10/25 职场文书
先进集体申报材料
2014/12/25 职场文书
2019年聘任书的写作格式及范文!
2019/07/03 职场文书
教你怎么用Python生成九宫格照片
2021/05/20 Python
Oracle 死锁的检测查询及处理
2021/09/25 Oracle