keras模型保存为tensorflow的二进制模型方式


Posted in Python onMay 25, 2020

最近需要将使用keras训练的模型移植到手机上使用, 因此需要转换到tensorflow的二进制模型。

折腾一下午,终于找到一个合适的方法,废话不多说,直接上代码:

# coding=utf-8
import sys

from keras.models import load_model
import tensorflow as tf
import os
import os.path as osp
from keras import backend as K

def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
 """
 Freezes the state of a session into a prunned computation graph.

 Creates a new computation graph where variable nodes are replaced by
 constants taking their current value in the session. The new graph will be
 prunned so subgraphs that are not neccesary to compute the requested
 outputs are removed.
 @param session The TensorFlow session to be frozen.
 @param keep_var_names A list of variable names that should not be frozen,
       or None to freeze all the variables in the graph.
 @param output_names Names of the relevant graph outputs.
 @param clear_devices Remove the device directives from the graph for better portability.
 @return The frozen graph definition.
 """
 from tensorflow.python.framework.graph_util import convert_variables_to_constants
 graph = session.graph
 with graph.as_default():
  freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
  output_names = output_names or []
  output_names += [v.op.name for v in tf.global_variables()]
  input_graph_def = graph.as_graph_def()
  if clear_devices:
   for node in input_graph_def.node:
    node.device = ""
  frozen_graph = convert_variables_to_constants(session, input_graph_def,
              output_names, freeze_var_names)
  return frozen_graph

input_fld = sys.path[0]
weight_file = 'your_model.h5'
output_graph_name = 'tensor_model.pb'

output_fld = input_fld + '/tensorflow_model/'
if not os.path.isdir(output_fld):
 os.mkdir(output_fld)
weight_file_path = osp.join(input_fld, weight_file)

K.set_learning_phase(0)
net_model = load_model(weight_file_path)

print('input is :', net_model.input.name)
print ('output is:', net_model.output.name)

sess = K.get_session()

frozen_graph = freeze_session(K.get_session(), output_names=[net_model.output.op.name])

from tensorflow.python.framework import graph_io

graph_io.write_graph(frozen_graph, output_fld, output_graph_name, as_text=False)

print('saved the constant graph (ready for inference) at: ', osp.join(output_fld, output_graph_name))

上面代码实现保存到当前目录的tensor_model目录下。

验证:

import tensorflow as tf
import numpy as np
import PIL.Image as Image
import cv2

def recognize(jpg_path, pb_file_path):
 with tf.Graph().as_default():
  output_graph_def = tf.GraphDef()

  with open(pb_file_path, "rb") as f:
   output_graph_def.ParseFromString(f.read())
   tensors = tf.import_graph_def(output_graph_def, name="")
   print tensors

  with tf.Session() as sess:
   init = tf.global_variables_initializer()
   sess.run(init)

   op = sess.graph.get_operations()
   
   for m in op:
    print(m.values())

   input_x = sess.graph.get_tensor_by_name("convolution2d_1_input:0") #具体名称看上一段代码的input.name
   print input_x

   out_softmax = sess.graph.get_tensor_by_name("activation_4/Softmax:0") #具体名称看上一段代码的output.name

   print out_softmax

   img = cv2.imread(jpg_path, 0)
   img_out_softmax = sess.run(out_softmax,
          feed_dict={input_x: 1.0 - np.array(img).reshape((-1,28, 28, 1)) / 255.0})

   print "img_out_softmax:", img_out_softmax
   prediction_labels = np.argmax(img_out_softmax, axis=1)
   print "label:", prediction_labels

pb_path = 'tensorflow_model/constant_graph_weights.pb'
img = 'test/6/8_48.jpg'
recognize(img, pb_path)

补充知识:如何将keras训练好的模型转换成tensorflow的.pb的文件并在TensorFlow serving环境调用

首先keras训练好的模型通过自带的model.save()保存下来是 .model (.h5) 格式的文件

模型载入是通过 my_model = keras . models . load_model( filepath )

要将该模型转换为.pb 格式的TensorFlow 模型,代码如下:

# -*- coding: utf-8 -*-
from keras.layers.core import Activation, Dense, Flatten
from keras.layers.embeddings import Embedding
from keras.layers.recurrent import LSTM
from keras.layers import Dropout
from keras.layers.wrappers import Bidirectional
from keras.models import Sequential,load_model
from keras.preprocessing import sequence
from sklearn.model_selection import train_test_split
import collections
from collections import defaultdict
import jieba
import numpy as np
import sys
reload(sys)
sys.setdefaultencoding('utf-8')
import tensorflow as tf
import os
import os.path as osp
from keras import backend as K
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
 from tensorflow.python.framework.graph_util import convert_variables_to_constants
 graph = session.graph
 with graph.as_default():
  freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
  output_names = output_names or []
  output_names += [v.op.name for v in tf.global_variables()]
  input_graph_def = graph.as_graph_def()
  if clear_devices:
   for node in input_graph_def.node:
    node.device = ""
  frozen_graph = convert_variables_to_constants(session, input_graph_def,
              output_names, freeze_var_names)
  return frozen_graph
input_fld = '/data/codebase/Keyword-fenci/brand_recogniton_biLSTM/'
weight_file = 'biLSTM_brand_recognize.model'
output_graph_name = 'tensor_model_v3.pb'

output_fld = input_fld + '/tensorflow_model/'
if not os.path.isdir(output_fld):
 os.mkdir(output_fld)
weight_file_path = osp.join(input_fld, weight_file)

K.set_learning_phase(0)
net_model = load_model(weight_file_path)

print('input is :', net_model.input.name)
print ('output is:', net_model.output.name)

sess = K.get_session()

frozen_graph = freeze_session(K.get_session(), output_names=[net_model.output.op.name])
from tensorflow.python.framework import graph_io

graph_io.write_graph(frozen_graph, output_fld, output_graph_name, as_text=True)

print('saved the constant graph (ready for inference) at: ', osp.join(output_fld, output_graph_name))

然后模型就存成了.pb格式的文件

问题就来了,这样存下来的.pb格式的文件是frozen model

如果通过TensorFlow serving 启用模型的话,会报错:

E tensorflow_serving/core/aspired_versions_manager.cc:358] Servable {name: mnist version: 1} cannot be loaded: Not found: Could not find meta graph def matching supplied tags: { serve }. To inspect available tag-sets in the SavedModel, please use the SavedModel CLI: `saved_model_cli`

因为TensorFlow serving 希望读取的是saved model

于是需要将frozen model 转化为 saved model 格式,解决方案如下:

from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants

export_dir = '/data/codebase/Keyword-fenci/brand_recogniton_biLSTM/saved_model'
graph_pb = '/data/codebase/Keyword-fenci/brand_recogniton_biLSTM/tensorflow_model/tensor_model.pb'

builder = tf.saved_model.builder.SavedModelBuilder(export_dir)

with tf.gfile.GFile(graph_pb, "rb") as f:
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())

sigs = {}

with tf.Session(graph=tf.Graph()) as sess:
 # name="" is important to ensure we don't get spurious prefixing
 tf.import_graph_def(graph_def, name="")
 g = tf.get_default_graph()
 inp = g.get_tensor_by_name(net_model.input.name)
 out = g.get_tensor_by_name(net_model.output.name)

 sigs[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = \
  tf.saved_model.signature_def_utils.predict_signature_def(
   {"in": inp}, {"out": out})

 builder.add_meta_graph_and_variables(sess,
           [tag_constants.SERVING],
           signature_def_map=sigs)
builder.save()

于是保存下来的saved model 文件夹下就有两个文件:

saved_model.pb variables

其中variables 可以为空

于是将.pb 模型导入serving再读取,成功!

以上这篇keras模型保存为tensorflow的二进制模型方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python3实现将文件归档到zip文件及从zip文件中读取数据的方法
May 22 Python
python中利用await关键字如何等待Future对象完成详解
Sep 07 Python
Python编程之基于概率论的分类方法:朴素贝叶斯
Nov 11 Python
tensorflow 打印内存中的变量方法
Jul 30 Python
matplotlib调整子图间距,调整整体空白的方法
Aug 03 Python
Python正则表达式和re库知识点总结
Feb 11 Python
Python除法之传统除法、Floor除法及真除法实例详解
May 23 Python
pandas数据筛选和csv操作的实现方法
Jul 02 Python
python 伯努利分布详解
Feb 25 Python
python实现密码验证合格程序的思路详解
Jun 01 Python
Python SMTP发送电子邮件的示例
Sep 23 Python
Python3爬虫ChromeDriver的安装实例
Feb 06 Python
keras 如何保存最佳的训练模型
May 25 #Python
keras处理欠拟合和过拟合的实例讲解
May 25 #Python
python如何调用字典的key
May 25 #Python
如何使用python的ctypes调用医保中心的dll动态库下载医保中心的账单
May 24 #Python
Python+PyQt5实现灭霸响指功能
May 25 #Python
PyQt5实现仿QQ贴边隐藏功能的实例代码
May 24 #Python
通过Python扫描代码关键字并进行预警的实现方法
May 24 #Python
You might like
人大复印资料处理程序_补充篇
2006/10/09 PHP
PHP 操作文件的一些FAQ总结
2009/02/12 PHP
php daodb插入、更新与删除数据
2009/03/19 PHP
让CodeIgniter数据库缓存自动过期的处理的方法
2014/06/12 PHP
php中chdir()函数用法实例
2014/11/13 PHP
php获取用户浏览器版本的方法
2015/01/03 PHP
php计算函数执行时间的方法
2015/03/20 PHP
extjs 学习笔记(一) 一些基础知识
2009/10/13 Javascript
JS中Iframe之间传值及子页面与父页面应用
2013/03/11 Javascript
js 窗口抖动示例
2013/09/04 Javascript
如何用jquery控制表格奇偶行及活动行颜色
2014/04/20 Javascript
JavaScript indexOf方法入门实例(计算指定字符在字符串中首次出现的位置)
2014/10/17 Javascript
使用js画图之圆、弧、扇形
2015/01/12 Javascript
jQuery+PHP实现动态数字展示特效
2015/03/14 Javascript
jquery实现简单手风琴菜单效果实例
2015/06/13 Javascript
javascript比较两个日期相差天数的方法
2015/07/24 Javascript
javascript中闭包(Closure)详解
2016/01/06 Javascript
深入理解setTimeout函数和setInterval函数
2016/05/20 Javascript
JavaScript奇技淫巧44招【实用】
2016/12/11 Javascript
JavaScript中各数制转换全面总结
2017/08/21 Javascript
利用Vue2.x开发实现JSON树的方法
2018/01/04 Javascript
通过jquery获取上传文件名称、类型和大小的实现代码
2018/04/19 jQuery
JS JQuery获取data-*属性值方法解析
2020/09/01 jQuery
微信小程序实现点击生成随机验证码
2020/09/09 Javascript
[41:37]DOTA2北京网鱼队选拔赛——冲击职业之路
2015/04/13 DOTA
[01:53]2016完美“圣”典风云人物:Maybe专访
2016/12/05 DOTA
pycharm设置鼠标悬停查看方法设置
2019/07/29 Python
Django shell调试models输出的SQL语句方法
2019/08/29 Python
Python类中方法getitem和getattr详解
2019/08/30 Python
使用Python将字符串转换为格式化的日期时间字符串
2019/09/01 Python
基于python实现微信好友数据分析(简单)
2020/02/16 Python
党员培训思想汇报
2014/01/07 职场文书
校园安全广播稿范文
2014/09/25 职场文书
贴吧吧主申请感言
2015/08/03 职场文书
python自动统计zabbix系统监控覆盖率的示例代码
2021/04/03 Python
CSS中calc(100%-100px)不加空格不生效
2023/05/07 HTML / CSS