将keras的h5模型转换为tensorflow的pb模型操作


Posted in Python onMay 25, 2020

背景:目前keras框架使用简单,很容易上手,深得广大算法工程师的喜爱,但是当部署到客户端时,可能会出现各种各样的bug,甚至不支持使用keras,本文来解决的是将keras的h5模型转换为客户端常用的tensorflow的pb模型并使用tensorflow加载pb模型。

h5_to_pb.py
 
from keras.models import load_model
import tensorflow as tf
import os 
import os.path as osp
from keras import backend as K
#路径参数
input_path = 'input path'
weight_file = 'weight.h5'
weight_file_path = osp.join(input_path,weight_file)
output_graph_name = weight_file[:-3] + '.pb'
#转换函数
def h5_to_pb(h5_model,output_dir,model_name,out_prefix = "output_",log_tensorboard = True):
  if osp.exists(output_dir) == False:
    os.mkdir(output_dir)
  out_nodes = []
  for i in range(len(h5_model.outputs)):
    out_nodes.append(out_prefix + str(i + 1))
    tf.identity(h5_model.output[i],out_prefix + str(i + 1))
  sess = K.get_session()
  from tensorflow.python.framework import graph_util,graph_io
  init_graph = sess.graph.as_graph_def()
  main_graph = graph_util.convert_variables_to_constants(sess,init_graph,out_nodes)
  graph_io.write_graph(main_graph,output_dir,name = model_name,as_text = False)
  if log_tensorboard:
    from tensorflow.python.tools import import_pb_to_tensorboard
    import_pb_to_tensorboard.import_to_tensorboard(osp.join(output_dir,model_name),output_dir)
#输出路径
output_dir = osp.join(os.getcwd(),"trans_model")
#加载模型
h5_model = load_model(weight_file_path)
h5_to_pb(h5_model,output_dir = output_dir,model_name = output_graph_name)
print('model saved')

将转换成的pb模型进行加载

load_pb.py
 
import tensorflow as tf
from tensorflow.python.platform import gfile
 
def load_pb(pb_file_path):
  sess = tf.Session()
  with gfile.FastGFile(pb_file_path, 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    sess.graph.as_default()
    tf.import_graph_def(graph_def, name='')
 
  print(sess.run('b:0'))
  #输入
  input_x = sess.graph.get_tensor_by_name('x:0')
  input_y = sess.graph.get_tensor_by_name('y:0')
  #输出
  op = sess.graph.get_tensor_by_name('op_to_store:0')
  #预测结果
  ret = sess.run(op, {input_x: 3, input_y: 4})
  print(ret)

补充知识:h5模型转化为pb模型,代码及排坑

我是在实际工程中要用到tensorflow训练的pb模型,但是训练的代码是用keras写的,所以生成keras特定的h5模型,所以用到了h5_to_pb.py函数。

附上h5_to_pb.py(python3)

#*-coding:utf-8-*

"""
将keras的.h5的模型文件,转换成TensorFlow的pb文件
"""
# ==========================================================

from keras.models import load_model
import tensorflow as tf
import os.path as osp
import os
from keras import backend
#from keras.models import Sequential

def h5_to_pb(h5_model, output_dir, model_name, out_prefix="output_", log_tensorboard=True):
  """.h5模型文件转换成pb模型文件
  Argument:
    h5_model: str
      .h5模型文件
    output_dir: str
      pb模型文件保存路径
    model_name: str
      pb模型文件名称
    out_prefix: str
      根据训练,需要修改
    log_tensorboard: bool
      是否生成日志文件
  Return:
    pb模型文件
  """
  if os.path.exists(output_dir) == False:
    os.mkdir(output_dir)
  out_nodes = []
  for i in range(len(h5_model.outputs)):
    out_nodes.append(out_prefix + str(i + 1))
    tf.identity(h5_model.output[i], out_prefix + str(i + 1))
  sess = backend.get_session()

  from tensorflow.python.framework import graph_util, graph_io
  # 写入pb模型文件
  init_graph = sess.graph.as_graph_def()
  main_graph = graph_util.convert_variables_to_constants(sess, init_graph, out_nodes)
  graph_io.write_graph(main_graph, output_dir, name=model_name, as_text=False)
  # 输出日志文件
  if log_tensorboard:
    from tensorflow.python.tools import import_pb_to_tensorboard
    import_pb_to_tensorboard.import_to_tensorboard(os.path.join(output_dir, model_name), output_dir)

if __name__ == '__main__':
  # .h模型文件路径参数
  input_path = 'D:/CSP'
  weight_file = 'xingren.h5'
  weight_file_path = os.path.join(input_path, weight_file)
  output_graph_name = weight_file[:-3] + '.pb'

  # pb模型文件输出输出路径
  output_dir = osp.join(os.getcwd(),"trans_model")
  #model.save(xingren.h5)
  # 加载模型
  #h5_model = Sequential()
  h5_model = load_model(weight_file_path)
  #h5_model.save(weight_file_path)
  #h5_model.save('xingren.h5')
  h5_to_pb(h5_model, output_dir=output_dir, model_name=output_graph_name)
  print ('Finished')

在运行的时候遇到了下面问题:

将keras的h5模型转换为tensorflow的pb模型操作

原因:我们训练模型的时候用save_weights函数保存模型,但是这个函数只保存了权重文件,并没有又保存模型的参数。要把save_weights改为save。

下边是两个函数介绍:

save()保存的模型结果,它既保持了模型的图结构,又保存了模型的参数。

save_weights()保存的模型结果,它只保存了模型的参数,但并没有保存模型的图结构

以上这篇将keras的h5模型转换为tensorflow的pb模型操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python yield使用方法示例
Dec 04 Python
Python中的类与对象之描述符详解
Mar 27 Python
Python+Wordpress制作小说站
Apr 14 Python
Python 装饰器实现DRY(不重复代码)原则
Mar 05 Python
pandas数值计算与排序方法
Apr 12 Python
对Python中gensim库word2vec的使用详解
May 08 Python
Python3随机漫步生成数据并绘制
Aug 27 Python
计算机二级python学习教程(3) python语言基本数据类型
May 16 Python
python启动应用程序和终止应用程序的方法
Jun 28 Python
python 列表、字典和集合的添加和删除操作
Dec 16 Python
Django通过设置CORS解决跨域问题
Nov 26 Python
python 逐步回归算法
Apr 06 Python
tensorflow转换ckpt为savermodel模型的实现
May 25 #Python
基于Python把网站域名解析成ip地址
May 25 #Python
使用keras和tensorflow保存为可部署的pb格式
May 25 #Python
Python使用configparser读取ini配置文件
May 25 #Python
浅谈tensorflow模型保存为pb的各种姿势
May 25 #Python
详解tensorflow2.x版本无法调用gpu的一种解决方法
May 25 #Python
keras模型保存为tensorflow的二进制模型方式
May 25 #Python
You might like
phpBB BBcode处理的漏洞
2006/10/09 PHP
PHP Token(令牌)设计
2008/03/15 PHP
php 输入输出流详解及示例代码
2016/08/25 PHP
PHP简单字符串过滤方法示例
2016/09/04 PHP
php-fpm添加service服务的例子
2018/04/27 PHP
关于laravel框架中的常用目录路径函数
2019/10/23 PHP
javascript中的array数组使用技巧
2010/01/31 Javascript
js本身的局限性 别让javascript做太多事
2010/03/23 Javascript
Javascript中实现String.startsWith和endsWith方法
2015/06/10 Javascript
分享一个插件实现水珠自动下落效果
2016/06/01 Javascript
angular基于路由控制ui-router实现系统权限控制
2016/09/27 Javascript
JQueryEasyUI之DataGrid数据显示
2016/11/23 Javascript
原生JS简单实现ajax的方法示例
2016/11/29 Javascript
nodejs操作mongodb的填删改查模块的制作及引入实例
2018/01/02 NodeJs
jquery 通过ajax请求获取后台数据显示在表格上的方法
2018/08/08 jQuery
webpack4.x CommonJS模块化浅析
2018/11/09 Javascript
Vue实现 点击显示再点击隐藏效果(点击页面空白区域也隐藏效果)
2020/01/16 Javascript
深入Python函数编程的一些特性
2015/04/13 Python
python数据结构之图的实现方法
2015/07/08 Python
快速了解Python开发中的cookie及简单代码示例
2018/01/17 Python
idea创建springMVC框架和配置小文件的教程图解
2018/09/18 Python
pycharm配置当鼠标悬停时快速提示方法参数
2019/07/31 Python
python源文件的字符编码知识点详解
2021/03/04 Python
Java工程师面试集锦之Spring框架
2013/06/16 面试题
北大青鸟学生求职信
2013/09/24 职场文书
社区学雷锋活动策划方案
2014/01/30 职场文书
庆中秋节主题活动方案
2014/02/03 职场文书
团队激励口号
2014/06/06 职场文书
上课随便讲话检讨书
2014/09/12 职场文书
党的群众路线教育实践活动专题组织生活会发言材料
2014/10/17 职场文书
典型事迹材料范文
2014/12/29 职场文书
因身体原因离职的辞职信范文
2015/05/12 职场文书
2015年普法依法治理工作总结
2015/05/26 职场文书
公司规章制度范本
2015/08/03 职场文书
如何获取numpy array前N个最大值
2021/05/14 Python
详解MySQL数据库千万级数据查询和存储
2021/05/18 MySQL