将keras的h5模型转换为tensorflow的pb模型操作


Posted in Python onMay 25, 2020

背景:目前keras框架使用简单,很容易上手,深得广大算法工程师的喜爱,但是当部署到客户端时,可能会出现各种各样的bug,甚至不支持使用keras,本文来解决的是将keras的h5模型转换为客户端常用的tensorflow的pb模型并使用tensorflow加载pb模型。

h5_to_pb.py
 
from keras.models import load_model
import tensorflow as tf
import os 
import os.path as osp
from keras import backend as K
#路径参数
input_path = 'input path'
weight_file = 'weight.h5'
weight_file_path = osp.join(input_path,weight_file)
output_graph_name = weight_file[:-3] + '.pb'
#转换函数
def h5_to_pb(h5_model,output_dir,model_name,out_prefix = "output_",log_tensorboard = True):
  if osp.exists(output_dir) == False:
    os.mkdir(output_dir)
  out_nodes = []
  for i in range(len(h5_model.outputs)):
    out_nodes.append(out_prefix + str(i + 1))
    tf.identity(h5_model.output[i],out_prefix + str(i + 1))
  sess = K.get_session()
  from tensorflow.python.framework import graph_util,graph_io
  init_graph = sess.graph.as_graph_def()
  main_graph = graph_util.convert_variables_to_constants(sess,init_graph,out_nodes)
  graph_io.write_graph(main_graph,output_dir,name = model_name,as_text = False)
  if log_tensorboard:
    from tensorflow.python.tools import import_pb_to_tensorboard
    import_pb_to_tensorboard.import_to_tensorboard(osp.join(output_dir,model_name),output_dir)
#输出路径
output_dir = osp.join(os.getcwd(),"trans_model")
#加载模型
h5_model = load_model(weight_file_path)
h5_to_pb(h5_model,output_dir = output_dir,model_name = output_graph_name)
print('model saved')

将转换成的pb模型进行加载

load_pb.py
 
import tensorflow as tf
from tensorflow.python.platform import gfile
 
def load_pb(pb_file_path):
  sess = tf.Session()
  with gfile.FastGFile(pb_file_path, 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    sess.graph.as_default()
    tf.import_graph_def(graph_def, name='')
 
  print(sess.run('b:0'))
  #输入
  input_x = sess.graph.get_tensor_by_name('x:0')
  input_y = sess.graph.get_tensor_by_name('y:0')
  #输出
  op = sess.graph.get_tensor_by_name('op_to_store:0')
  #预测结果
  ret = sess.run(op, {input_x: 3, input_y: 4})
  print(ret)

补充知识:h5模型转化为pb模型,代码及排坑

我是在实际工程中要用到tensorflow训练的pb模型,但是训练的代码是用keras写的,所以生成keras特定的h5模型,所以用到了h5_to_pb.py函数。

附上h5_to_pb.py(python3)

#*-coding:utf-8-*

"""
将keras的.h5的模型文件,转换成TensorFlow的pb文件
"""
# ==========================================================

from keras.models import load_model
import tensorflow as tf
import os.path as osp
import os
from keras import backend
#from keras.models import Sequential

def h5_to_pb(h5_model, output_dir, model_name, out_prefix="output_", log_tensorboard=True):
  """.h5模型文件转换成pb模型文件
  Argument:
    h5_model: str
      .h5模型文件
    output_dir: str
      pb模型文件保存路径
    model_name: str
      pb模型文件名称
    out_prefix: str
      根据训练,需要修改
    log_tensorboard: bool
      是否生成日志文件
  Return:
    pb模型文件
  """
  if os.path.exists(output_dir) == False:
    os.mkdir(output_dir)
  out_nodes = []
  for i in range(len(h5_model.outputs)):
    out_nodes.append(out_prefix + str(i + 1))
    tf.identity(h5_model.output[i], out_prefix + str(i + 1))
  sess = backend.get_session()

  from tensorflow.python.framework import graph_util, graph_io
  # 写入pb模型文件
  init_graph = sess.graph.as_graph_def()
  main_graph = graph_util.convert_variables_to_constants(sess, init_graph, out_nodes)
  graph_io.write_graph(main_graph, output_dir, name=model_name, as_text=False)
  # 输出日志文件
  if log_tensorboard:
    from tensorflow.python.tools import import_pb_to_tensorboard
    import_pb_to_tensorboard.import_to_tensorboard(os.path.join(output_dir, model_name), output_dir)

if __name__ == '__main__':
  # .h模型文件路径参数
  input_path = 'D:/CSP'
  weight_file = 'xingren.h5'
  weight_file_path = os.path.join(input_path, weight_file)
  output_graph_name = weight_file[:-3] + '.pb'

  # pb模型文件输出输出路径
  output_dir = osp.join(os.getcwd(),"trans_model")
  #model.save(xingren.h5)
  # 加载模型
  #h5_model = Sequential()
  h5_model = load_model(weight_file_path)
  #h5_model.save(weight_file_path)
  #h5_model.save('xingren.h5')
  h5_to_pb(h5_model, output_dir=output_dir, model_name=output_graph_name)
  print ('Finished')

在运行的时候遇到了下面问题:

将keras的h5模型转换为tensorflow的pb模型操作

原因:我们训练模型的时候用save_weights函数保存模型,但是这个函数只保存了权重文件,并没有又保存模型的参数。要把save_weights改为save。

下边是两个函数介绍:

save()保存的模型结果,它既保持了模型的图结构,又保存了模型的参数。

save_weights()保存的模型结果,它只保存了模型的参数,但并没有保存模型的图结构

以上这篇将keras的h5模型转换为tensorflow的pb模型操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
centos 下面安装python2.7 +pip +mysqld
Nov 18 Python
Python中__init__.py文件的作用详解
Sep 18 Python
python中yaml配置文件模块的使用详解
Apr 27 Python
异步任务队列Celery在Django中的使用方法
Jun 07 Python
用python脚本24小时刷浏览器的访问量方法
Dec 07 Python
学习python可以干什么
Feb 26 Python
Django ImageFiled上传照片并显示的方法
Jul 28 Python
Python IDE Pycharm中的快捷键列表用法
Aug 08 Python
Python封装成可带参数的EXE安装包实例
Aug 24 Python
python torch.utils.data.DataLoader使用方法
Apr 02 Python
Python如何使用ElementTree解析xml
Oct 12 Python
Python+Appium自动化测试的实战
Jun 30 Python
tensorflow转换ckpt为savermodel模型的实现
May 25 #Python
基于Python把网站域名解析成ip地址
May 25 #Python
使用keras和tensorflow保存为可部署的pb格式
May 25 #Python
Python使用configparser读取ini配置文件
May 25 #Python
浅谈tensorflow模型保存为pb的各种姿势
May 25 #Python
详解tensorflow2.x版本无法调用gpu的一种解决方法
May 25 #Python
keras模型保存为tensorflow的二进制模型方式
May 25 #Python
You might like
学习php笔记 字符串处理
2010/10/19 PHP
Zend的MVC机制使用分析(一)
2013/05/02 PHP
PHP内核探索:哈希表碰撞攻击原理
2015/07/31 PHP
PHP静态成员变量和非静态成员变量详解
2017/02/14 PHP
php mysql_list_dbs()函数用法示例
2017/03/29 PHP
关于Mozilla浏览器不支持innerText的解决办法
2011/01/01 Javascript
jQuery学习笔记之jQuery+CSS3的浏览器兼容性
2015/01/19 Javascript
Ajax清除浏览器js、css、图片缓存的方法
2015/08/06 Javascript
JS时间特效最常用的三款
2015/08/19 Javascript
jQuery结合CSS制作动态的下拉菜单
2015/10/27 Javascript
javascript中的 object 和 function小结
2016/08/14 Javascript
JavaScript之DOM_动力节点Java学院整理
2017/07/03 Javascript
webpack4.0 入门实践教程
2018/10/08 Javascript
微信小程序MUI侧滑导航菜单示例(Popup弹出式,左侧滑动,右侧不动)
2019/01/23 Javascript
解析vue、angular深度作用选择器
2019/09/11 Javascript
vue 添加和编辑用同一个表单,el-form表单提交后清空表单数据操作
2020/08/03 Javascript
JS实现密码框效果
2020/09/10 Javascript
python简单程序读取串口信息的方法
2015/03/13 Python
Python fileinput模块使用实例
2015/06/03 Python
使用python遍历指定城市的一周气温
2017/03/31 Python
python中WSGI是什么,Python应用WSGI详解
2017/11/24 Python
Flask框架响应、调度方法和蓝图操作实例分析
2018/07/24 Python
基于python3监控服务器状态进行邮件报警
2019/10/19 Python
python使用梯度下降和牛顿法寻找Rosenbrock函数最小值实例
2020/04/02 Python
基于Python的OCR实现示例
2020/04/03 Python
Python selenium使用autoIT上传附件过程详解
2020/05/26 Python
完美解决keras保存好的model不能成功加载问题
2020/06/11 Python
Django Auth用户认证组件实现代码
2020/10/13 Python
屈臣氏俄罗斯在线商店:Watsons俄罗斯
2020/08/03 全球购物
应用数学自荐书范文
2013/11/24 职场文书
感恩教育月活动总结
2014/07/07 职场文书
销售人员求职信
2014/07/22 职场文书
入党积极分子十八届四中全会思想汇报
2014/10/23 职场文书
24年收藏2000多部退役军用电台
2022/02/18 无线电
spring注解 @PropertySource配置数据源全流程
2022/03/25 Java/Android
Springboot集成kafka高级应用实战分享
2022/08/14 Java/Android