tensorflow pb to tflite 精度下降详解


Posted in Python onMay 25, 2020

之前希望在手机端使用深度模型做OCR,于是尝试在手机端部署tensorflow模型,用于图像分类。

思路主要是想使用tflite部署到安卓端,但是在使用tflite的时候发现模型的精度大幅度下降,已经不能支持业务需求了,最后就把OCR模型调用写在服务端了,但是精度下降的原因目前也没有找到,现在这里记录一下。

工作思路:

1.训练图像分类模型;2.模型固化成pb;3.由pb转成tflite文件;

但是使用python 的tf interpreter 调用tflite文件就已经出现精度下降的问题,android端部署也是一样。

1.网络结构

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
 
import tensorflow as tf
slim = tf.contrib.slim
 
def ttnet(images, num_classes=10, is_training=False,
   dropout_keep_prob=0.5,
   prediction_fn=slim.softmax,
   scope='TtNet'):
 end_points = {}
 
 with tf.variable_scope(scope, 'TtNet', [images, num_classes]):
 net = slim.conv2d(images, 32, [3, 3], scope='conv1')
 # net = slim.conv2d(images, 64, [3, 3], scope='conv1_2')
 net = slim.max_pool2d(net, [2, 2], 2, scope='pool1')
 net = slim.batch_norm(net, activation_fn=tf.nn.relu, scope='bn1')
 # net = slim.conv2d(net, 128, [3, 3], scope='conv2_1')
 net = slim.conv2d(net, 64, [3, 3], scope='conv2')
 net = slim.max_pool2d(net, [2, 2], 2, scope='pool2')
 net = slim.conv2d(net, 128, [3, 3], scope='conv3')
 net = slim.max_pool2d(net, [2, 2], 2, scope='pool3')
 net = slim.conv2d(net, 256, [3, 3], scope='conv4')
 net = slim.max_pool2d(net, [2, 2], 2, scope='pool4')
 net = slim.batch_norm(net, activation_fn=tf.nn.relu, scope='bn2')
 # net = slim.conv2d(net, 512, [3, 3], scope='conv5')
 # net = slim.max_pool2d(net, [2, 2], 2, scope='pool5')
 net = slim.flatten(net)
 end_points['Flatten'] = net
 
 # net = slim.fully_connected(net, 1024, scope='fc3')
 net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
      scope='dropout3')
 logits = slim.fully_connected(net, num_classes, activation_fn=None,
         scope='fc4') 
 end_points['Logits'] = logits
 end_points['Predictions'] = prediction_fn(logits, scope='Predictions')
 
 return logits, end_points
ttnet.default_image_size = 28
 
def ttnet_arg_scope(weight_decay=0.0):
 with slim.arg_scope(
  [slim.conv2d, slim.fully_connected],
  weights_regularizer=slim.l2_regularizer(weight_decay),
  weights_initializer=tf.truncated_normal_initializer(stddev=0.1),
  activation_fn=tf.nn.relu) as sc:
 return sc

基于slim,由于是一个比较简单的分类问题,网络结构也很简单,几个卷积加池化。

测试效果是很棒的。真实样本测试集能达到99%+的准确率。

2.模型固化,生成pb文件

#coding:utf-8
 
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from nets import nets_factory
import cv2
import os
import numpy as np
from datasets import dataset_factory
from preprocessing import preprocessing_factory
from tensorflow.python.platform import gfile
slim = tf.contrib.slim
#todo
#support arbitray image size and num_class
 
tf.app.flags.DEFINE_string(
 'checkpoint_path', '/tmp/tfmodel/',
 'The directory where the model was written to or an absolute path to a '
 'checkpoint file.')
 
tf.app.flags.DEFINE_string(
 'model_name', 'inception_v3', 'The name of the architecture to evaluate.')
tf.app.flags.DEFINE_string(
 'preprocessing_name', None, 'The name of the preprocessing to use. If left '
 'as `None`, then the model_name flag is used.')
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_integer(
 'eval_image_size', None, 'Eval image size')
tf.app.flags.DEFINE_integer(
 'eval_image_height', None, 'Eval image height')
tf.app.flags.DEFINE_integer(
 'eval_image_width', None, 'Eval image width')
tf.app.flags.DEFINE_string(
 'export_path', './ttnet_1.0_37_32.pb', 'the export path of the pd file')
FLAGS = tf.app.flags.FLAGS
NUM_CLASSES = 37
 
def main(_):
 network_fn = nets_factory.get_network_fn(
  FLAGS.model_name,
  num_classes=NUM_CLASSES,
  is_training=False)
 # pre_image = tf.placeholder(tf.float32, [None, None, 3], name='input_data')
 # preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
 # image_preprocessing_fn = preprocessing_factory.get_preprocessing(
 #  preprocessing_name,
 #  is_training=False)
 # image = image_preprocessing_fn(pre_image, FLAGS.eval_image_height, FLAGS.eval_image_width)
 # images2 = tf.expand_dims(image, 0)
 images2 = tf.placeholder(tf.float32, (None,32, 32, 3),name='input_data')
 logits, endpoints = network_fn(images2)
 with tf.Session() as sess:
 output = tf.identity(endpoints['Predictions'],name="output_data")
 with gfile.GFile(FLAGS.export_path, 'wb') as f:
  f.write(sess.graph_def.SerializeToString())
 
if __name__ == '__main__':
 tf.app.run()

3.生成tflite文件

import tensorflow as tf
 
graph_def_file = "/datastore1/Colonist_Lord/Colonist_Lord/workspace/models/model_files/passport_model_with_tflite/ocr_frozen.pb"
input_arrays = ["input_data"]
output_arrays = ["output_data"]
 
converter = tf.lite.TFLiteConverter.from_frozen_graph(
 graph_def_file, input_arrays, output_arrays)
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)

使用pb文件进行测试,效果正常;使用tflite文件进行测试,精度下降严重。下面附上pb与tflite测试代码。

pb测试代码

with tf.gfile.GFile(graph_filename, "rb") as f:
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())
 
with tf.Graph().as_default() as graph:
 tf.import_graph_def(graph_def)
 input_node = graph.get_tensor_by_name('import/input_data:0')
 output_node = graph.get_tensor_by_name('import/output_data:0')
 with tf.Session() as sess:
  for image_file in image_files:
   abs_path = os.path.join(image_folder, image_file)
   img = cv2.imread(abs_path).astype(np.float32)
   img = cv2.resize(img, (int(input_node.shape[1]), int(input_node.shape[2])))
   output_data = sess.run(output_node, feed_dict={input_node: [img]})
   index = np.argmax(output_data)
   label = dict_laebl[index]
   dst_floder = os.path.join(result_folder, label)
   if not os.path.exists(dst_floder):
    os.mkdir(dst_floder)
   cv2.imwrite(os.path.join(dst_floder, image_file), img)
   count += 1

tflite测试代码

model_path = "converted_model.tflite" #"/datastore1/Colonist_Lord/Colonist_Lord/data/passport_char/ocr.tflite"
interpreter = tf.contrib.lite.Interpreter(model_path=model_path)
interpreter.allocate_tensors()
 
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
for image_file in image_files:
 abs_path = os.path.join(image_folder,image_file)
 img = cv2.imread(abs_path).astype(np.float32)
 img = cv2.resize(img, tuple(input_details[0]['shape'][1:3]))
 # input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
 interpreter.set_tensor(input_details[0]['index'], [img])
 
 interpreter.invoke()
 output_data = interpreter.get_tensor(output_details[0]['index'])
 index = np.argmax(output_data)
 label = dict_laebl[index]
 dst_floder = os.path.join(result_folder,label)
 if not os.path.exists(dst_floder):
  os.mkdir(dst_floder)
 cv2.imwrite(os.path.join(dst_floder,image_file),img)
 count+=1

最后也算是绕过这个问题解决了业务需求,后面有空的话,还是会花时间研究一下这个问题。

如果有哪个大佬知道原因,希望不吝赐教。

补充知识:.pb 转tflite代码,使用量化,减小体积,converter.post_training_quantize = True

import tensorflow as tf

path = "/home/python/Downloads/a.pb" # pb文件位置和文件名
inputs = ["input_images"] # 模型文件的输入节点名称
classes = ['feature_fusion/Conv_7/Sigmoid','feature_fusion/concat_3'] # 模型文件的输出节点名称
# converter = tf.contrib.lite.TocoConverter.from_frozen_graph(path, inputs, classes, input_shapes={'input_images':[1, 320, 320, 3]})
converter = tf.lite.TFLiteConverter.from_frozen_graph(path, inputs, classes,
              input_shapes={'input_images': [1, 320, 320, 3]})
converter.post_training_quantize = True
tflite_model = converter.convert()
open("/home/python/Downloads/aNew.tflite", "wb").write(tflite_model)

以上这篇tensorflow pb to tflite 精度下降详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现简单文本字符串处理的方法
Jan 22 Python
Django框架使用富文本编辑器Uedit的方法分析
Jul 31 Python
Python Pywavelet 小波阈值实例
Jan 09 Python
详解Python正则表达式re模块
Mar 19 Python
NumPy 基本切片和索引的具体使用方法
Apr 24 Python
python 将字符串完成特定的向右移动方法
Jun 11 Python
python 反编译exe文件为py文件的实例代码
Jun 27 Python
深入浅析Python 函数注解与匿名函数
Feb 24 Python
解决pycharm每次打开项目都需要配置解释器和安装库问题
Feb 26 Python
在tensorflow下利用plt画论文中loss,acc等曲线图实例
Jun 15 Python
python给视频添加背景音乐并改变音量的具体方法
Jul 19 Python
python ConfigParser库的使用及遇到的坑
Feb 12 Python
Python HTMLTestRunner测试报告view按钮失效解决方案
May 25 #Python
python用opencv完成图像分割并进行目标物的提取
May 25 #Python
Pytorch转tflite方式
May 25 #Python
Python HTMLTestRunner库安装过程解析
May 25 #Python
Anaconda+vscode+pytorch环境搭建过程详解
May 25 #Python
5行Python代码实现图像分割的步骤详解
May 25 #Python
Win10用vscode打开anaconda环境中的python出错问题的解决
May 25 #Python
You might like
php判断输入不超过mysql的varchar字段的长度范围
2011/06/24 PHP
php 目录遍历、删除 函数的使用介绍
2013/04/28 PHP
PHP通过插入mysql数据来实现多机互锁实例
2014/11/05 PHP
基于PHP实现邮箱验证激活过程详解
2020/10/28 PHP
尽可能写"友好"的"Javascript"代码
2007/01/09 Javascript
js left,right,mid函数
2008/06/10 Javascript
Jquery注册事件实现方法
2015/05/18 Javascript
jqueryMobile使用示例分享
2016/01/12 Javascript
浅谈js数组和splice的用法
2016/12/04 Javascript
Vue组件之全局组件与局部组件的使用详解
2017/10/09 Javascript
vue中动态绑定表单元素的属性方法
2018/02/23 Javascript
vue.js添加一些触摸事件以及安装fastclick的实例
2018/08/28 Javascript
解决vue 绑定对象内点击事件失效问题
2018/09/05 Javascript
vue-vuex中使用commit提交mutation来修改state的方法详解
2018/09/16 Javascript
详解Vue底部导航栏组件
2019/05/02 Javascript
24个解决实际问题的ES6代码片段(小结)
2020/02/02 Javascript
ES6中的类(Class)示例详解
2020/12/09 Javascript
使用python编写android截屏脚本双击运行即可
2014/07/21 Python
python中列表和元组的区别
2017/12/18 Python
windows下python安装pip图文教程
2018/05/25 Python
基于Python pip用国内镜像下载的方法
2018/06/12 Python
Django配置celery(非djcelery)执行异步任务和定时任务
2018/07/16 Python
python保存二维数组到txt文件中的方法
2018/11/15 Python
python中的函数递归和迭代原理解析
2019/11/14 Python
PyTorch实现AlexNet示例
2020/01/14 Python
windows下python安装pip方法详解
2020/02/10 Python
python 在右键菜单中加入复制目标文件的有效存放路径(单斜杠或者双反斜杠)
2020/04/08 Python
python将logging模块封装成单独模块并实现动态切换Level方式
2020/05/12 Python
浅谈tensorflow中dataset.shuffle和dataset.batch dataset.repeat注意点
2020/06/08 Python
python向xls写入数据(包括合并,边框,对齐,列宽)
2021/02/02 Python
HTML5新增加的功能详解
2016/09/05 HTML / CSS
五星红旗迎风飘扬观后感
2015/06/17 职场文书
经典祝酒词大全
2015/08/12 职场文书
晶体管单管来复再生式收音机
2021/04/22 无线电
解决python存数据库速度太慢的问题
2021/04/23 Python
MySQL优化及索引解析
2022/03/17 MySQL