使用Keras训练好的.h5模型来测试一个实例


Posted in Python onJuly 06, 2020

环境:python 3.6 +opencv3+Keras

训练集:MNIST

下面划重点:因为MNIST使用的是黑底白字的图片,所以你自己手写数字的时候一定要注意把得到的图片也改成黑底白字的,否则会识别错(至少我得到的结论是这样的 ,之前用白底黑字的图总是识别出错)

注意:需要测试图片需要为与训练模时相同大小的图片,RGB图像需转为gray

代码:

import cv2
import numpy as np
from keras.models import load_model

model = load_model('fm_cnn_BN.h5') #选取自己的.h模型名称
image = cv2.imread('6_b.png')
img = cv2.cvtColor(image,cv2.COLOR_RGB2GRAY) # RGB图像转为gray

#需要用reshape定义出例子的个数,图片的 通道数,图片的长与宽。具体的参加keras文档
img = (img.reshape(1, 1, 28, 28)).astype('int32')/255 
predict = model.predict_classes(img)
print ('识别为:')
print (predict)

cv2.imshow("Image1", image)
cv2.waitKey(0)

补充知识:keras转tf并加速(1)Keras转TensorFlow,并调用转换后模型进行预测

由于方便快捷,所以先使用Keras来搭建网络并进行训练,得到比较好的模型后,这时候就该考虑做成服务使用的问题了,TensorFlow的serving就很合适,所以需要把Keras保存的模型转为TensorFlow格式来使用。

Keras模型转TensorFlow

其实由于TensorFlow本身以及把Keras作为其高层简化API,且也是建议由浅入深地来研究应用,TensorFlow本身就对Keras的模型格式转化有支持,所以核心的代码很少。这里给出一份代码:https://github.com/amir-abdi/keras_to_tensorflow,作者提供了一份很好的工具,能够满足绝大多数人的需求了。原理很简单:原理很简单,首先用 Keras 读取 .h5 模型文件,然后用 tensorflow 的 convert_variables_to_constants 函数将所有变量转换成常量,最后再 write_graph 就是一个包含了网络以及参数值的 .pb 文件了。

如果你的Keras模型是一个包含了网络结构和权重的h5文件,那么使用下面的命令就可以了:

python keras_to_tensorflow.py 
 --input_model="path/to/keras/model.h5" 
 --output_model="path/to/save/model.pb"

两个参数,一个输入路径,一个输出路径。输出路径即使你没创建好,代码也会帮你创建。建议使用绝对地址。此外作者还做了很多选项,比如如果你的keras模型文件分为网络结构和权重两个文件也可以支持,或者你想给转化后的网络节点编号,或者想在TensorFlow下继续训练等等,这份代码都是支持的,只是使用上需要输入不同的参数来设置。

如果转换成功则输出如下:

begin====================================================
I1229 14:29:44.819010 140709034264384 keras_to_tf.py:119] Input nodes names are: [u'input_1']
I1229 14:29:44.819385 140709034264384 keras_to_tf.py:137] Converted output node names are: [u'dense_2/Sigmoid']
INFO:tensorflow:Froze 322 variables.
I1229 14:29:47.091161 140709034264384 tf_logging.py:82] Froze 322 variables.
Converted 322 variables to const ops.
I1229 14:29:48.504235 140709034264384 keras_to_tf.py:170] Saved the freezed graph at /path/to/save/model.pb

这里首先把输入的层和输出的层名字给出来了,也就是“input_1”和“dense_2/Sigmoid”,这两个下面会用到。另外还告诉你冻结了多少个变量,以及你输出的模型路径,pb文件就是TensorFlow下的模型文件。

使用TensorFlow模型

转换后我们当然要使用一下看是否转换成功,其实也就是TensorFlow的常见代码,如果只用过Keras的,可以参考一下:

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import tensorflow as tf
import numpy as np
from tensorflow.python.platform import gfile
import cv2
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "6"
 
# img = cv2.imread(os.path.expanduser('/test_imgs/img_1.png'))
# img = cv2.resize(img, dsize=(1000, 1000), interpolation=cv2.INTER_LINEAR)
# img = img.astype(float)
# img /= 255
# img = np.array([img])
 
# 初始化TensorFlow的session
with tf.Session() as sess:
 # 读取得到的pb文件加载模型
 with gfile.FastGFile("/path/to/save/model.pb",'rb') as f:
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())
 # 把图加到session中
 tf.import_graph_def(graph_def, name='')
 
 # 获取当前计算图
 graph = tf.get_default_graph()
 
 # 从图中获输出那一层
 pred = graph.get_tensor_by_name("dense_2/Sigmoid:0")
 
 # 运行并预测输入的img
 res = sess.run(pred, feed_dict={"input_1:0": img})
 
 # 执行得到结果
 pred_index = res[0][0]
 print('Predict:', pred_index)

在代码中可以看到,我们用到了上面得到的输入层和输出层的名称,但是在后面加了一个“:0”,也就是索引,因为名称只是指定了一个层,大部分层的输出都是一个tensor,但依然有输出多个tensor的层,所以需要制定是第几个输出,对于一个输出的情况,那就是索引0了。输入同理。

如果你输出res,会得到这样的结果:

('Predict:', array([[0.9998584]], dtype=float32))

这也就是为什么我们要取res[0][0]了,这个输出其实取决于具体的需求,因为这里我是对一张图做二分类预测,所以会得到这样一个结果

运行的结果如果和使用Keras模型时一样,那就说明转换成功了!

以上这篇使用Keras训练好的.h5模型来测试一个实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python的Tornado框架异步编程入门实例
Apr 24 Python
python实现搜索本地文件信息写入文件的方法
Feb 22 Python
使用python3.5仿微软记事本notepad
Jun 15 Python
Python中%r和%s的详解及区别
Mar 16 Python
Python读取和处理文件后缀为.sqlite的数据文件(实例讲解)
Jun 27 Python
python绘制热力图heatmap
Mar 23 Python
python用列表生成式写嵌套循环的方法
Nov 08 Python
解决pyinstaller打包pyqt5的问题
Jan 08 Python
Python+PyQt5实现美剧爬虫可视工具的方法
Apr 25 Python
将tensorflow.Variable中的某些元素取出组成一个新的矩阵示例
Jan 04 Python
python使用Thread的setDaemon启动后台线程教程
Apr 25 Python
基于Python实现天天酷跑功能
Jan 06 Python
Keras实现DenseNet结构操作
Jul 06 #Python
基于Python和C++实现删除链表的节点
Jul 06 #Python
基于Python 的语音重采样函数解析
Jul 06 #Python
python interpolate插值实例
Jul 06 #Python
基于Python实现2种反转链表方法代码实例
Jul 06 #Python
简单了解Django项目应用创建过程
Jul 06 #Python
如何在mac下配置python虚拟环境
Jul 06 #Python
You might like
Smarty结合Ajax实现无刷新留言本实例
2007/01/02 PHP
PHP+.htaccess实现全站静态HTML文件GZIP压缩传输(一)
2007/02/15 PHP
PHP原生模板引擎 最简单的模板引擎
2012/04/25 PHP
PHP FTP操作类代码( 上传、拷贝、移动、删除文件/创建目录)
2014/05/10 PHP
PHP实现图片不变型裁剪及图片按比例裁剪的方法
2016/01/14 PHP
Laravel5中防止XSS跨站攻击的方法
2016/10/10 PHP
ThinkPHP 5 AJAX跨域请求头设置实现过程解析
2020/10/28 PHP
用Javascript评估用户输入密码的强度实现代码
2011/11/30 Javascript
JS定义回车事件(实现代码)
2013/07/08 Javascript
jquery按回车提交数据的代码示例
2013/11/05 Javascript
jQuery动态显示和隐藏datagrid中的某一列的方法
2013/12/11 Javascript
jquery实现的仿天猫侧导航tab切换效果
2015/08/24 Javascript
js实现简洁大方的二级下拉菜单效果代码
2015/09/01 Javascript
jquery遍历函数siblings()用法实例
2015/12/24 Javascript
JavaScript+html5 canvas绘制渐变区域完整实例
2016/01/26 Javascript
Bootstrap栅格系统的使用详解
2017/10/30 Javascript
基于jQuery实现无缝轮播与左右点击效果
2018/05/13 jQuery
vue实现菜单切换功能
2019/05/08 Javascript
使用npm命令提示: 'npm' 不是内部或外部命令,也不是可运行的程序的处理方法
2020/05/14 Javascript
vue单文件组件无法获取$refs的问题
2020/06/24 Javascript
vue路由的配置和页面切换详解
2020/09/09 Javascript
wxpython 学习笔记 第一天
2009/03/16 Python
在Python中使用成员运算符的示例
2015/05/13 Python
Python 爬虫的工具列表大全
2016/01/31 Python
pyQt5实时刷新界面的示例
2019/06/25 Python
Python修改DBF文件指定列
2020/12/19 Python
NIHAOMARKET官方海外旗舰店:意大利你好华人超市
2018/01/27 全球购物
意大利制造的西装、衬衫和针对男士量身定制的服装:Lanieri
2018/04/08 全球购物
报社实习生自荐信
2014/01/24 职场文书
《小石潭记》教学反思
2014/02/13 职场文书
篝火晚会主持词
2014/03/25 职场文书
销售岗位职责范本
2014/06/12 职场文书
2014国庆节商场促销活动策划方案
2014/09/16 职场文书
授权委托书
2014/09/17 职场文书
公司总经理岗位职责
2015/04/01 职场文书
2019年暑期安全广播稿!
2019/07/03 职场文书