将keras的h5模型转换为tensorflow的pb模型操作


Posted in Python onMay 25, 2020

背景:目前keras框架使用简单,很容易上手,深得广大算法工程师的喜爱,但是当部署到客户端时,可能会出现各种各样的bug,甚至不支持使用keras,本文来解决的是将keras的h5模型转换为客户端常用的tensorflow的pb模型并使用tensorflow加载pb模型。

h5_to_pb.py
 
from keras.models import load_model
import tensorflow as tf
import os 
import os.path as osp
from keras import backend as K
#路径参数
input_path = 'input path'
weight_file = 'weight.h5'
weight_file_path = osp.join(input_path,weight_file)
output_graph_name = weight_file[:-3] + '.pb'
#转换函数
def h5_to_pb(h5_model,output_dir,model_name,out_prefix = "output_",log_tensorboard = True):
  if osp.exists(output_dir) == False:
    os.mkdir(output_dir)
  out_nodes = []
  for i in range(len(h5_model.outputs)):
    out_nodes.append(out_prefix + str(i + 1))
    tf.identity(h5_model.output[i],out_prefix + str(i + 1))
  sess = K.get_session()
  from tensorflow.python.framework import graph_util,graph_io
  init_graph = sess.graph.as_graph_def()
  main_graph = graph_util.convert_variables_to_constants(sess,init_graph,out_nodes)
  graph_io.write_graph(main_graph,output_dir,name = model_name,as_text = False)
  if log_tensorboard:
    from tensorflow.python.tools import import_pb_to_tensorboard
    import_pb_to_tensorboard.import_to_tensorboard(osp.join(output_dir,model_name),output_dir)
#输出路径
output_dir = osp.join(os.getcwd(),"trans_model")
#加载模型
h5_model = load_model(weight_file_path)
h5_to_pb(h5_model,output_dir = output_dir,model_name = output_graph_name)
print('model saved')

将转换成的pb模型进行加载

load_pb.py
 
import tensorflow as tf
from tensorflow.python.platform import gfile
 
def load_pb(pb_file_path):
  sess = tf.Session()
  with gfile.FastGFile(pb_file_path, 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    sess.graph.as_default()
    tf.import_graph_def(graph_def, name='')
 
  print(sess.run('b:0'))
  #输入
  input_x = sess.graph.get_tensor_by_name('x:0')
  input_y = sess.graph.get_tensor_by_name('y:0')
  #输出
  op = sess.graph.get_tensor_by_name('op_to_store:0')
  #预测结果
  ret = sess.run(op, {input_x: 3, input_y: 4})
  print(ret)

补充知识:h5模型转化为pb模型,代码及排坑

我是在实际工程中要用到tensorflow训练的pb模型,但是训练的代码是用keras写的,所以生成keras特定的h5模型,所以用到了h5_to_pb.py函数。

附上h5_to_pb.py(python3)

#*-coding:utf-8-*

"""
将keras的.h5的模型文件,转换成TensorFlow的pb文件
"""
# ==========================================================

from keras.models import load_model
import tensorflow as tf
import os.path as osp
import os
from keras import backend
#from keras.models import Sequential

def h5_to_pb(h5_model, output_dir, model_name, out_prefix="output_", log_tensorboard=True):
  """.h5模型文件转换成pb模型文件
  Argument:
    h5_model: str
      .h5模型文件
    output_dir: str
      pb模型文件保存路径
    model_name: str
      pb模型文件名称
    out_prefix: str
      根据训练,需要修改
    log_tensorboard: bool
      是否生成日志文件
  Return:
    pb模型文件
  """
  if os.path.exists(output_dir) == False:
    os.mkdir(output_dir)
  out_nodes = []
  for i in range(len(h5_model.outputs)):
    out_nodes.append(out_prefix + str(i + 1))
    tf.identity(h5_model.output[i], out_prefix + str(i + 1))
  sess = backend.get_session()

  from tensorflow.python.framework import graph_util, graph_io
  # 写入pb模型文件
  init_graph = sess.graph.as_graph_def()
  main_graph = graph_util.convert_variables_to_constants(sess, init_graph, out_nodes)
  graph_io.write_graph(main_graph, output_dir, name=model_name, as_text=False)
  # 输出日志文件
  if log_tensorboard:
    from tensorflow.python.tools import import_pb_to_tensorboard
    import_pb_to_tensorboard.import_to_tensorboard(os.path.join(output_dir, model_name), output_dir)

if __name__ == '__main__':
  # .h模型文件路径参数
  input_path = 'D:/CSP'
  weight_file = 'xingren.h5'
  weight_file_path = os.path.join(input_path, weight_file)
  output_graph_name = weight_file[:-3] + '.pb'

  # pb模型文件输出输出路径
  output_dir = osp.join(os.getcwd(),"trans_model")
  #model.save(xingren.h5)
  # 加载模型
  #h5_model = Sequential()
  h5_model = load_model(weight_file_path)
  #h5_model.save(weight_file_path)
  #h5_model.save('xingren.h5')
  h5_to_pb(h5_model, output_dir=output_dir, model_name=output_graph_name)
  print ('Finished')

在运行的时候遇到了下面问题:

将keras的h5模型转换为tensorflow的pb模型操作

原因:我们训练模型的时候用save_weights函数保存模型,但是这个函数只保存了权重文件,并没有又保存模型的参数。要把save_weights改为save。

下边是两个函数介绍:

save()保存的模型结果,它既保持了模型的图结构,又保存了模型的参数。

save_weights()保存的模型结果,它只保存了模型的参数,但并没有保存模型的图结构

以上这篇将keras的h5模型转换为tensorflow的pb模型操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中字典的基础知识归纳小结
Aug 19 Python
解读Python编程中的命名空间与作用域
Oct 16 Python
举例讲解Python设计模式编程中对抽象工厂模式的运用
Mar 02 Python
Python爬取三国演义的实现方法
Sep 12 Python
python读取和保存图片5种方法对比
Sep 12 Python
Python一行代码实现快速排序的方法
Apr 30 Python
Python Django Vue 项目创建过程详解
Jul 29 Python
使用python os模块复制文件到指定文件夹的方法
Aug 22 Python
详解python datetime模块
Aug 17 Python
使用python对excel表格处理的一些小功能
Jan 25 Python
Django实现简单的分页功能
Feb 22 Python
Python探索生命起源 matplotlib细胞自动机动画演示
Apr 21 Python
tensorflow转换ckpt为savermodel模型的实现
May 25 #Python
基于Python把网站域名解析成ip地址
May 25 #Python
使用keras和tensorflow保存为可部署的pb格式
May 25 #Python
Python使用configparser读取ini配置文件
May 25 #Python
浅谈tensorflow模型保存为pb的各种姿势
May 25 #Python
详解tensorflow2.x版本无法调用gpu的一种解决方法
May 25 #Python
keras模型保存为tensorflow的二进制模型方式
May 25 #Python
You might like
pdo中使用参数化查询sql
2011/08/11 PHP
php摘要生成函数(无乱码)
2012/02/04 PHP
php格式化日期和时间格式化示例分享
2014/02/24 PHP
php使用for语句输出三角形的方法
2015/06/09 PHP
简单谈谈PHP中的Reload操作
2016/12/12 PHP
PHP使用POP3读取邮箱接收邮件的示例代码
2020/07/08 PHP
Prototype使用指南之array.js
2007/01/10 Javascript
如何书写高质量jQuery代码(使用jquery性能问题)
2014/06/30 Javascript
分享一则javascript 调试技巧
2015/01/02 Javascript
玩转JavaScript OOP - 类的实现详解
2016/06/08 Javascript
node.js express安装及示例网站搭建方法(分享)
2016/08/22 Javascript
jquery注册文本框获取焦点清空,失去焦点赋值的简单实例
2016/09/08 Javascript
ros::spin() 和 ros::spinOnce()函数的区别及详解
2016/10/01 Javascript
Nodejs之TCP服务端与客户端聊天程序详解
2017/07/07 NodeJs
JavaScript实现左侧菜单效果
2017/12/14 Javascript
Vue打包后出现一些map文件的解决方法
2018/02/13 Javascript
浅析Visual Studio Code断点调试Vue
2018/02/27 Javascript
原生JS实现的碰撞检测功能示例
2018/05/18 Javascript
微信小程序实现图片上传功能
2018/05/28 Javascript
Vuex的初探与实战小结
2018/11/26 Javascript
jQuery实现当拉动滚动条到底部加载数据的方法分析
2019/01/24 jQuery
详解微信图片防盗链“此图片来自微信公众平台 未经允许不得引用”的解决方案
2019/04/04 Javascript
[01:05]DOTA2完美大师赛趣味视频之选手教你打职业
2017/11/23 DOTA
Python实现的数据结构与算法之基本搜索详解
2015/04/22 Python
python实现requests发送/上传多个文件的示例
2018/06/04 Python
selenium python 实现基本自动化测试的示例代码
2019/02/25 Python
CSS3 实用技巧:实现黑白图像效果示例代码
2013/07/11 HTML / CSS
Mybag美国/加拿大:英国奢华包包和名牌手袋网站
2020/02/16 全球购物
广告学专业推荐信范文
2013/11/23 职场文书
工作年限证明模板
2014/11/01 职场文书
3.15消费者权益日活动总结
2015/02/09 职场文书
投诉信回复范文
2015/07/03 职场文书
新郎新娘致辞
2015/07/31 职场文书
干货干货!2019最新优秀创业计划书
2019/03/21 职场文书
Linux安装Nginx步骤详解
2021/03/31 Servers
MySQL 使用SQL语句修改表名的实现
2021/04/07 MySQL