将keras的h5模型转换为tensorflow的pb模型操作


Posted in Python onMay 25, 2020

背景:目前keras框架使用简单,很容易上手,深得广大算法工程师的喜爱,但是当部署到客户端时,可能会出现各种各样的bug,甚至不支持使用keras,本文来解决的是将keras的h5模型转换为客户端常用的tensorflow的pb模型并使用tensorflow加载pb模型。

h5_to_pb.py
 
from keras.models import load_model
import tensorflow as tf
import os 
import os.path as osp
from keras import backend as K
#路径参数
input_path = 'input path'
weight_file = 'weight.h5'
weight_file_path = osp.join(input_path,weight_file)
output_graph_name = weight_file[:-3] + '.pb'
#转换函数
def h5_to_pb(h5_model,output_dir,model_name,out_prefix = "output_",log_tensorboard = True):
  if osp.exists(output_dir) == False:
    os.mkdir(output_dir)
  out_nodes = []
  for i in range(len(h5_model.outputs)):
    out_nodes.append(out_prefix + str(i + 1))
    tf.identity(h5_model.output[i],out_prefix + str(i + 1))
  sess = K.get_session()
  from tensorflow.python.framework import graph_util,graph_io
  init_graph = sess.graph.as_graph_def()
  main_graph = graph_util.convert_variables_to_constants(sess,init_graph,out_nodes)
  graph_io.write_graph(main_graph,output_dir,name = model_name,as_text = False)
  if log_tensorboard:
    from tensorflow.python.tools import import_pb_to_tensorboard
    import_pb_to_tensorboard.import_to_tensorboard(osp.join(output_dir,model_name),output_dir)
#输出路径
output_dir = osp.join(os.getcwd(),"trans_model")
#加载模型
h5_model = load_model(weight_file_path)
h5_to_pb(h5_model,output_dir = output_dir,model_name = output_graph_name)
print('model saved')

将转换成的pb模型进行加载

load_pb.py
 
import tensorflow as tf
from tensorflow.python.platform import gfile
 
def load_pb(pb_file_path):
  sess = tf.Session()
  with gfile.FastGFile(pb_file_path, 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    sess.graph.as_default()
    tf.import_graph_def(graph_def, name='')
 
  print(sess.run('b:0'))
  #输入
  input_x = sess.graph.get_tensor_by_name('x:0')
  input_y = sess.graph.get_tensor_by_name('y:0')
  #输出
  op = sess.graph.get_tensor_by_name('op_to_store:0')
  #预测结果
  ret = sess.run(op, {input_x: 3, input_y: 4})
  print(ret)

补充知识:h5模型转化为pb模型,代码及排坑

我是在实际工程中要用到tensorflow训练的pb模型,但是训练的代码是用keras写的,所以生成keras特定的h5模型,所以用到了h5_to_pb.py函数。

附上h5_to_pb.py(python3)

#*-coding:utf-8-*

"""
将keras的.h5的模型文件,转换成TensorFlow的pb文件
"""
# ==========================================================

from keras.models import load_model
import tensorflow as tf
import os.path as osp
import os
from keras import backend
#from keras.models import Sequential

def h5_to_pb(h5_model, output_dir, model_name, out_prefix="output_", log_tensorboard=True):
  """.h5模型文件转换成pb模型文件
  Argument:
    h5_model: str
      .h5模型文件
    output_dir: str
      pb模型文件保存路径
    model_name: str
      pb模型文件名称
    out_prefix: str
      根据训练,需要修改
    log_tensorboard: bool
      是否生成日志文件
  Return:
    pb模型文件
  """
  if os.path.exists(output_dir) == False:
    os.mkdir(output_dir)
  out_nodes = []
  for i in range(len(h5_model.outputs)):
    out_nodes.append(out_prefix + str(i + 1))
    tf.identity(h5_model.output[i], out_prefix + str(i + 1))
  sess = backend.get_session()

  from tensorflow.python.framework import graph_util, graph_io
  # 写入pb模型文件
  init_graph = sess.graph.as_graph_def()
  main_graph = graph_util.convert_variables_to_constants(sess, init_graph, out_nodes)
  graph_io.write_graph(main_graph, output_dir, name=model_name, as_text=False)
  # 输出日志文件
  if log_tensorboard:
    from tensorflow.python.tools import import_pb_to_tensorboard
    import_pb_to_tensorboard.import_to_tensorboard(os.path.join(output_dir, model_name), output_dir)

if __name__ == '__main__':
  # .h模型文件路径参数
  input_path = 'D:/CSP'
  weight_file = 'xingren.h5'
  weight_file_path = os.path.join(input_path, weight_file)
  output_graph_name = weight_file[:-3] + '.pb'

  # pb模型文件输出输出路径
  output_dir = osp.join(os.getcwd(),"trans_model")
  #model.save(xingren.h5)
  # 加载模型
  #h5_model = Sequential()
  h5_model = load_model(weight_file_path)
  #h5_model.save(weight_file_path)
  #h5_model.save('xingren.h5')
  h5_to_pb(h5_model, output_dir=output_dir, model_name=output_graph_name)
  print ('Finished')

在运行的时候遇到了下面问题:

将keras的h5模型转换为tensorflow的pb模型操作

原因:我们训练模型的时候用save_weights函数保存模型,但是这个函数只保存了权重文件,并没有又保存模型的参数。要把save_weights改为save。

下边是两个函数介绍:

save()保存的模型结果,它既保持了模型的图结构,又保存了模型的参数。

save_weights()保存的模型结果,它只保存了模型的参数,但并没有保存模型的图结构

以上这篇将keras的h5模型转换为tensorflow的pb模型操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python的Flask框架中Flask-Admin库的简单入门指引
Apr 07 Python
python通过自定义isnumber函数判断字符串是否为数字的方法
Apr 23 Python
Python中的数据对象持久化存储模块pickle的使用示例
Mar 03 Python
Python中规范定义命名空间的一些建议
Jun 04 Python
Python 类的继承实例详解
Mar 25 Python
python使用opencv按一定间隔截取视频帧
Mar 06 Python
使用python画社交网络图实例代码
Jul 10 Python
使用Python和Scribus创建一个RGB立方体的方法
Jul 17 Python
python内存监控工具memory_profiler和guppy的用法详解
Jul 29 Python
解决Django中调用keras的模型出现的问题
Aug 07 Python
python输出决策树图形的例子
Aug 09 Python
Python matplotlib多个子图绘制整合
Apr 13 Python
tensorflow转换ckpt为savermodel模型的实现
May 25 #Python
基于Python把网站域名解析成ip地址
May 25 #Python
使用keras和tensorflow保存为可部署的pb格式
May 25 #Python
Python使用configparser读取ini配置文件
May 25 #Python
浅谈tensorflow模型保存为pb的各种姿势
May 25 #Python
详解tensorflow2.x版本无法调用gpu的一种解决方法
May 25 #Python
keras模型保存为tensorflow的二进制模型方式
May 25 #Python
You might like
曾在DC漫画界反派角色扮演的演员,谁才是你心目中的小丑之王?
2020/04/09 欧美动漫
数组任意位置插入元素,删除特定元素的实例
2017/03/02 PHP
PHP自定义函数实现数组比较功能示例
2017/10/19 PHP
laravel实现按月或天或小时统计mysql数据的方法
2019/10/09 PHP
laravel excel 上传文件保存到本地服务器功能
2019/11/14 PHP
ASP小贴士/ASP Tips javascript tips可以当桌面
2009/12/10 Javascript
jquery 全局AJAX事件使用代码
2010/11/05 Javascript
js函数中onmousedown和onclick的区别和联系探讨
2013/05/19 Javascript
json数据与字符串的相互转化示例
2013/09/18 Javascript
js判断游览器类型及版本号的代码
2014/05/11 Javascript
JavaScript中使用Object.create()创建对象介绍
2014/12/30 Javascript
jquery实现隐藏在左侧的弹性弹出菜单效果
2015/09/18 Javascript
angularJS 指令封装回到顶部示例详解
2017/01/22 Javascript
javascript设计模式之模块模式学习笔记
2017/02/15 Javascript
JavaScript基本语法_动力节点Java学院整理
2017/06/26 Javascript
JS模拟超市简易收银台小程序代码解析
2017/08/18 Javascript
JS闭包的几种常见形式实例详解
2017/09/16 Javascript
vue ajax 拦截原理与实现方法示例
2019/11/29 Javascript
node.JS路径解析之PATH模块使用方法详解
2020/02/06 Javascript
JavaScript中继承原理与用法实例入门
2020/05/09 Javascript
Python中super的用法实例
2015/05/28 Python
Linux中Python 环境软件包安装步骤
2016/03/31 Python
python中装饰器级连的使用方法示例
2017/09/29 Python
ubuntu环境下python虚拟环境的安装过程
2018/01/07 Python
selenium+python 去除启动的黑色cmd窗口方法
2018/05/22 Python
Python计算一个给定时间点前一个月和后一个月第一天的方法
2018/05/29 Python
详解Python的三种拷贝方式
2020/02/11 Python
Python读取yaml文件的详细教程
2020/07/21 Python
Python tkinter实现日期选择器
2021/02/22 Python
医学院毕业生自荐信
2013/11/08 职场文书
师范学院教师自荐书
2014/01/31 职场文书
圣诞节红领巾广播稿
2014/02/03 职场文书
会计求职信
2014/05/29 职场文书
年度考核个人总结
2015/03/06 职场文书
python字典进行运算原理及实例分享
2021/08/02 Python
聊聊redis-dump工具安装问题
2022/01/18 Redis