Keras模型转成tensorflow的.pb操作


Posted in Python onJuly 06, 2020

Keras的.h5模型转成tensorflow的.pb格式模型,方便后期的前端部署。直接上代码

from keras.models import Model
from keras.layers import Dense, Dropout
from keras.applications.mobilenet import MobileNet
from keras.applications.mobilenet import preprocess_input
from keras.preprocessing.image import load_img, img_to_array
import tensorflow as tf
from keras import backend as K
import os
 
base_model = MobileNet((None, None, 3), alpha=1, include_top=False, pooling='avg', weights=None)
x = Dropout(0.75)(base_model.output)
x = Dense(10, activation='softmax')(x)
 
model = Model(base_model.input, x)
model.load_weights('mobilenet_weights.h5')
 
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
 from tensorflow.python.framework.graph_util import convert_variables_to_constants
 graph = session.graph
 with graph.as_default():
  freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
  output_names = output_names or []
  output_names += [v.op.name for v in tf.global_variables()]
  input_graph_def = graph.as_graph_def()
  if clear_devices:
   for node in input_graph_def.node:
    node.device = ""
  frozen_graph = convert_variables_to_constants(session, input_graph_def,
             output_names, freeze_var_names)
  return frozen_graph
 
output_graph_name = 'NIMA.pb'
output_fld = ''
#K.set_learning_phase(0)
 
print('input is :', model.input.name)
print ('output is:', model.output.name)
 
sess = K.get_session()
frozen_graph = freeze_session(K.get_session(), output_names=[model.output.op.name])
 
from tensorflow.python.framework import graph_io
graph_io.write_graph(frozen_graph, output_fld, output_graph_name, as_text=False)
print('saved the constant graph (ready for inference) at: ', os.path.join(output_fld, output_graph_name))

补充知识:keras h5 model 转换为tflite

在移动端的模型,若选择tensorflow或者keras最基本的就是生成tflite文件,以本文记录一次转换过程。

环境

tensorflow 1.12.0

python 3.6.5

h5 model saved by `model.save('tf.h5')`

直接转换

`tflite_convert --output_file=tf.tflite --keras_model_file=tf.h5`
output
`TypeError: __init__() missing 2 required positional arguments: 'filters' and 'kernel_size'`

先转成pb再转tflite

```

git clone git@github.com:amir-abdi/keras_to_tensorflow.git
cd keras_to_tensorflow
python keras_to_tensorflow.py --input_model=path/to/tf.h5 --output_model=path/to/tf.pb
tflite_convert \

 --output_file=tf.tflite \
 --graph_def_file=tf.pb \
 --input_arrays=convolution2d_1_input \
 --output_arrays=dense_3/BiasAdd \
 --input_shape=1,3,448,448
```

参数说明,input_arrays和output_arrays是model的起始输入变量名和结束变量名,input_shape是和input_arrays对应

官网是说需要用到tenorboard来查看,一个比较trick的方法

先执行上面的命令,会报convolution2d_1_input找不到,在堆栈里面有convert_saved_model.py文件,get_tensors_from_tensor_names()这个方法,添加`print(list(tensor_name_to_tensor))` 到 tensor_name_to_tensor 这个变量下面,再执行一遍,会打印出所有tensor的名字,再根据自己的模型很容易就能判断出实际的name。

以上这篇Keras模型转成tensorflow的.pb操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现全局变量的两个解决方法
Jul 03 Python
Python命令行参数解析模块optparse使用实例
Apr 13 Python
解决python爬虫中有中文的url问题
May 11 Python
python bmp转换为jpg 并删除原图的方法
Oct 25 Python
Python当中的array数组对象实例详解
Jun 12 Python
python实现身份证实名认证的方法实例
Nov 08 Python
python使用协程实现并发操作的方法详解
Dec 27 Python
Python tkinter常用操作代码实例
Jan 03 Python
python3排序的实例方法
Oct 20 Python
python+appium+yaml移动端自动化测试框架实现详解
Nov 24 Python
python中Tkinter 窗口之输入框和文本框的实现
Apr 12 Python
解决jupyter notebook图片显示模糊和保存清晰图片的操作
Apr 24 Python
python如何进入交互模式
Jul 06 #Python
python3.4中清屏的处理方法
Jul 06 #Python
Python3基于print打印带颜色字符串
Jul 06 #Python
python判断是空的实例分享
Jul 06 #Python
python三引号如何输入
Jul 06 #Python
如何验证python安装成功
Jul 06 #Python
使用Keras训练好的.h5模型来测试一个实例
Jul 06 #Python
You might like
了解咖啡雨林联盟认证 什么是雨林认证 雨林认证是什么意思
2021/03/05 新手入门
模拟OICQ的实现思路和核心程序(三)
2006/10/09 PHP
Zend Framework框架中实现Ajax的方法示例
2017/06/27 PHP
PHP流Streams、包装器wrapper概念与用法实例详解
2017/11/17 PHP
微信公众号开发之获取位置信息php代码
2018/06/13 PHP
Laravel访问出错提示:`Warning: require(/vendor/autoload.php): failed to open stream: No such file or di解决方法
2019/04/02 PHP
js 处理URL实用技巧
2010/11/23 Javascript
Javascript继承(上)——对象构建介绍
2012/11/08 Javascript
Backbone.js的Hello World程序实例
2015/06/19 Javascript
animate 实现滑动切换效果【实例代码】
2016/05/05 Javascript
url中的特殊符号有什么含义(推荐)
2016/06/17 Javascript
AngularJs Javascript MVC 框架
2016/06/20 Javascript
Js操作DOM元素及获取浏览器高宽的简单方法
2016/09/08 Javascript
原生js封装自定义滚动条
2017/03/24 Javascript
中级前端工程师必须要掌握的27个JavaScript 技巧(干货总结)
2019/09/23 Javascript
JavaScript或jQuery 获取option value值方法解析
2020/05/12 jQuery
Vue检测屏幕变化来改变不同的charts样式实例
2020/10/26 Javascript
Python写的PHPMyAdmin暴力破解工具代码
2014/08/06 Python
Python实现的飞速中文网小说下载脚本
2015/04/23 Python
python实现查找两个字符串中相同字符并输出的方法
2015/07/11 Python
不管你的Python报什么错,用这个模块就能正常运行
2018/09/14 Python
python判断完全平方数的方法
2018/11/13 Python
win10子系统python开发环境准备及kenlm和nltk的使用教程
2019/10/14 Python
Python3.6 中的pyinstaller安装和使用教程
2020/03/16 Python
Python socket连接中的粘包、精确传输问题实例分析
2020/03/24 Python
Python实现多线程下载脚本的示例代码
2020/04/03 Python
30行Python代码实现高分辨率图像导航的方法
2020/05/22 Python
Python如何实现大型数组运算(使用NumPy)
2020/07/24 Python
HTML5标签嵌套规则详解【必看】
2016/04/26 HTML / CSS
草莓网化妆品加拿大网站:Strawberrynet Canada
2016/09/20 全球购物
《鹬蚌相争》教学反思
2014/04/22 职场文书
农村党支部书记四风问题个人对照检查材料
2014/09/21 职场文书
2014年租房协议书范本
2014/10/30 职场文书
远程教育培训心得体会
2016/01/09 职场文书
一篇文章弄懂MySQL查询语句的执行过程
2021/05/07 MySQL
MySQL中datetime时间字段的四舍五入操作
2021/10/05 MySQL