将keras的h5模型转换为tensorflow的pb模型操作


Posted in Python onMay 25, 2020

背景:目前keras框架使用简单,很容易上手,深得广大算法工程师的喜爱,但是当部署到客户端时,可能会出现各种各样的bug,甚至不支持使用keras,本文来解决的是将keras的h5模型转换为客户端常用的tensorflow的pb模型并使用tensorflow加载pb模型。

h5_to_pb.py
 
from keras.models import load_model
import tensorflow as tf
import os 
import os.path as osp
from keras import backend as K
#路径参数
input_path = 'input path'
weight_file = 'weight.h5'
weight_file_path = osp.join(input_path,weight_file)
output_graph_name = weight_file[:-3] + '.pb'
#转换函数
def h5_to_pb(h5_model,output_dir,model_name,out_prefix = "output_",log_tensorboard = True):
  if osp.exists(output_dir) == False:
    os.mkdir(output_dir)
  out_nodes = []
  for i in range(len(h5_model.outputs)):
    out_nodes.append(out_prefix + str(i + 1))
    tf.identity(h5_model.output[i],out_prefix + str(i + 1))
  sess = K.get_session()
  from tensorflow.python.framework import graph_util,graph_io
  init_graph = sess.graph.as_graph_def()
  main_graph = graph_util.convert_variables_to_constants(sess,init_graph,out_nodes)
  graph_io.write_graph(main_graph,output_dir,name = model_name,as_text = False)
  if log_tensorboard:
    from tensorflow.python.tools import import_pb_to_tensorboard
    import_pb_to_tensorboard.import_to_tensorboard(osp.join(output_dir,model_name),output_dir)
#输出路径
output_dir = osp.join(os.getcwd(),"trans_model")
#加载模型
h5_model = load_model(weight_file_path)
h5_to_pb(h5_model,output_dir = output_dir,model_name = output_graph_name)
print('model saved')

将转换成的pb模型进行加载

load_pb.py
 
import tensorflow as tf
from tensorflow.python.platform import gfile
 
def load_pb(pb_file_path):
  sess = tf.Session()
  with gfile.FastGFile(pb_file_path, 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    sess.graph.as_default()
    tf.import_graph_def(graph_def, name='')
 
  print(sess.run('b:0'))
  #输入
  input_x = sess.graph.get_tensor_by_name('x:0')
  input_y = sess.graph.get_tensor_by_name('y:0')
  #输出
  op = sess.graph.get_tensor_by_name('op_to_store:0')
  #预测结果
  ret = sess.run(op, {input_x: 3, input_y: 4})
  print(ret)

补充知识:h5模型转化为pb模型,代码及排坑

我是在实际工程中要用到tensorflow训练的pb模型,但是训练的代码是用keras写的,所以生成keras特定的h5模型,所以用到了h5_to_pb.py函数。

附上h5_to_pb.py(python3)

#*-coding:utf-8-*

"""
将keras的.h5的模型文件,转换成TensorFlow的pb文件
"""
# ==========================================================

from keras.models import load_model
import tensorflow as tf
import os.path as osp
import os
from keras import backend
#from keras.models import Sequential

def h5_to_pb(h5_model, output_dir, model_name, out_prefix="output_", log_tensorboard=True):
  """.h5模型文件转换成pb模型文件
  Argument:
    h5_model: str
      .h5模型文件
    output_dir: str
      pb模型文件保存路径
    model_name: str
      pb模型文件名称
    out_prefix: str
      根据训练,需要修改
    log_tensorboard: bool
      是否生成日志文件
  Return:
    pb模型文件
  """
  if os.path.exists(output_dir) == False:
    os.mkdir(output_dir)
  out_nodes = []
  for i in range(len(h5_model.outputs)):
    out_nodes.append(out_prefix + str(i + 1))
    tf.identity(h5_model.output[i], out_prefix + str(i + 1))
  sess = backend.get_session()

  from tensorflow.python.framework import graph_util, graph_io
  # 写入pb模型文件
  init_graph = sess.graph.as_graph_def()
  main_graph = graph_util.convert_variables_to_constants(sess, init_graph, out_nodes)
  graph_io.write_graph(main_graph, output_dir, name=model_name, as_text=False)
  # 输出日志文件
  if log_tensorboard:
    from tensorflow.python.tools import import_pb_to_tensorboard
    import_pb_to_tensorboard.import_to_tensorboard(os.path.join(output_dir, model_name), output_dir)

if __name__ == '__main__':
  # .h模型文件路径参数
  input_path = 'D:/CSP'
  weight_file = 'xingren.h5'
  weight_file_path = os.path.join(input_path, weight_file)
  output_graph_name = weight_file[:-3] + '.pb'

  # pb模型文件输出输出路径
  output_dir = osp.join(os.getcwd(),"trans_model")
  #model.save(xingren.h5)
  # 加载模型
  #h5_model = Sequential()
  h5_model = load_model(weight_file_path)
  #h5_model.save(weight_file_path)
  #h5_model.save('xingren.h5')
  h5_to_pb(h5_model, output_dir=output_dir, model_name=output_graph_name)
  print ('Finished')

在运行的时候遇到了下面问题:

将keras的h5模型转换为tensorflow的pb模型操作

原因:我们训练模型的时候用save_weights函数保存模型,但是这个函数只保存了权重文件,并没有又保存模型的参数。要把save_weights改为save。

下边是两个函数介绍:

save()保存的模型结果,它既保持了模型的图结构,又保存了模型的参数。

save_weights()保存的模型结果,它只保存了模型的参数,但并没有保存模型的图结构

以上这篇将keras的h5模型转换为tensorflow的pb模型操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中多线程thread与threading的实现方法
Aug 18 Python
详解Django框架中的视图级缓存
Jul 23 Python
python生成器表达式和列表解析
Mar 10 Python
Ubuntu安装Jupyter Notebook教程
Oct 18 Python
用python制作游戏外挂
Jan 04 Python
对python requests发送json格式数据的实例详解
Dec 19 Python
Python获取航线信息并且制作成图的讲解
Jan 03 Python
python实现可变变量名方法详解
Jul 01 Python
Jupyter Notebook 文件默认目录的查看以及更改步骤
Apr 14 Python
Python类中的装饰器在当前类中的声明与调用详解
Apr 15 Python
python实现俄罗斯方块小游戏
Apr 24 Python
Python基于内置函数type创建新类型
Oct 22 Python
tensorflow转换ckpt为savermodel模型的实现
May 25 #Python
基于Python把网站域名解析成ip地址
May 25 #Python
使用keras和tensorflow保存为可部署的pb格式
May 25 #Python
Python使用configparser读取ini配置文件
May 25 #Python
浅谈tensorflow模型保存为pb的各种姿势
May 25 #Python
详解tensorflow2.x版本无法调用gpu的一种解决方法
May 25 #Python
keras模型保存为tensorflow的二进制模型方式
May 25 #Python
You might like
德劲1104的电路分析与改良
2021/03/01 无线电
CI框架学习笔记(一) - 环境安装、基本术语和框架流程
2014/10/26 PHP
PHP strcmp()和strcasecmp()的区别实例
2016/11/05 PHP
PHP开发中解决并发问题的几种实现方法分析
2017/11/13 PHP
js实现遮罩层弹出框的方法
2015/01/15 Javascript
JS实现动态给图片添加边框的方法
2015/04/01 Javascript
js实现的tab标签切换效果代码分享
2015/08/25 Javascript
Vue.js组件使用开发实例教程
2016/11/01 Javascript
基于vue2.0+vuex+localStorage开发的本地记事本示例
2017/02/28 Javascript
vue-cli如何快速构建vue项目
2017/04/26 Javascript
vue使用axios时关于this的指向问题详解
2017/12/22 Javascript
Angular2整合其他插件的方法
2018/01/20 Javascript
vue项目实现记住密码到cookie功能示例(附源码)
2018/01/31 Javascript
JavaScript回调函数callback用法解析
2020/01/14 Javascript
jquery插件实现轮播图效果
2020/10/19 jQuery
解决vue中使用less/sass及使用中遇到无效的问题
2020/10/24 Javascript
[01:45]典藏宝瓶2+祈求者身心——这就是DOTA2TI9总奖金突破3000万美元的秘密
2019/07/21 DOTA
初学Python实用技巧两则
2014/08/29 Python
Python实现的简单算术游戏实例
2015/05/26 Python
对python中执行DOS命令的3种方法总结
2018/05/12 Python
selenium python 实现基本自动化测试的示例代码
2019/02/25 Python
简单了解python 生成器 列表推导式 生成器表达式
2019/08/22 Python
Python (Win)readline和tab补全的安装方法
2019/08/27 Python
使用Pandas将inf, nan转化成特定的值
2019/12/19 Python
基于python实现获取网页图片过程解析
2020/05/11 Python
详解用Python调用百度地图正/逆地理编码API
2020/07/02 Python
Python urllib库如何添加headers过程解析
2020/10/05 Python
澳大利亚UGG工厂直销:Australian Ugg Boots
2017/10/14 全球购物
泰国演唱会订票网站:StubHub泰国
2018/02/26 全球购物
L’urv官网:精品女性运动服品牌
2019/07/07 全球购物
酒桌上的开场白
2015/06/01 职场文书
摘录式读书笔记
2015/07/01 职场文书
2016春季幼儿园小班开学寄语
2015/12/03 职场文书
python如何进行基准测试
2021/04/26 Python
MySQL REVOKE实现删除用户权限
2021/06/18 MySQL
vue route新窗口跳转页面并且携带与接收参数
2022/04/10 Vue.js