Keras模型转成tensorflow的.pb操作


Posted in Python onJuly 06, 2020

Keras的.h5模型转成tensorflow的.pb格式模型,方便后期的前端部署。直接上代码

from keras.models import Model
from keras.layers import Dense, Dropout
from keras.applications.mobilenet import MobileNet
from keras.applications.mobilenet import preprocess_input
from keras.preprocessing.image import load_img, img_to_array
import tensorflow as tf
from keras import backend as K
import os
 
base_model = MobileNet((None, None, 3), alpha=1, include_top=False, pooling='avg', weights=None)
x = Dropout(0.75)(base_model.output)
x = Dense(10, activation='softmax')(x)
 
model = Model(base_model.input, x)
model.load_weights('mobilenet_weights.h5')
 
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
 from tensorflow.python.framework.graph_util import convert_variables_to_constants
 graph = session.graph
 with graph.as_default():
  freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
  output_names = output_names or []
  output_names += [v.op.name for v in tf.global_variables()]
  input_graph_def = graph.as_graph_def()
  if clear_devices:
   for node in input_graph_def.node:
    node.device = ""
  frozen_graph = convert_variables_to_constants(session, input_graph_def,
             output_names, freeze_var_names)
  return frozen_graph
 
output_graph_name = 'NIMA.pb'
output_fld = ''
#K.set_learning_phase(0)
 
print('input is :', model.input.name)
print ('output is:', model.output.name)
 
sess = K.get_session()
frozen_graph = freeze_session(K.get_session(), output_names=[model.output.op.name])
 
from tensorflow.python.framework import graph_io
graph_io.write_graph(frozen_graph, output_fld, output_graph_name, as_text=False)
print('saved the constant graph (ready for inference) at: ', os.path.join(output_fld, output_graph_name))

补充知识:keras h5 model 转换为tflite

在移动端的模型,若选择tensorflow或者keras最基本的就是生成tflite文件,以本文记录一次转换过程。

环境

tensorflow 1.12.0

python 3.6.5

h5 model saved by `model.save('tf.h5')`

直接转换

`tflite_convert --output_file=tf.tflite --keras_model_file=tf.h5`
output
`TypeError: __init__() missing 2 required positional arguments: 'filters' and 'kernel_size'`

先转成pb再转tflite

```

git clone git@github.com:amir-abdi/keras_to_tensorflow.git
cd keras_to_tensorflow
python keras_to_tensorflow.py --input_model=path/to/tf.h5 --output_model=path/to/tf.pb
tflite_convert \

 --output_file=tf.tflite \
 --graph_def_file=tf.pb \
 --input_arrays=convolution2d_1_input \
 --output_arrays=dense_3/BiasAdd \
 --input_shape=1,3,448,448
```

参数说明,input_arrays和output_arrays是model的起始输入变量名和结束变量名,input_shape是和input_arrays对应

官网是说需要用到tenorboard来查看,一个比较trick的方法

先执行上面的命令,会报convolution2d_1_input找不到,在堆栈里面有convert_saved_model.py文件,get_tensors_from_tensor_names()这个方法,添加`print(list(tensor_name_to_tensor))` 到 tensor_name_to_tensor 这个变量下面,再执行一遍,会打印出所有tensor的名字,再根据自己的模型很容易就能判断出实际的name。

以上这篇Keras模型转成tensorflow的.pb操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python正则表达式匹配HTML页面编码
Apr 08 Python
通过源码分析Python中的切片赋值
May 08 Python
在java中如何定义一个抽象属性示例详解
Aug 18 Python
导入tensorflow时报错:cannot import name 'abs'的解决
Oct 10 Python
python爬虫爬取监控教务系统的思路详解
Jan 08 Python
浅谈ROC曲线的最佳阈值如何选取
Feb 28 Python
python随机模块random的22种函数(小结)
May 15 Python
Python爬虫设置ip代理过程解析
Jul 20 Python
Python过滤序列元素的方法
Jul 31 Python
python3通过subprocess模块调用脚本并和脚本交互的操作
Dec 05 Python
matplotlib绘制正余弦曲线图的实现
Feb 22 Python
Python进行区间取值案例讲解
Aug 02 Python
python如何进入交互模式
Jul 06 #Python
python3.4中清屏的处理方法
Jul 06 #Python
Python3基于print打印带颜色字符串
Jul 06 #Python
python判断是空的实例分享
Jul 06 #Python
python三引号如何输入
Jul 06 #Python
如何验证python安装成功
Jul 06 #Python
使用Keras训练好的.h5模型来测试一个实例
Jul 06 #Python
You might like
php下用cookie统计用户访问网页次数的代码
2010/05/09 PHP
PHP 二维数组和三维数组的过滤
2016/03/16 PHP
php array_walk_recursive 使用自定的函数处理数组中的每一个元素
2016/11/16 PHP
laravel 5.3 单用户登录简单实现方法
2019/10/14 PHP
window.ActiveXObject使用说明
2010/11/08 Javascript
javascript模拟实现ajax加载框实例
2014/10/15 Javascript
JavaScript File API文件上传预览
2016/02/02 Javascript
jQuery插件FusionCharts绘制的2D双面积图效果示例【附demo源码】
2017/04/11 jQuery
H5基于iScroll实现下拉刷新和上拉加载更多
2017/07/18 Javascript
es6学习之解构时应该注意的点
2017/08/29 Javascript
解决iview打包时UglifyJs报错的问题
2018/03/07 Javascript
JS实现的全选、全不选及反选功能【案例】
2019/02/19 Javascript
Vue中的组件及路由使用实例代码详解
2019/05/22 Javascript
在vue项目实现一个ctrl+f的搜索功能
2020/02/28 Javascript
JS对象属性的检测与获取操作实例分析
2020/03/17 Javascript
浅谈vue 组件中的setInterval方法和window的不同
2020/07/30 Javascript
解决vue项目中遇到 Cannot find module ‘chalk‘ 报错的问题
2020/11/05 Javascript
Django与遗留的数据库整合的方法指南
2015/07/24 Python
Python用zip函数同时遍历多个迭代器示例详解
2016/11/14 Python
Python基础学习之常见的内建函数整理
2017/09/06 Python
python+opencv实现动态物体追踪
2018/01/09 Python
python 处理dataframe中的时间字段方法
2018/04/10 Python
python和opencv实现抠图
2018/07/18 Python
Python读取excel中的图片完美解决方法
2018/07/27 Python
详解python分布式进程
2018/10/08 Python
python Tkinter版学生管理系统
2019/02/20 Python
【python】matplotlib动态显示详解
2019/04/11 Python
python 函数中的内置函数及用法详解
2019/07/02 Python
python实现WebSocket服务端过程解析
2019/10/18 Python
Django执行源生mysql语句实现过程解析
2020/11/12 Python
Yahoo-PHP面试题1
2016/07/20 面试题
简历上的自我评价
2014/02/03 职场文书
创建市级文明单位实施方案
2014/03/01 职场文书
项目申请汇报材料
2014/08/16 职场文书
几款流行的HTML5 UI框架比较(小结)
2021/04/08 HTML / CSS
SpringCloud的JPA连接PostgreSql的教程
2021/06/26 Java/Android