tensorflow pb to tflite 精度下降详解


Posted in Python onMay 25, 2020

之前希望在手机端使用深度模型做OCR,于是尝试在手机端部署tensorflow模型,用于图像分类。

思路主要是想使用tflite部署到安卓端,但是在使用tflite的时候发现模型的精度大幅度下降,已经不能支持业务需求了,最后就把OCR模型调用写在服务端了,但是精度下降的原因目前也没有找到,现在这里记录一下。

工作思路:

1.训练图像分类模型;2.模型固化成pb;3.由pb转成tflite文件;

但是使用python 的tf interpreter 调用tflite文件就已经出现精度下降的问题,android端部署也是一样。

1.网络结构

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
 
import tensorflow as tf
slim = tf.contrib.slim
 
def ttnet(images, num_classes=10, is_training=False,
   dropout_keep_prob=0.5,
   prediction_fn=slim.softmax,
   scope='TtNet'):
 end_points = {}
 
 with tf.variable_scope(scope, 'TtNet', [images, num_classes]):
 net = slim.conv2d(images, 32, [3, 3], scope='conv1')
 # net = slim.conv2d(images, 64, [3, 3], scope='conv1_2')
 net = slim.max_pool2d(net, [2, 2], 2, scope='pool1')
 net = slim.batch_norm(net, activation_fn=tf.nn.relu, scope='bn1')
 # net = slim.conv2d(net, 128, [3, 3], scope='conv2_1')
 net = slim.conv2d(net, 64, [3, 3], scope='conv2')
 net = slim.max_pool2d(net, [2, 2], 2, scope='pool2')
 net = slim.conv2d(net, 128, [3, 3], scope='conv3')
 net = slim.max_pool2d(net, [2, 2], 2, scope='pool3')
 net = slim.conv2d(net, 256, [3, 3], scope='conv4')
 net = slim.max_pool2d(net, [2, 2], 2, scope='pool4')
 net = slim.batch_norm(net, activation_fn=tf.nn.relu, scope='bn2')
 # net = slim.conv2d(net, 512, [3, 3], scope='conv5')
 # net = slim.max_pool2d(net, [2, 2], 2, scope='pool5')
 net = slim.flatten(net)
 end_points['Flatten'] = net
 
 # net = slim.fully_connected(net, 1024, scope='fc3')
 net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
      scope='dropout3')
 logits = slim.fully_connected(net, num_classes, activation_fn=None,
         scope='fc4') 
 end_points['Logits'] = logits
 end_points['Predictions'] = prediction_fn(logits, scope='Predictions')
 
 return logits, end_points
ttnet.default_image_size = 28
 
def ttnet_arg_scope(weight_decay=0.0):
 with slim.arg_scope(
  [slim.conv2d, slim.fully_connected],
  weights_regularizer=slim.l2_regularizer(weight_decay),
  weights_initializer=tf.truncated_normal_initializer(stddev=0.1),
  activation_fn=tf.nn.relu) as sc:
 return sc

基于slim,由于是一个比较简单的分类问题,网络结构也很简单,几个卷积加池化。

测试效果是很棒的。真实样本测试集能达到99%+的准确率。

2.模型固化,生成pb文件

#coding:utf-8
 
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from nets import nets_factory
import cv2
import os
import numpy as np
from datasets import dataset_factory
from preprocessing import preprocessing_factory
from tensorflow.python.platform import gfile
slim = tf.contrib.slim
#todo
#support arbitray image size and num_class
 
tf.app.flags.DEFINE_string(
 'checkpoint_path', '/tmp/tfmodel/',
 'The directory where the model was written to or an absolute path to a '
 'checkpoint file.')
 
tf.app.flags.DEFINE_string(
 'model_name', 'inception_v3', 'The name of the architecture to evaluate.')
tf.app.flags.DEFINE_string(
 'preprocessing_name', None, 'The name of the preprocessing to use. If left '
 'as `None`, then the model_name flag is used.')
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_integer(
 'eval_image_size', None, 'Eval image size')
tf.app.flags.DEFINE_integer(
 'eval_image_height', None, 'Eval image height')
tf.app.flags.DEFINE_integer(
 'eval_image_width', None, 'Eval image width')
tf.app.flags.DEFINE_string(
 'export_path', './ttnet_1.0_37_32.pb', 'the export path of the pd file')
FLAGS = tf.app.flags.FLAGS
NUM_CLASSES = 37
 
def main(_):
 network_fn = nets_factory.get_network_fn(
  FLAGS.model_name,
  num_classes=NUM_CLASSES,
  is_training=False)
 # pre_image = tf.placeholder(tf.float32, [None, None, 3], name='input_data')
 # preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
 # image_preprocessing_fn = preprocessing_factory.get_preprocessing(
 #  preprocessing_name,
 #  is_training=False)
 # image = image_preprocessing_fn(pre_image, FLAGS.eval_image_height, FLAGS.eval_image_width)
 # images2 = tf.expand_dims(image, 0)
 images2 = tf.placeholder(tf.float32, (None,32, 32, 3),name='input_data')
 logits, endpoints = network_fn(images2)
 with tf.Session() as sess:
 output = tf.identity(endpoints['Predictions'],name="output_data")
 with gfile.GFile(FLAGS.export_path, 'wb') as f:
  f.write(sess.graph_def.SerializeToString())
 
if __name__ == '__main__':
 tf.app.run()

3.生成tflite文件

import tensorflow as tf
 
graph_def_file = "/datastore1/Colonist_Lord/Colonist_Lord/workspace/models/model_files/passport_model_with_tflite/ocr_frozen.pb"
input_arrays = ["input_data"]
output_arrays = ["output_data"]
 
converter = tf.lite.TFLiteConverter.from_frozen_graph(
 graph_def_file, input_arrays, output_arrays)
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)

使用pb文件进行测试,效果正常;使用tflite文件进行测试,精度下降严重。下面附上pb与tflite测试代码。

pb测试代码

with tf.gfile.GFile(graph_filename, "rb") as f:
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())
 
with tf.Graph().as_default() as graph:
 tf.import_graph_def(graph_def)
 input_node = graph.get_tensor_by_name('import/input_data:0')
 output_node = graph.get_tensor_by_name('import/output_data:0')
 with tf.Session() as sess:
  for image_file in image_files:
   abs_path = os.path.join(image_folder, image_file)
   img = cv2.imread(abs_path).astype(np.float32)
   img = cv2.resize(img, (int(input_node.shape[1]), int(input_node.shape[2])))
   output_data = sess.run(output_node, feed_dict={input_node: [img]})
   index = np.argmax(output_data)
   label = dict_laebl[index]
   dst_floder = os.path.join(result_folder, label)
   if not os.path.exists(dst_floder):
    os.mkdir(dst_floder)
   cv2.imwrite(os.path.join(dst_floder, image_file), img)
   count += 1

tflite测试代码

model_path = "converted_model.tflite" #"/datastore1/Colonist_Lord/Colonist_Lord/data/passport_char/ocr.tflite"
interpreter = tf.contrib.lite.Interpreter(model_path=model_path)
interpreter.allocate_tensors()
 
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
for image_file in image_files:
 abs_path = os.path.join(image_folder,image_file)
 img = cv2.imread(abs_path).astype(np.float32)
 img = cv2.resize(img, tuple(input_details[0]['shape'][1:3]))
 # input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
 interpreter.set_tensor(input_details[0]['index'], [img])
 
 interpreter.invoke()
 output_data = interpreter.get_tensor(output_details[0]['index'])
 index = np.argmax(output_data)
 label = dict_laebl[index]
 dst_floder = os.path.join(result_folder,label)
 if not os.path.exists(dst_floder):
  os.mkdir(dst_floder)
 cv2.imwrite(os.path.join(dst_floder,image_file),img)
 count+=1

最后也算是绕过这个问题解决了业务需求,后面有空的话,还是会花时间研究一下这个问题。

如果有哪个大佬知道原因,希望不吝赐教。

补充知识:.pb 转tflite代码,使用量化,减小体积,converter.post_training_quantize = True

import tensorflow as tf

path = "/home/python/Downloads/a.pb" # pb文件位置和文件名
inputs = ["input_images"] # 模型文件的输入节点名称
classes = ['feature_fusion/Conv_7/Sigmoid','feature_fusion/concat_3'] # 模型文件的输出节点名称
# converter = tf.contrib.lite.TocoConverter.from_frozen_graph(path, inputs, classes, input_shapes={'input_images':[1, 320, 320, 3]})
converter = tf.lite.TFLiteConverter.from_frozen_graph(path, inputs, classes,
              input_shapes={'input_images': [1, 320, 320, 3]})
converter.post_training_quantize = True
tflite_model = converter.convert()
open("/home/python/Downloads/aNew.tflite", "wb").write(tflite_model)

以上这篇tensorflow pb to tflite 精度下降详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python正则表达式去掉数字中的逗号(python正则匹配逗号)
Dec 25 Python
Python爬取网易云音乐热门评论
Mar 31 Python
Python对象类型及其运算方法(详解)
Jul 05 Python
python机器学习之神经网络(一)
Dec 20 Python
浅谈python 读excel数值为浮点型的问题
Dec 25 Python
Django项目之Elasticsearch搜索引擎的实例
Aug 21 Python
基于python2.7实现图形密码生成器的实例代码
Nov 05 Python
浅析NumPy 切片和索引
Sep 02 Python
Python基于staticmethod装饰器标示静态方法
Oct 17 Python
tensorflow与numpy的版本兼容性问题的解决
Jan 08 Python
Python实现简单猜数字游戏
Feb 03 Python
Python可视化神器pyecharts之绘制箱形图
Jul 07 Python
Python HTMLTestRunner测试报告view按钮失效解决方案
May 25 #Python
python用opencv完成图像分割并进行目标物的提取
May 25 #Python
Pytorch转tflite方式
May 25 #Python
Python HTMLTestRunner库安装过程解析
May 25 #Python
Anaconda+vscode+pytorch环境搭建过程详解
May 25 #Python
5行Python代码实现图像分割的步骤详解
May 25 #Python
Win10用vscode打开anaconda环境中的python出错问题的解决
May 25 #Python
You might like
php代码收集表单内容并写入文件的代码
2012/01/29 PHP
如何阻止网站被恶意反向代理访问(防网站镜像)
2014/03/18 PHP
php打乱数组二维数组多维数组的简单实例
2016/06/17 PHP
php实现支持中文的文件下载功能示例
2017/08/30 PHP
js 上传图片预览问题
2010/12/06 Javascript
jquery 回车事件实现代码
2011/08/23 Javascript
input链接页面、打开新网页等等的具体实现
2013/12/30 Javascript
jquery $("#variable") 循环改变variable的值示例
2014/02/23 Javascript
父页面显示遮罩层弹出半透明状态的dialog
2014/03/04 Javascript
JavaScript也谈内存优化
2014/06/06 Javascript
jquery插件方式实现table查询功能的简单实例
2016/06/06 Javascript
微信小程序request出现400的问题解决办法
2017/05/23 Javascript
JS实现上传图片的三种方法并实现预览图片功能
2017/07/14 Javascript
JS获取字符对应的ASCII码实例
2017/09/10 Javascript
基于原生js运动方式关键点的总结(推荐)
2017/10/01 Javascript
基于webpack4.X从零搭建React脚手架的方法步骤
2018/12/23 Javascript
详解微信小程序支付流程与梳理
2019/07/16 Javascript
python中 ? : 三元表达式的使用介绍
2013/10/09 Python
在Django中同时使用多个配置文件的方法
2015/07/22 Python
Python 专题六 局部变量、全局变量global、导入模块变量
2017/03/20 Python
python实现的二叉树定义与遍历算法实例
2017/06/30 Python
Python cookbook(数据结构与算法)找出序列中出现次数最多的元素算法示例
2018/03/15 Python
pytorch构建网络模型的4种方法
2018/04/13 Python
python 脚本生成随机 字母 + 数字密码功能
2018/05/26 Python
Selenium的使用详解
2018/10/19 Python
python requests库爬取豆瓣电视剧数据并保存到本地详解
2019/08/10 Python
时尚的CSS3进度条效果
2012/02/22 HTML / CSS
纯CSS3实现的井字棋游戏
2020/11/25 HTML / CSS
特步官方商城:Xtep
2017/03/21 全球购物
Bogner美国官网:滑雪服中的”Dior”
2018/01/30 全球购物
个人简历自荐信
2013/12/05 职场文书
美容师的职业规划书
2013/12/27 职场文书
百货商场楼层班组长竞聘书
2014/03/31 职场文书
品酒会策划方案
2014/05/26 职场文书
2015年元宵节活动总结
2015/02/06 职场文书
《学会生存》读后感3篇
2019/12/09 职场文书