Keras模型转成tensorflow的.pb操作


Posted in Python onJuly 06, 2020

Keras的.h5模型转成tensorflow的.pb格式模型,方便后期的前端部署。直接上代码

from keras.models import Model
from keras.layers import Dense, Dropout
from keras.applications.mobilenet import MobileNet
from keras.applications.mobilenet import preprocess_input
from keras.preprocessing.image import load_img, img_to_array
import tensorflow as tf
from keras import backend as K
import os
 
base_model = MobileNet((None, None, 3), alpha=1, include_top=False, pooling='avg', weights=None)
x = Dropout(0.75)(base_model.output)
x = Dense(10, activation='softmax')(x)
 
model = Model(base_model.input, x)
model.load_weights('mobilenet_weights.h5')
 
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
 from tensorflow.python.framework.graph_util import convert_variables_to_constants
 graph = session.graph
 with graph.as_default():
  freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
  output_names = output_names or []
  output_names += [v.op.name for v in tf.global_variables()]
  input_graph_def = graph.as_graph_def()
  if clear_devices:
   for node in input_graph_def.node:
    node.device = ""
  frozen_graph = convert_variables_to_constants(session, input_graph_def,
             output_names, freeze_var_names)
  return frozen_graph
 
output_graph_name = 'NIMA.pb'
output_fld = ''
#K.set_learning_phase(0)
 
print('input is :', model.input.name)
print ('output is:', model.output.name)
 
sess = K.get_session()
frozen_graph = freeze_session(K.get_session(), output_names=[model.output.op.name])
 
from tensorflow.python.framework import graph_io
graph_io.write_graph(frozen_graph, output_fld, output_graph_name, as_text=False)
print('saved the constant graph (ready for inference) at: ', os.path.join(output_fld, output_graph_name))

补充知识:keras h5 model 转换为tflite

在移动端的模型,若选择tensorflow或者keras最基本的就是生成tflite文件,以本文记录一次转换过程。

环境

tensorflow 1.12.0

python 3.6.5

h5 model saved by `model.save('tf.h5')`

直接转换

`tflite_convert --output_file=tf.tflite --keras_model_file=tf.h5`
output
`TypeError: __init__() missing 2 required positional arguments: 'filters' and 'kernel_size'`

先转成pb再转tflite

```

git clone git@github.com:amir-abdi/keras_to_tensorflow.git
cd keras_to_tensorflow
python keras_to_tensorflow.py --input_model=path/to/tf.h5 --output_model=path/to/tf.pb
tflite_convert \

 --output_file=tf.tflite \
 --graph_def_file=tf.pb \
 --input_arrays=convolution2d_1_input \
 --output_arrays=dense_3/BiasAdd \
 --input_shape=1,3,448,448
```

参数说明,input_arrays和output_arrays是model的起始输入变量名和结束变量名,input_shape是和input_arrays对应

官网是说需要用到tenorboard来查看,一个比较trick的方法

先执行上面的命令,会报convolution2d_1_input找不到,在堆栈里面有convert_saved_model.py文件,get_tensors_from_tensor_names()这个方法,添加`print(list(tensor_name_to_tensor))` 到 tensor_name_to_tensor 这个变量下面,再执行一遍,会打印出所有tensor的名字,再根据自己的模型很容易就能判断出实际的name。

以上这篇Keras模型转成tensorflow的.pb操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python脚本实现数据导出excel格式的简单方法(推荐)
Dec 30 Python
关于Python面向对象编程的知识点总结
Feb 14 Python
Python学习小技巧之列表项的排序
May 20 Python
Django中利用filter与simple_tag为前端自定义函数的实现方法
Jun 15 Python
Python实现针对含中文字符串的截取功能示例
Sep 22 Python
Python多继承原理与用法示例
Aug 23 Python
python调用百度语音识别api
Aug 30 Python
ubuntu16.04制作vim和python3的开发环境
Sep 23 Python
python关于调用函数外的变量实例
Dec 26 Python
Python openpyxl模块实现excel读写操作
Jun 30 Python
Python使用for生成列表实现过程解析
Sep 22 Python
python 对图片进行简单的处理
Jun 23 Python
python如何进入交互模式
Jul 06 #Python
python3.4中清屏的处理方法
Jul 06 #Python
Python3基于print打印带颜色字符串
Jul 06 #Python
python判断是空的实例分享
Jul 06 #Python
python三引号如何输入
Jul 06 #Python
如何验证python安装成功
Jul 06 #Python
使用Keras训练好的.h5模型来测试一个实例
Jul 06 #Python
You might like
php中常用编辑器推荐
2007/01/02 PHP
攻克CakePHP系列三 表单数据增删改
2008/10/22 PHP
php封装的连接Mysql类及用法分析
2015/12/10 PHP
PHP二维数组排序简单实现方法
2016/02/14 PHP
PHP isset()与empty()的使用区别详解
2017/02/10 PHP
利用404错误页面实现UrlRewrite的实现代码
2008/08/20 Javascript
JavaScript中的面向对象介绍
2012/06/30 Javascript
form表单中去掉默认的enter键提交并绑定js方法实现代码
2013/04/01 Javascript
jQuery实现固定在网页顶部的菜单效果代码
2015/09/02 Javascript
ng-alain表单使用方式详解
2018/07/10 Javascript
如何理解Vue的v-model指令的使用方法
2018/07/19 Javascript
浅谈Angular7 项目开发总结
2018/12/19 Javascript
9102了,你还不会移动端真机调试吗
2019/03/25 Javascript
layer.js open 隐藏滚动条的例子
2019/09/05 Javascript
微信小程序实现时间进度条功能
2020/11/17 Javascript
python生成IP段的方法
2015/07/07 Python
Python实现基于TCP UDP协议的IPv4 IPv6模式客户端和服务端功能示例
2018/03/22 Python
PyCharm代码整体缩进,反向缩进的方法
2018/06/25 Python
python 解压pkl文件的方法
2018/10/25 Python
Tensorflow分类器项目自定义数据读入的实现
2019/02/05 Python
python按照多个条件排序的方法
2019/02/08 Python
python中的反斜杠问题深入讲解
2019/08/12 Python
python中的selenium安装的步骤(浏览器自动化测试框架)
2020/03/17 Python
PyTorch中torch.tensor与torch.Tensor的区别详解
2020/05/18 Python
Python pysnmp使用方法及代码实例
2020/08/24 Python
Python 数据分析之逐块读取文本的实现
2020/12/14 Python
详解CSS3中字体平滑处理和抗锯齿渲染
2017/03/29 HTML / CSS
澳大利亚领先的在线美容商店:Facial Co
2017/10/22 全球购物
澳大利亚购买太阳镜和眼镜网站:Glamoureyes
2020/09/22 全球购物
德国便宜的宠物店:Brekz.de
2020/10/23 全球购物
北京SQL新华信咨询
2016/09/30 面试题
关于礼仪的演讲稿
2014/01/04 职场文书
好学生评语大全
2014/05/05 职场文书
文明寝室标语
2014/06/13 职场文书
自主招生学校推荐信
2014/09/26 职场文书
2016创先争优活动党员公开承诺书
2016/03/24 职场文书