Keras模型转成tensorflow的.pb操作


Posted in Python onJuly 06, 2020

Keras的.h5模型转成tensorflow的.pb格式模型,方便后期的前端部署。直接上代码

from keras.models import Model
from keras.layers import Dense, Dropout
from keras.applications.mobilenet import MobileNet
from keras.applications.mobilenet import preprocess_input
from keras.preprocessing.image import load_img, img_to_array
import tensorflow as tf
from keras import backend as K
import os
 
base_model = MobileNet((None, None, 3), alpha=1, include_top=False, pooling='avg', weights=None)
x = Dropout(0.75)(base_model.output)
x = Dense(10, activation='softmax')(x)
 
model = Model(base_model.input, x)
model.load_weights('mobilenet_weights.h5')
 
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
 from tensorflow.python.framework.graph_util import convert_variables_to_constants
 graph = session.graph
 with graph.as_default():
  freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
  output_names = output_names or []
  output_names += [v.op.name for v in tf.global_variables()]
  input_graph_def = graph.as_graph_def()
  if clear_devices:
   for node in input_graph_def.node:
    node.device = ""
  frozen_graph = convert_variables_to_constants(session, input_graph_def,
             output_names, freeze_var_names)
  return frozen_graph
 
output_graph_name = 'NIMA.pb'
output_fld = ''
#K.set_learning_phase(0)
 
print('input is :', model.input.name)
print ('output is:', model.output.name)
 
sess = K.get_session()
frozen_graph = freeze_session(K.get_session(), output_names=[model.output.op.name])
 
from tensorflow.python.framework import graph_io
graph_io.write_graph(frozen_graph, output_fld, output_graph_name, as_text=False)
print('saved the constant graph (ready for inference) at: ', os.path.join(output_fld, output_graph_name))

补充知识:keras h5 model 转换为tflite

在移动端的模型,若选择tensorflow或者keras最基本的就是生成tflite文件,以本文记录一次转换过程。

环境

tensorflow 1.12.0

python 3.6.5

h5 model saved by `model.save('tf.h5')`

直接转换

`tflite_convert --output_file=tf.tflite --keras_model_file=tf.h5`
output
`TypeError: __init__() missing 2 required positional arguments: 'filters' and 'kernel_size'`

先转成pb再转tflite

```

git clone git@github.com:amir-abdi/keras_to_tensorflow.git
cd keras_to_tensorflow
python keras_to_tensorflow.py --input_model=path/to/tf.h5 --output_model=path/to/tf.pb
tflite_convert \

 --output_file=tf.tflite \
 --graph_def_file=tf.pb \
 --input_arrays=convolution2d_1_input \
 --output_arrays=dense_3/BiasAdd \
 --input_shape=1,3,448,448
```

参数说明,input_arrays和output_arrays是model的起始输入变量名和结束变量名,input_shape是和input_arrays对应

官网是说需要用到tenorboard来查看,一个比较trick的方法

先执行上面的命令,会报convolution2d_1_input找不到,在堆栈里面有convert_saved_model.py文件,get_tensors_from_tensor_names()这个方法,添加`print(list(tensor_name_to_tensor))` 到 tensor_name_to_tensor 这个变量下面,再执行一遍,会打印出所有tensor的名字,再根据自己的模型很容易就能判断出实际的name。

以上这篇Keras模型转成tensorflow的.pb操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python第三方库的安装方法总结
Jun 06 Python
Python实现向服务器请求压缩数据及解压缩数据的方法示例
Jun 09 Python
django自带的server 让外网主机访问方法
May 14 Python
浅谈python下tiff图像的读取和保存方法
Dec 04 Python
Python 运行.py文件和交互式运行代码的区别详解
Jul 02 Python
python 批量修改 labelImg 生成的xml文件的方法
Sep 09 Python
python进程间通信Queue工作过程详解
Nov 01 Python
Python 装饰器原理、定义与用法详解
Dec 07 Python
Python函数的定义方式与函数参数问题实例分析
Dec 26 Python
基于python实现ROC曲线绘制广场解析
Jun 28 Python
Tensorflow tensor 数学运算和逻辑运算方式
Jun 30 Python
python实现简单聊天功能
Jul 07 Python
python如何进入交互模式
Jul 06 #Python
python3.4中清屏的处理方法
Jul 06 #Python
Python3基于print打印带颜色字符串
Jul 06 #Python
python判断是空的实例分享
Jul 06 #Python
python三引号如何输入
Jul 06 #Python
如何验证python安装成功
Jul 06 #Python
使用Keras训练好的.h5模型来测试一个实例
Jul 06 #Python
You might like
laravel5使用freetds连接sql server的方法
2018/12/07 PHP
Laravel5.3+框架定义API路径取消CSRF保护方法详解
2020/04/06 PHP
在JavaScript中实现命名空间
2006/11/23 Javascript
用jquery来定位
2007/02/20 Javascript
Div自动滚动到末尾的代码
2008/10/26 Javascript
window.print打印指定div指定网页指定区域的方法
2014/08/04 Javascript
浅谈EasyUI中编辑treegrid的方法
2015/03/01 Javascript
js控件Kindeditor实现图片自动上传功能
2020/07/20 Javascript
JS JSOP跨域请求实例详解
2016/07/04 Javascript
JavaScript中误用/g导致的正则test()无法正确重复执行的解决方案
2016/07/27 Javascript
BootStrap实现邮件列表的分页和模态框添加邮件的功能
2016/10/13 Javascript
javaScript中定义类或对象的五种方式总结
2016/12/04 Javascript
移动端效果之Swiper详解
2017/10/09 Javascript
vue.js整合mint-ui里的轮播图实例代码
2017/12/27 Javascript
jQuery获取随机颜色的实例代码
2018/05/21 jQuery
微信小程序实现红包功能(后端PHP实现逻辑)
2018/07/11 Javascript
详解JavaScript事件循环机制
2018/09/07 Javascript
vue-cli 使用vue-bus来全局控制的实例讲解
2018/09/15 Javascript
AngularJS $http post 传递参数数据的方法
2018/10/09 Javascript
基于JQuery实现页面定时弹出广告
2020/05/08 jQuery
vue vant中picker组件的使用
2020/11/03 Javascript
[24:42]VP vs TNC Supermajor小组赛B组 BO3 第三场 6.2
2018/06/03 DOTA
Python实现的使用telnet登陆聊天室实例
2015/06/17 Python
分享Python开发中要注意的十个小贴士
2016/08/30 Python
Python中字典的setdefault()方法教程
2017/02/07 Python
Python学习笔记之Break和Continue用法分析
2019/08/14 Python
TensorBoard 计算图的查看方式
2020/02/15 Python
keras分类之二分类实例(Cat and dog)
2020/07/09 Python
Crucial英睿达法国官网:内存条及SSD固态硬盘升级
2018/07/13 全球购物
美国隐形眼镜零售商:LensPure
2019/03/10 全球购物
StubHub中国:购买和出售全球活动门票
2020/01/01 全球购物
俄罗斯连接商品和买家的在线平台:goods.ru
2020/11/30 全球购物
The North Face意大利官网:服装、背包和鞋子
2020/06/17 全球购物
教师读书活动总结
2014/05/07 职场文书
网络文明传播志愿者活动方案
2014/08/20 职场文书
vue里使用create, mounted调用方法
2022/04/26 Vue.js