keras模型保存为tensorflow的二进制模型方式


Posted in Python onMay 25, 2020

最近需要将使用keras训练的模型移植到手机上使用, 因此需要转换到tensorflow的二进制模型。

折腾一下午,终于找到一个合适的方法,废话不多说,直接上代码:

# coding=utf-8
import sys

from keras.models import load_model
import tensorflow as tf
import os
import os.path as osp
from keras import backend as K

def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
 """
 Freezes the state of a session into a prunned computation graph.

 Creates a new computation graph where variable nodes are replaced by
 constants taking their current value in the session. The new graph will be
 prunned so subgraphs that are not neccesary to compute the requested
 outputs are removed.
 @param session The TensorFlow session to be frozen.
 @param keep_var_names A list of variable names that should not be frozen,
       or None to freeze all the variables in the graph.
 @param output_names Names of the relevant graph outputs.
 @param clear_devices Remove the device directives from the graph for better portability.
 @return The frozen graph definition.
 """
 from tensorflow.python.framework.graph_util import convert_variables_to_constants
 graph = session.graph
 with graph.as_default():
  freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
  output_names = output_names or []
  output_names += [v.op.name for v in tf.global_variables()]
  input_graph_def = graph.as_graph_def()
  if clear_devices:
   for node in input_graph_def.node:
    node.device = ""
  frozen_graph = convert_variables_to_constants(session, input_graph_def,
              output_names, freeze_var_names)
  return frozen_graph

input_fld = sys.path[0]
weight_file = 'your_model.h5'
output_graph_name = 'tensor_model.pb'

output_fld = input_fld + '/tensorflow_model/'
if not os.path.isdir(output_fld):
 os.mkdir(output_fld)
weight_file_path = osp.join(input_fld, weight_file)

K.set_learning_phase(0)
net_model = load_model(weight_file_path)

print('input is :', net_model.input.name)
print ('output is:', net_model.output.name)

sess = K.get_session()

frozen_graph = freeze_session(K.get_session(), output_names=[net_model.output.op.name])

from tensorflow.python.framework import graph_io

graph_io.write_graph(frozen_graph, output_fld, output_graph_name, as_text=False)

print('saved the constant graph (ready for inference) at: ', osp.join(output_fld, output_graph_name))

上面代码实现保存到当前目录的tensor_model目录下。

验证:

import tensorflow as tf
import numpy as np
import PIL.Image as Image
import cv2

def recognize(jpg_path, pb_file_path):
 with tf.Graph().as_default():
  output_graph_def = tf.GraphDef()

  with open(pb_file_path, "rb") as f:
   output_graph_def.ParseFromString(f.read())
   tensors = tf.import_graph_def(output_graph_def, name="")
   print tensors

  with tf.Session() as sess:
   init = tf.global_variables_initializer()
   sess.run(init)

   op = sess.graph.get_operations()
   
   for m in op:
    print(m.values())

   input_x = sess.graph.get_tensor_by_name("convolution2d_1_input:0") #具体名称看上一段代码的input.name
   print input_x

   out_softmax = sess.graph.get_tensor_by_name("activation_4/Softmax:0") #具体名称看上一段代码的output.name

   print out_softmax

   img = cv2.imread(jpg_path, 0)
   img_out_softmax = sess.run(out_softmax,
          feed_dict={input_x: 1.0 - np.array(img).reshape((-1,28, 28, 1)) / 255.0})

   print "img_out_softmax:", img_out_softmax
   prediction_labels = np.argmax(img_out_softmax, axis=1)
   print "label:", prediction_labels

pb_path = 'tensorflow_model/constant_graph_weights.pb'
img = 'test/6/8_48.jpg'
recognize(img, pb_path)

补充知识:如何将keras训练好的模型转换成tensorflow的.pb的文件并在TensorFlow serving环境调用

首先keras训练好的模型通过自带的model.save()保存下来是 .model (.h5) 格式的文件

模型载入是通过 my_model = keras . models . load_model( filepath )

要将该模型转换为.pb 格式的TensorFlow 模型,代码如下:

# -*- coding: utf-8 -*-
from keras.layers.core import Activation, Dense, Flatten
from keras.layers.embeddings import Embedding
from keras.layers.recurrent import LSTM
from keras.layers import Dropout
from keras.layers.wrappers import Bidirectional
from keras.models import Sequential,load_model
from keras.preprocessing import sequence
from sklearn.model_selection import train_test_split
import collections
from collections import defaultdict
import jieba
import numpy as np
import sys
reload(sys)
sys.setdefaultencoding('utf-8')
import tensorflow as tf
import os
import os.path as osp
from keras import backend as K
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
 from tensorflow.python.framework.graph_util import convert_variables_to_constants
 graph = session.graph
 with graph.as_default():
  freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
  output_names = output_names or []
  output_names += [v.op.name for v in tf.global_variables()]
  input_graph_def = graph.as_graph_def()
  if clear_devices:
   for node in input_graph_def.node:
    node.device = ""
  frozen_graph = convert_variables_to_constants(session, input_graph_def,
              output_names, freeze_var_names)
  return frozen_graph
input_fld = '/data/codebase/Keyword-fenci/brand_recogniton_biLSTM/'
weight_file = 'biLSTM_brand_recognize.model'
output_graph_name = 'tensor_model_v3.pb'

output_fld = input_fld + '/tensorflow_model/'
if not os.path.isdir(output_fld):
 os.mkdir(output_fld)
weight_file_path = osp.join(input_fld, weight_file)

K.set_learning_phase(0)
net_model = load_model(weight_file_path)

print('input is :', net_model.input.name)
print ('output is:', net_model.output.name)

sess = K.get_session()

frozen_graph = freeze_session(K.get_session(), output_names=[net_model.output.op.name])
from tensorflow.python.framework import graph_io

graph_io.write_graph(frozen_graph, output_fld, output_graph_name, as_text=True)

print('saved the constant graph (ready for inference) at: ', osp.join(output_fld, output_graph_name))

然后模型就存成了.pb格式的文件

问题就来了,这样存下来的.pb格式的文件是frozen model

如果通过TensorFlow serving 启用模型的话,会报错:

E tensorflow_serving/core/aspired_versions_manager.cc:358] Servable {name: mnist version: 1} cannot be loaded: Not found: Could not find meta graph def matching supplied tags: { serve }. To inspect available tag-sets in the SavedModel, please use the SavedModel CLI: `saved_model_cli`

因为TensorFlow serving 希望读取的是saved model

于是需要将frozen model 转化为 saved model 格式,解决方案如下:

from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants

export_dir = '/data/codebase/Keyword-fenci/brand_recogniton_biLSTM/saved_model'
graph_pb = '/data/codebase/Keyword-fenci/brand_recogniton_biLSTM/tensorflow_model/tensor_model.pb'

builder = tf.saved_model.builder.SavedModelBuilder(export_dir)

with tf.gfile.GFile(graph_pb, "rb") as f:
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())

sigs = {}

with tf.Session(graph=tf.Graph()) as sess:
 # name="" is important to ensure we don't get spurious prefixing
 tf.import_graph_def(graph_def, name="")
 g = tf.get_default_graph()
 inp = g.get_tensor_by_name(net_model.input.name)
 out = g.get_tensor_by_name(net_model.output.name)

 sigs[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = \
  tf.saved_model.signature_def_utils.predict_signature_def(
   {"in": inp}, {"out": out})

 builder.add_meta_graph_and_variables(sess,
           [tag_constants.SERVING],
           signature_def_map=sigs)
builder.save()

于是保存下来的saved model 文件夹下就有两个文件:

saved_model.pb variables

其中variables 可以为空

于是将.pb 模型导入serving再读取,成功!

以上这篇keras模型保存为tensorflow的二进制模型方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python定时器使用示例分享
Feb 16 Python
python列表操作之extend和append的区别实例分析
Jul 28 Python
python正则中最短匹配实现代码
Jan 16 Python
python 时间信息“2018-02-04 18:23:35“ 解析成字典形式的结果代码详解
Apr 19 Python
python os.path模块常用方法实例详解
Sep 16 Python
使用coverage统计python web项目代码覆盖率的方法详解
Aug 05 Python
python Dijkstra算法实现最短路径问题的方法
Sep 19 Python
Python文本处理简单易懂方法解析
Dec 19 Python
Python计算机视觉里的IOU计算实例
Jan 17 Python
python3连接MySQL8.0的两种方式
Feb 17 Python
浅谈pandas dataframe对除数是零的处理
Jul 20 Python
Python3使用Selenium获取session和token方法详解
Feb 16 Python
keras 如何保存最佳的训练模型
May 25 #Python
keras处理欠拟合和过拟合的实例讲解
May 25 #Python
python如何调用字典的key
May 25 #Python
如何使用python的ctypes调用医保中心的dll动态库下载医保中心的账单
May 24 #Python
Python+PyQt5实现灭霸响指功能
May 25 #Python
PyQt5实现仿QQ贴边隐藏功能的实例代码
May 24 #Python
通过Python扫描代码关键字并进行预警的实现方法
May 24 #Python
You might like
php 移除数组重复元素的一点说明
2008/11/27 PHP
PHP+javascript制作带提示的验证码源码分享
2014/05/28 PHP
浅析ThinkPHP缓存之快速缓存(F方法)和动态缓存(S方法)(日常整理)
2015/10/26 PHP
PHP实现15位身份证号转18位的方法分析
2019/10/16 PHP
jquery EasyUI的formatter格式化函数代码
2011/01/12 Javascript
JavaScript快速检测浏览器对CSS3特性的支持情况
2012/09/26 Javascript
原生Js页面滚动延迟加载图片实现原理及过程
2013/06/24 Javascript
如何通过javascript操作web控件的自定义属性
2013/11/25 Javascript
jquery将一个表单序列化为一个对象的方法
2014/01/03 Javascript
dreamweaver 8实现Jquery自动提示
2014/12/04 Javascript
JS原型链怎么理解
2016/06/27 Javascript
js点击按钮实现水波纹效果代码(CSS3和Canves)
2016/09/15 Javascript
js数组与字符串常用方法总结
2017/01/13 Javascript
NodeJs安装npm包一直失败的解决方法
2017/04/28 NodeJs
JavaScript截屏功能的实现代码
2017/07/28 Javascript
解决Vue打包后访问图片/图标不显示的问题
2019/07/25 Javascript
javascript操作元素的常见方法小结
2019/11/13 Javascript
VueCli4项目配置反向代理proxy的方法步骤
2020/05/17 Javascript
在python中获取div的文本内容并和想定结果进行对比详解
2019/01/02 Python
对python mayavi三维绘图的实现详解
2019/01/08 Python
Python3 requests文件下载 期间显示文件信息和下载进度代码实例
2019/08/16 Python
python程序中的线程操作 concurrent模块使用详解
2019/09/23 Python
解决Python二维数组赋值问题
2019/11/28 Python
Python 如何查找特定类型文件
2020/08/17 Python
Html5插件教程之添加浏览器放大镜效果的商品橱窗
2016/01/07 HTML / CSS
韩国三星旗下的一家超市连锁店:Home Plus
2016/07/30 全球购物
牦牛毛户外探险服装:Kora
2019/02/08 全球购物
乌克兰在线电子产品商店:MTA
2019/11/14 全球购物
Java里面如何把一个Array数组转换成Collection, List
2013/07/26 面试题
金融专业毕业生推荐信
2013/11/26 职场文书
最新个人职业生涯规划书
2014/01/22 职场文书
服装设计专业毕业生求职信
2014/04/09 职场文书
大三学生英语考试作弊检讨书
2015/01/01 职场文书
团员自我评价范文
2015/03/10 职场文书
OpenCV中resize函数插值算法的实现过程(五种)
2021/06/05 Python
react中useState使用:如何实现在当前表格直接更改数据
2022/08/05 Javascript