浅谈keras中Dropout在预测过程中是否仍要起作用


Posted in Python onJuly 09, 2020

因为需要,要重写训练好的keras模型,虽然只具备预测功能,但是发现还是有很多坑要趟过。其中Dropout这个坑,我记忆犹新。

一开始,我以为预测时要保持和训练时完全一样的网络结构,也就是预测时用的网络也是有丢弃的网络节点,但是这样想就掉进了一个大坑!因为无法通过已经训练好的模型,来获取其训练时随机丢弃的网络节点是那些,这本身就根本不可能。

更重要的是:我发现每一个迭代周期丢弃的神经元也不完全一样。

假若迭代500次,网络共有1000个神经元, 在第n(1<= n <500)个迭代周期内,从1000个神经元里随机丢弃了200个神经元,在n+1个迭代周期内,会在这1000个神经元里(不是在剩余得800个)重新随机丢弃200个神经元。

训练过程中,使用Dropout,其实就是对部分权重和偏置在某次迭代训练过程中,不参与计算和更新而已,并不是不再使用这些权重和偏置了(预测时,会使用全部的神经元,包括使用训练时丢弃的神经元)。

也就是说在预测过程中完全没有Dropout什么事了,他只是在训练时有用,特别是针对训练集比较小时防止过拟合非常有用。

补充知识:TensorFlow直接使用ckpt模型predict不用restore

我就废话不多说了,大家还是直接看代码吧~

# -*- coding: utf-8 -*-
# from util import *
import cv2
import numpy as np
import tensorflow as tf
# from tensorflow.python.framework import graph_util
import os

os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
image_path = './8760.pgm'

input_checkpoint = './model/xu_spatial_model_1340.ckpt'

sess = tf.Session()
saver = tf.train.import_meta_graph(input_checkpoint + '.meta')
saver.restore(sess, input_checkpoint)

# input:0作为输入图像,keep_prob:0作为dropout的参数,测试时值为1,is_training:0训练参数
input_image_tensor = sess.graph.get_tensor_by_name("coef_input:0")
is_training = sess.graph.get_tensor_by_name('is_training:0')
batch_size = sess.graph.get_tensor_by_name('batch_size:0')
# 定义输出的张量名称
output_tensor_name = sess.graph.get_tensor_by_name("xuNet/logits:0") # xuNet/Logits/logits
image = cv2.imread(image_path, 0)
# 读取测试图片
out = sess.run(output_tensor_name, feed_dict={input_image_tensor: np.reshape(image, (1, 512, 512, 1)),
                       is_training: False,
                       batch_size: 1})
print(out)

ckpt模型中的所有节点名称,可以这样查看

[n.name for n in tf.get_default_graph().as_graph_def().node]

以上这篇浅谈keras中Dropout在预测过程中是否仍要起作用就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python Web框架Flask中使用新浪SAE云存储实例
Feb 08 Python
在Python的Django框架中更新数据库数据的方法
Jul 17 Python
对Pycharm创建py文件时自定义头部模板的方法详解
Feb 12 Python
Pythony运维入门之Socket网络编程详解
Apr 15 Python
python多进程读图提取特征存npy
May 21 Python
在python里面运用多继承方法详解
Jul 01 Python
Python学习笔记之Django创建第一个数据库模型的方法
Aug 07 Python
使用turtle绘制五角星、分形树
Oct 06 Python
python中with语句结合上下文管理器操作详解
Dec 19 Python
Python 虚拟环境工作原理解析
Dec 24 Python
python中subplot大小的设置步骤
Jun 28 Python
Python函数中apply、map、applymap的区别
Nov 27 Python
在keras中对单一输入图像进行预测并返回预测结果操作
Jul 09 #Python
python求解汉诺塔游戏
Jul 09 #Python
Django中Aggregation聚合的基本使用方法
Jul 09 #Python
Python  word实现读取及导出代码解析
Jul 09 #Python
推荐技术人员一款Python开源库(造数据神器)
Jul 08 #Python
实例讲解Python 迭代器与生成器
Jul 08 #Python
opencv 阈值分割的具体使用
Jul 08 #Python
You might like
使用php+xslt在windows平台上
2006/10/09 PHP
php max_execution_time执行时间问题
2011/07/17 PHP
php cookie名使用点号(句号)会被转换
2014/10/23 PHP
Laravel接收前端ajax传来的数据的实例代码
2017/07/20 PHP
PHP PDOStatement::debugDumpParams讲解
2019/01/30 PHP
JQuery 构建客户/服务分离的链接模型中Table分页代码效率初探
2010/01/22 Javascript
SOSO地图API使用(一)在地图上画圆实现思路与代码
2013/01/15 Javascript
JS+CSS实现下拉列表框美化效果(3款)
2015/08/15 Javascript
Angular发布1.5正式版,专注于向Angular 2的过渡
2016/02/18 Javascript
Javascript+CSS3实现进度条效果
2016/10/28 Javascript
Javascript 跨域知识详细介绍
2016/10/30 Javascript
bootstrap datetimepicker实现秒钟选择下拉框
2017/01/05 Javascript
JavaScript实现星级评分
2017/01/12 Javascript
微信小程序 devtool隐藏的秘密
2017/01/21 Javascript
js canvas实现简单的图像扩散效果
2020/06/28 Javascript
Angular之toDoList的实现代码示例
2017/12/02 Javascript
微信小程序云开发之使用云数据库
2019/05/17 Javascript
js模拟实现百度搜索
2020/06/28 Javascript
python中django框架通过正则搜索页面上email地址的方法
2015/03/21 Python
关于python列表增加元素的三种操作方法
2018/08/22 Python
python获取全国城市pm2.5、臭氧等空气质量过程解析
2019/10/12 Python
Python通过正则库爬取淘宝商品信息代码实例
2020/03/02 Python
Python3 操作 MySQL 插入一条数据并返回主键 id的实例
2020/03/02 Python
Python使用lambda抛出异常实现方法解析
2020/08/20 Python
Django日志及中间件模块应用案例
2020/09/10 Python
python爬虫爬取某网站视频的示例代码
2021/02/20 Python
初中科学教学反思
2014/01/21 职场文书
仓库管理制度
2014/01/21 职场文书
索桥的故事教学反思
2014/02/06 职场文书
劲霸男装广告词
2014/03/21 职场文书
物流专业自荐信
2014/05/23 职场文书
后备干部推荐材料
2014/12/24 职场文书
求职简历自我评价2015
2015/03/10 职场文书
2015年社区妇联工作总结
2015/04/21 职场文书
python 经纬度求两点距离、三点面积操作
2021/06/03 Python
Shell中的单中括号和双中括号的用法详解
2022/12/24 Servers