将keras的h5模型转换为tensorflow的pb模型操作


Posted in Python onMay 25, 2020

背景:目前keras框架使用简单,很容易上手,深得广大算法工程师的喜爱,但是当部署到客户端时,可能会出现各种各样的bug,甚至不支持使用keras,本文来解决的是将keras的h5模型转换为客户端常用的tensorflow的pb模型并使用tensorflow加载pb模型。

h5_to_pb.py
 
from keras.models import load_model
import tensorflow as tf
import os 
import os.path as osp
from keras import backend as K
#路径参数
input_path = 'input path'
weight_file = 'weight.h5'
weight_file_path = osp.join(input_path,weight_file)
output_graph_name = weight_file[:-3] + '.pb'
#转换函数
def h5_to_pb(h5_model,output_dir,model_name,out_prefix = "output_",log_tensorboard = True):
  if osp.exists(output_dir) == False:
    os.mkdir(output_dir)
  out_nodes = []
  for i in range(len(h5_model.outputs)):
    out_nodes.append(out_prefix + str(i + 1))
    tf.identity(h5_model.output[i],out_prefix + str(i + 1))
  sess = K.get_session()
  from tensorflow.python.framework import graph_util,graph_io
  init_graph = sess.graph.as_graph_def()
  main_graph = graph_util.convert_variables_to_constants(sess,init_graph,out_nodes)
  graph_io.write_graph(main_graph,output_dir,name = model_name,as_text = False)
  if log_tensorboard:
    from tensorflow.python.tools import import_pb_to_tensorboard
    import_pb_to_tensorboard.import_to_tensorboard(osp.join(output_dir,model_name),output_dir)
#输出路径
output_dir = osp.join(os.getcwd(),"trans_model")
#加载模型
h5_model = load_model(weight_file_path)
h5_to_pb(h5_model,output_dir = output_dir,model_name = output_graph_name)
print('model saved')

将转换成的pb模型进行加载

load_pb.py
 
import tensorflow as tf
from tensorflow.python.platform import gfile
 
def load_pb(pb_file_path):
  sess = tf.Session()
  with gfile.FastGFile(pb_file_path, 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    sess.graph.as_default()
    tf.import_graph_def(graph_def, name='')
 
  print(sess.run('b:0'))
  #输入
  input_x = sess.graph.get_tensor_by_name('x:0')
  input_y = sess.graph.get_tensor_by_name('y:0')
  #输出
  op = sess.graph.get_tensor_by_name('op_to_store:0')
  #预测结果
  ret = sess.run(op, {input_x: 3, input_y: 4})
  print(ret)

补充知识:h5模型转化为pb模型,代码及排坑

我是在实际工程中要用到tensorflow训练的pb模型,但是训练的代码是用keras写的,所以生成keras特定的h5模型,所以用到了h5_to_pb.py函数。

附上h5_to_pb.py(python3)

#*-coding:utf-8-*

"""
将keras的.h5的模型文件,转换成TensorFlow的pb文件
"""
# ==========================================================

from keras.models import load_model
import tensorflow as tf
import os.path as osp
import os
from keras import backend
#from keras.models import Sequential

def h5_to_pb(h5_model, output_dir, model_name, out_prefix="output_", log_tensorboard=True):
  """.h5模型文件转换成pb模型文件
  Argument:
    h5_model: str
      .h5模型文件
    output_dir: str
      pb模型文件保存路径
    model_name: str
      pb模型文件名称
    out_prefix: str
      根据训练,需要修改
    log_tensorboard: bool
      是否生成日志文件
  Return:
    pb模型文件
  """
  if os.path.exists(output_dir) == False:
    os.mkdir(output_dir)
  out_nodes = []
  for i in range(len(h5_model.outputs)):
    out_nodes.append(out_prefix + str(i + 1))
    tf.identity(h5_model.output[i], out_prefix + str(i + 1))
  sess = backend.get_session()

  from tensorflow.python.framework import graph_util, graph_io
  # 写入pb模型文件
  init_graph = sess.graph.as_graph_def()
  main_graph = graph_util.convert_variables_to_constants(sess, init_graph, out_nodes)
  graph_io.write_graph(main_graph, output_dir, name=model_name, as_text=False)
  # 输出日志文件
  if log_tensorboard:
    from tensorflow.python.tools import import_pb_to_tensorboard
    import_pb_to_tensorboard.import_to_tensorboard(os.path.join(output_dir, model_name), output_dir)

if __name__ == '__main__':
  # .h模型文件路径参数
  input_path = 'D:/CSP'
  weight_file = 'xingren.h5'
  weight_file_path = os.path.join(input_path, weight_file)
  output_graph_name = weight_file[:-3] + '.pb'

  # pb模型文件输出输出路径
  output_dir = osp.join(os.getcwd(),"trans_model")
  #model.save(xingren.h5)
  # 加载模型
  #h5_model = Sequential()
  h5_model = load_model(weight_file_path)
  #h5_model.save(weight_file_path)
  #h5_model.save('xingren.h5')
  h5_to_pb(h5_model, output_dir=output_dir, model_name=output_graph_name)
  print ('Finished')

在运行的时候遇到了下面问题:

将keras的h5模型转换为tensorflow的pb模型操作

原因:我们训练模型的时候用save_weights函数保存模型,但是这个函数只保存了权重文件,并没有又保存模型的参数。要把save_weights改为save。

下边是两个函数介绍:

save()保存的模型结果,它既保持了模型的图结构,又保存了模型的参数。

save_weights()保存的模型结果,它只保存了模型的参数,但并没有保存模型的图结构

以上这篇将keras的h5模型转换为tensorflow的pb模型操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python在命令行下使用google翻译(带语音)
Jan 16 Python
python输入整条数据分割存入数组的方法
Nov 13 Python
对python3 中方法各种参数和返回值详解
Dec 15 Python
浅谈django2.0 ForeignKey参数的变化
Aug 06 Python
如何用Python来搭建一个简单的推荐系统
Aug 07 Python
sklearn-SVC实现与类参数详解
Dec 10 Python
python GUI库图形界面开发之PyQt5下拉列表框控件QComboBox详细使用方法与实例
Feb 27 Python
python+gdal+遥感图像拼接(mosaic)的实例
Mar 10 Python
python3.7+selenium模拟淘宝登录功能的实现
May 26 Python
python实现发送带附件的邮件代码分享
Sep 22 Python
OpenCV灰度化之后图片为绿色的解决
Dec 01 Python
pytorch中的model=model.to(device)使用说明
May 24 Python
tensorflow转换ckpt为savermodel模型的实现
May 25 #Python
基于Python把网站域名解析成ip地址
May 25 #Python
使用keras和tensorflow保存为可部署的pb格式
May 25 #Python
Python使用configparser读取ini配置文件
May 25 #Python
浅谈tensorflow模型保存为pb的各种姿势
May 25 #Python
详解tensorflow2.x版本无法调用gpu的一种解决方法
May 25 #Python
keras模型保存为tensorflow的二进制模型方式
May 25 #Python
You might like
php实现utf-8和GB2312编码相互转换函数代码
2013/02/07 PHP
ThinkPHP3.1新特性之对Ajax的支持更加完善
2014/06/19 PHP
symfony2.4的twig中date用法分析
2016/03/18 PHP
php获取今日开始时间和结束时间的方法
2017/02/27 PHP
php7 安装yar 生成docker镜像
2017/05/09 PHP
JavaScript中的Window窗口对象
2008/01/16 Javascript
jquery实现简单的拖拽效果实例兼容所有主流浏览器
2013/06/21 Javascript
用Jquery.load载入页面实现局部刷新
2014/01/22 Javascript
JavaScript Promise启示录
2014/08/12 Javascript
jQuery实现鼠标划过添加和删除class的方法
2015/06/26 Javascript
JavaScript实现网页加载进度条代码超简单
2015/09/21 Javascript
jQuery的Each比JS原生for循环性能慢很多的原因
2016/07/05 Javascript
详解Vue中添加过渡效果
2017/03/20 Javascript
浅谈在node.js进入文件目录的问题
2018/05/13 Javascript
jQuery实现的老虎机跑动效果示例
2018/12/29 jQuery
selenium+java中用js来完成日期的修改
2019/10/31 Javascript
jQuery 判断元素是否存在然后按需加载内容的实现代码
2020/01/16 jQuery
Vue中watch、computed、updated三者的区别及用法
2020/07/27 Javascript
vue 判断页面是首次进入还是再次刷新的实例
2020/11/05 Javascript
[01:00:49]DOTA2-DPC中国联赛 正赛 Ehome vs iG BO3 第二场 1月31日
2021/03/11 DOTA
python 实现堆排序算法代码
2012/06/05 Python
用pywin32实现windows模拟鼠标及键盘动作
2014/04/22 Python
Python通过DOM和SAX方式解析XML的应用实例分享
2015/11/16 Python
python3实现跳一跳点击跳跃
2018/01/08 Python
python 显示数组全部元素的方法
2018/04/19 Python
详解CSS3中字体平滑处理和抗锯齿渲染
2017/03/29 HTML / CSS
Html5上传图片 移动端、PC端通用代码
2016/06/08 HTML / CSS
十岁生日同学答谢词
2014/01/19 职场文书
家电业务员岗位职责
2014/03/10 职场文书
二手房买卖协议书
2014/04/10 职场文书
技校毕业生自荐书
2014/05/23 职场文书
广告学专业毕业生自荐信
2014/05/28 职场文书
移交协议书
2014/08/19 职场文书
公司仓管员岗位职责
2015/04/01 职场文书
2015年护理工作总结范文
2015/04/03 职场文书
倡议书范文大全
2015/04/28 职场文书