将keras的h5模型转换为tensorflow的pb模型操作


Posted in Python onMay 25, 2020

背景:目前keras框架使用简单,很容易上手,深得广大算法工程师的喜爱,但是当部署到客户端时,可能会出现各种各样的bug,甚至不支持使用keras,本文来解决的是将keras的h5模型转换为客户端常用的tensorflow的pb模型并使用tensorflow加载pb模型。

h5_to_pb.py
 
from keras.models import load_model
import tensorflow as tf
import os 
import os.path as osp
from keras import backend as K
#路径参数
input_path = 'input path'
weight_file = 'weight.h5'
weight_file_path = osp.join(input_path,weight_file)
output_graph_name = weight_file[:-3] + '.pb'
#转换函数
def h5_to_pb(h5_model,output_dir,model_name,out_prefix = "output_",log_tensorboard = True):
  if osp.exists(output_dir) == False:
    os.mkdir(output_dir)
  out_nodes = []
  for i in range(len(h5_model.outputs)):
    out_nodes.append(out_prefix + str(i + 1))
    tf.identity(h5_model.output[i],out_prefix + str(i + 1))
  sess = K.get_session()
  from tensorflow.python.framework import graph_util,graph_io
  init_graph = sess.graph.as_graph_def()
  main_graph = graph_util.convert_variables_to_constants(sess,init_graph,out_nodes)
  graph_io.write_graph(main_graph,output_dir,name = model_name,as_text = False)
  if log_tensorboard:
    from tensorflow.python.tools import import_pb_to_tensorboard
    import_pb_to_tensorboard.import_to_tensorboard(osp.join(output_dir,model_name),output_dir)
#输出路径
output_dir = osp.join(os.getcwd(),"trans_model")
#加载模型
h5_model = load_model(weight_file_path)
h5_to_pb(h5_model,output_dir = output_dir,model_name = output_graph_name)
print('model saved')

将转换成的pb模型进行加载

load_pb.py
 
import tensorflow as tf
from tensorflow.python.platform import gfile
 
def load_pb(pb_file_path):
  sess = tf.Session()
  with gfile.FastGFile(pb_file_path, 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    sess.graph.as_default()
    tf.import_graph_def(graph_def, name='')
 
  print(sess.run('b:0'))
  #输入
  input_x = sess.graph.get_tensor_by_name('x:0')
  input_y = sess.graph.get_tensor_by_name('y:0')
  #输出
  op = sess.graph.get_tensor_by_name('op_to_store:0')
  #预测结果
  ret = sess.run(op, {input_x: 3, input_y: 4})
  print(ret)

补充知识:h5模型转化为pb模型,代码及排坑

我是在实际工程中要用到tensorflow训练的pb模型,但是训练的代码是用keras写的,所以生成keras特定的h5模型,所以用到了h5_to_pb.py函数。

附上h5_to_pb.py(python3)

#*-coding:utf-8-*

"""
将keras的.h5的模型文件,转换成TensorFlow的pb文件
"""
# ==========================================================

from keras.models import load_model
import tensorflow as tf
import os.path as osp
import os
from keras import backend
#from keras.models import Sequential

def h5_to_pb(h5_model, output_dir, model_name, out_prefix="output_", log_tensorboard=True):
  """.h5模型文件转换成pb模型文件
  Argument:
    h5_model: str
      .h5模型文件
    output_dir: str
      pb模型文件保存路径
    model_name: str
      pb模型文件名称
    out_prefix: str
      根据训练,需要修改
    log_tensorboard: bool
      是否生成日志文件
  Return:
    pb模型文件
  """
  if os.path.exists(output_dir) == False:
    os.mkdir(output_dir)
  out_nodes = []
  for i in range(len(h5_model.outputs)):
    out_nodes.append(out_prefix + str(i + 1))
    tf.identity(h5_model.output[i], out_prefix + str(i + 1))
  sess = backend.get_session()

  from tensorflow.python.framework import graph_util, graph_io
  # 写入pb模型文件
  init_graph = sess.graph.as_graph_def()
  main_graph = graph_util.convert_variables_to_constants(sess, init_graph, out_nodes)
  graph_io.write_graph(main_graph, output_dir, name=model_name, as_text=False)
  # 输出日志文件
  if log_tensorboard:
    from tensorflow.python.tools import import_pb_to_tensorboard
    import_pb_to_tensorboard.import_to_tensorboard(os.path.join(output_dir, model_name), output_dir)

if __name__ == '__main__':
  # .h模型文件路径参数
  input_path = 'D:/CSP'
  weight_file = 'xingren.h5'
  weight_file_path = os.path.join(input_path, weight_file)
  output_graph_name = weight_file[:-3] + '.pb'

  # pb模型文件输出输出路径
  output_dir = osp.join(os.getcwd(),"trans_model")
  #model.save(xingren.h5)
  # 加载模型
  #h5_model = Sequential()
  h5_model = load_model(weight_file_path)
  #h5_model.save(weight_file_path)
  #h5_model.save('xingren.h5')
  h5_to_pb(h5_model, output_dir=output_dir, model_name=output_graph_name)
  print ('Finished')

在运行的时候遇到了下面问题:

将keras的h5模型转换为tensorflow的pb模型操作

原因:我们训练模型的时候用save_weights函数保存模型,但是这个函数只保存了权重文件,并没有又保存模型的参数。要把save_weights改为save。

下边是两个函数介绍:

save()保存的模型结果,它既保持了模型的图结构,又保存了模型的参数。

save_weights()保存的模型结果,它只保存了模型的参数,但并没有保存模型的图结构

以上这篇将keras的h5模型转换为tensorflow的pb模型操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
浅谈Python使用Bottle来提供一个简单的web服务
Dec 27 Python
Python3之读取连接过的网络并定位的方法
Apr 22 Python
python 将md5转为16字节的方法
May 29 Python
python实现linux下抓包并存库功能
Jul 18 Python
Python 中PyQt5 点击主窗口弹出另一个窗口的实现方法
Jul 04 Python
python3的数据类型及数据类型转换实例详解
Aug 20 Python
Python实现语音识别和语音合成功能
Sep 20 Python
Pandas 缺失数据处理的实现
Nov 04 Python
Django Admin 上传文件到七牛云的示例代码
Jun 20 Python
实例讲解Python 迭代器与生成器
Jul 08 Python
python pillow库的基础使用教程
Jan 13 Python
Python实现视频自动打码的示例代码
Apr 08 Python
tensorflow转换ckpt为savermodel模型的实现
May 25 #Python
基于Python把网站域名解析成ip地址
May 25 #Python
使用keras和tensorflow保存为可部署的pb格式
May 25 #Python
Python使用configparser读取ini配置文件
May 25 #Python
浅谈tensorflow模型保存为pb的各种姿势
May 25 #Python
详解tensorflow2.x版本无法调用gpu的一种解决方法
May 25 #Python
keras模型保存为tensorflow的二进制模型方式
May 25 #Python
You might like
PHP 上传文件的方法(类)
2009/07/30 PHP
PHP学习笔记之三 数据库基本操作
2011/01/17 PHP
PHP表单递交控件名称含有点号(.)会被转化为下划线(_)的处理方法
2013/01/06 PHP
基于Zookeeper的使用详解
2013/05/02 PHP
php命名空间学习详解
2014/02/27 PHP
php实现RSA加密类实例
2015/03/26 PHP
PHPExcel导出2003和2007的excel文档功能示例
2017/01/04 PHP
PHP下用Swoole实现Actor并发模型的方法
2019/06/12 PHP
Yii框架where查询用法实例分析
2019/10/22 PHP
input+select(multiple) 实现下拉框输入值
2009/05/21 Javascript
jQuery简单实现banner图片切换
2014/01/02 Javascript
jQuery Easyui DataGrid点击某个单元格即进入编辑状态焦点移开后保存数据
2016/08/15 Javascript
jQuery实现table表格checkbox全选的方法分析
2018/07/04 jQuery
js实现input密码框显示/隐藏功能
2020/09/10 Javascript
Node.js 深度调试方法解析
2020/07/28 Javascript
使用JavaScript和MQTT开发物联网应用示例解析
2020/08/07 Javascript
Vuejs通过拖动改变元素宽度实现自适应
2020/09/02 Javascript
[01:33]完美世界DOTA2联赛PWL S3 集锦第二期
2020/12/21 DOTA
Python3.5 Pandas模块缺失值处理和层次索引实例详解
2019/04/23 Python
Python定时发送天气预报邮件代码实例
2019/09/09 Python
Python3 main函数使用sys.argv传入多个参数的实现
2019/12/25 Python
利用PyQt中的QThread类实现多线程
2020/02/18 Python
Docker如何部署Python项目的实现详解
2020/10/26 Python
Calphalon美国官网:美国顶级锅具品牌
2020/02/05 全球购物
入党申请书自我鉴定
2013/10/12 职场文书
求职简历中个人的自我评价
2013/12/01 职场文书
设备管理实施方案
2014/05/31 职场文书
合伙经营协议书范本
2014/09/13 职场文书
商场收银员岗位职责
2015/04/07 职场文书
演讲开场白和结束语
2015/05/29 职场文书
草房子读书笔记
2015/06/29 职场文书
护士岗前培训心得体会
2016/01/08 职场文书
iPhone13再次曝光
2021/04/15 数码科技
win11如何查看端口是否被占用? Win11查看端口是否占用的技巧
2022/04/05 数码科技
PYTHON 使用 Pandas 删除某列指定值所在的行
2022/04/28 Python
详解flex:1什么意思
2022/07/23 HTML / CSS