Keras模型转成tensorflow的.pb操作


Posted in Python onJuly 06, 2020

Keras的.h5模型转成tensorflow的.pb格式模型,方便后期的前端部署。直接上代码

from keras.models import Model
from keras.layers import Dense, Dropout
from keras.applications.mobilenet import MobileNet
from keras.applications.mobilenet import preprocess_input
from keras.preprocessing.image import load_img, img_to_array
import tensorflow as tf
from keras import backend as K
import os
 
base_model = MobileNet((None, None, 3), alpha=1, include_top=False, pooling='avg', weights=None)
x = Dropout(0.75)(base_model.output)
x = Dense(10, activation='softmax')(x)
 
model = Model(base_model.input, x)
model.load_weights('mobilenet_weights.h5')
 
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
 from tensorflow.python.framework.graph_util import convert_variables_to_constants
 graph = session.graph
 with graph.as_default():
  freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
  output_names = output_names or []
  output_names += [v.op.name for v in tf.global_variables()]
  input_graph_def = graph.as_graph_def()
  if clear_devices:
   for node in input_graph_def.node:
    node.device = ""
  frozen_graph = convert_variables_to_constants(session, input_graph_def,
             output_names, freeze_var_names)
  return frozen_graph
 
output_graph_name = 'NIMA.pb'
output_fld = ''
#K.set_learning_phase(0)
 
print('input is :', model.input.name)
print ('output is:', model.output.name)
 
sess = K.get_session()
frozen_graph = freeze_session(K.get_session(), output_names=[model.output.op.name])
 
from tensorflow.python.framework import graph_io
graph_io.write_graph(frozen_graph, output_fld, output_graph_name, as_text=False)
print('saved the constant graph (ready for inference) at: ', os.path.join(output_fld, output_graph_name))

补充知识:keras h5 model 转换为tflite

在移动端的模型,若选择tensorflow或者keras最基本的就是生成tflite文件,以本文记录一次转换过程。

环境

tensorflow 1.12.0

python 3.6.5

h5 model saved by `model.save('tf.h5')`

直接转换

`tflite_convert --output_file=tf.tflite --keras_model_file=tf.h5`
output
`TypeError: __init__() missing 2 required positional arguments: 'filters' and 'kernel_size'`

先转成pb再转tflite

```

git clone git@github.com:amir-abdi/keras_to_tensorflow.git
cd keras_to_tensorflow
python keras_to_tensorflow.py --input_model=path/to/tf.h5 --output_model=path/to/tf.pb
tflite_convert \

 --output_file=tf.tflite \
 --graph_def_file=tf.pb \
 --input_arrays=convolution2d_1_input \
 --output_arrays=dense_3/BiasAdd \
 --input_shape=1,3,448,448
```

参数说明,input_arrays和output_arrays是model的起始输入变量名和结束变量名,input_shape是和input_arrays对应

官网是说需要用到tenorboard来查看,一个比较trick的方法

先执行上面的命令,会报convolution2d_1_input找不到,在堆栈里面有convert_saved_model.py文件,get_tensors_from_tensor_names()这个方法,添加`print(list(tensor_name_to_tensor))` 到 tensor_name_to_tensor 这个变量下面,再执行一遍,会打印出所有tensor的名字,再根据自己的模型很容易就能判断出实际的name。

以上这篇Keras模型转成tensorflow的.pb操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
举例讲解Python中is和id的用法
Apr 03 Python
python使用reportlab实现图片转换成pdf的方法
May 22 Python
基于python脚本实现软件的注册功能(机器码+注册码机制)
Oct 09 Python
Scrapy爬虫实例讲解_校花网
Oct 23 Python
Python基于最小二乘法实现曲线拟合示例
Jun 14 Python
解决PyCharm不运行脚本,而是运行单元测试的问题
Jan 17 Python
python地震数据可视化详解
Jun 18 Python
Numpy的简单用法小结
Aug 28 Python
Pytorch 实现计算分类器准确率(总分类及子分类)
Jan 18 Python
Python Web项目Cherrypy使用方法镜像
Nov 05 Python
python 实现逻辑回归
Dec 30 Python
Python3自带工具2to3.py 转换 Python2.x 代码到Python3的操作
Mar 03 Python
python如何进入交互模式
Jul 06 #Python
python3.4中清屏的处理方法
Jul 06 #Python
Python3基于print打印带颜色字符串
Jul 06 #Python
python判断是空的实例分享
Jul 06 #Python
python三引号如何输入
Jul 06 #Python
如何验证python安装成功
Jul 06 #Python
使用Keras训练好的.h5模型来测试一个实例
Jul 06 #Python
You might like
php中使用preg_match_all匹配文章中的图片
2013/02/06 PHP
laravel 4安装及入门图文教程
2014/10/29 PHP
php中substr()函数参数说明及用法实例
2014/11/15 PHP
基于PHP实现假装商品限时抢购繁忙的效果
2015/10/16 PHP
php数据结构之顺序链表与链式线性表示例
2018/01/22 PHP
Yii框架安装简明教程
2020/05/15 PHP
jQuery中调用WebService方法小结
2011/03/28 Javascript
JQuery表格内容过滤的实现方法
2013/07/05 Javascript
JavaScript调用客户端的可执行文件(示例代码)
2013/11/28 Javascript
js控制浏览器全屏示例代码
2014/02/20 Javascript
使用upstart把nodejs应用封装为系统服务实例
2014/06/01 NodeJs
js中for in语句的用法讲解
2015/04/24 Javascript
DWR中各种java方法的调用
2016/05/04 Javascript
详解使用Vue.Js结合Jquery Ajax加载数据的两种方式
2017/01/10 Javascript
JS实现上传图片实时预览功能
2017/05/22 Javascript
vue-router3.0版本中 router.push 不能刷新页面的问题
2018/05/10 Javascript
vue-cli3访问public文件夹静态资源报错的解决方式
2020/09/02 Javascript
[43:58]DOTA2上海特级锦标赛C组败者赛 Newbee VS Archon第二局
2016/02/27 DOTA
基于Python实现的百度贴吧网络爬虫实例
2015/04/17 Python
[原创]教女朋友学Python(一)运行环境搭建
2017/11/29 Python
Python:Scrapy框架中Item Pipeline组件使用详解
2017/12/27 Python
Python爬虫实战:分析《战狼2》豆瓣影评
2018/03/26 Python
python 求1-100之间的奇数或者偶数之和的实例
2019/06/11 Python
关于numpy.where()函数 返回值的解释
2019/12/06 Python
Pyecharts绘制全球流向图的示例代码
2020/01/08 Python
Python如何操作office实现自动化及win32com.client的运用
2020/04/01 Python
Python selenium模拟手动操作实现无人值守刷积分功能
2020/05/13 Python
python 下划线的不同用法
2020/10/24 Python
一文带你掌握Pyecharts地理数据可视化的方法
2021/02/06 Python
超酷炫 CSS3垂直手风琴菜单
2016/06/28 HTML / CSS
iframe跨域的几种常用方法
2019/11/11 HTML / CSS
英国复古服装和球衣购买网站:3Retro Football
2018/07/09 全球购物
中学生旷课检讨书模板
2014/10/08 职场文书
2015年农村党员干部主题教育活动总结
2015/03/25 职场文书
关于PHP数组迭代器的使用方法实例
2021/11/17 PHP
python 管理系统实现mysql交互的示例代码
2021/12/06 Python