tensorflow pb to tflite 精度下降详解


Posted in Python onMay 25, 2020

之前希望在手机端使用深度模型做OCR,于是尝试在手机端部署tensorflow模型,用于图像分类。

思路主要是想使用tflite部署到安卓端,但是在使用tflite的时候发现模型的精度大幅度下降,已经不能支持业务需求了,最后就把OCR模型调用写在服务端了,但是精度下降的原因目前也没有找到,现在这里记录一下。

工作思路:

1.训练图像分类模型;2.模型固化成pb;3.由pb转成tflite文件;

但是使用python 的tf interpreter 调用tflite文件就已经出现精度下降的问题,android端部署也是一样。

1.网络结构

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
 
import tensorflow as tf
slim = tf.contrib.slim
 
def ttnet(images, num_classes=10, is_training=False,
   dropout_keep_prob=0.5,
   prediction_fn=slim.softmax,
   scope='TtNet'):
 end_points = {}
 
 with tf.variable_scope(scope, 'TtNet', [images, num_classes]):
 net = slim.conv2d(images, 32, [3, 3], scope='conv1')
 # net = slim.conv2d(images, 64, [3, 3], scope='conv1_2')
 net = slim.max_pool2d(net, [2, 2], 2, scope='pool1')
 net = slim.batch_norm(net, activation_fn=tf.nn.relu, scope='bn1')
 # net = slim.conv2d(net, 128, [3, 3], scope='conv2_1')
 net = slim.conv2d(net, 64, [3, 3], scope='conv2')
 net = slim.max_pool2d(net, [2, 2], 2, scope='pool2')
 net = slim.conv2d(net, 128, [3, 3], scope='conv3')
 net = slim.max_pool2d(net, [2, 2], 2, scope='pool3')
 net = slim.conv2d(net, 256, [3, 3], scope='conv4')
 net = slim.max_pool2d(net, [2, 2], 2, scope='pool4')
 net = slim.batch_norm(net, activation_fn=tf.nn.relu, scope='bn2')
 # net = slim.conv2d(net, 512, [3, 3], scope='conv5')
 # net = slim.max_pool2d(net, [2, 2], 2, scope='pool5')
 net = slim.flatten(net)
 end_points['Flatten'] = net
 
 # net = slim.fully_connected(net, 1024, scope='fc3')
 net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
      scope='dropout3')
 logits = slim.fully_connected(net, num_classes, activation_fn=None,
         scope='fc4') 
 end_points['Logits'] = logits
 end_points['Predictions'] = prediction_fn(logits, scope='Predictions')
 
 return logits, end_points
ttnet.default_image_size = 28
 
def ttnet_arg_scope(weight_decay=0.0):
 with slim.arg_scope(
  [slim.conv2d, slim.fully_connected],
  weights_regularizer=slim.l2_regularizer(weight_decay),
  weights_initializer=tf.truncated_normal_initializer(stddev=0.1),
  activation_fn=tf.nn.relu) as sc:
 return sc

基于slim,由于是一个比较简单的分类问题,网络结构也很简单,几个卷积加池化。

测试效果是很棒的。真实样本测试集能达到99%+的准确率。

2.模型固化,生成pb文件

#coding:utf-8
 
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from nets import nets_factory
import cv2
import os
import numpy as np
from datasets import dataset_factory
from preprocessing import preprocessing_factory
from tensorflow.python.platform import gfile
slim = tf.contrib.slim
#todo
#support arbitray image size and num_class
 
tf.app.flags.DEFINE_string(
 'checkpoint_path', '/tmp/tfmodel/',
 'The directory where the model was written to or an absolute path to a '
 'checkpoint file.')
 
tf.app.flags.DEFINE_string(
 'model_name', 'inception_v3', 'The name of the architecture to evaluate.')
tf.app.flags.DEFINE_string(
 'preprocessing_name', None, 'The name of the preprocessing to use. If left '
 'as `None`, then the model_name flag is used.')
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_integer(
 'eval_image_size', None, 'Eval image size')
tf.app.flags.DEFINE_integer(
 'eval_image_height', None, 'Eval image height')
tf.app.flags.DEFINE_integer(
 'eval_image_width', None, 'Eval image width')
tf.app.flags.DEFINE_string(
 'export_path', './ttnet_1.0_37_32.pb', 'the export path of the pd file')
FLAGS = tf.app.flags.FLAGS
NUM_CLASSES = 37
 
def main(_):
 network_fn = nets_factory.get_network_fn(
  FLAGS.model_name,
  num_classes=NUM_CLASSES,
  is_training=False)
 # pre_image = tf.placeholder(tf.float32, [None, None, 3], name='input_data')
 # preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
 # image_preprocessing_fn = preprocessing_factory.get_preprocessing(
 #  preprocessing_name,
 #  is_training=False)
 # image = image_preprocessing_fn(pre_image, FLAGS.eval_image_height, FLAGS.eval_image_width)
 # images2 = tf.expand_dims(image, 0)
 images2 = tf.placeholder(tf.float32, (None,32, 32, 3),name='input_data')
 logits, endpoints = network_fn(images2)
 with tf.Session() as sess:
 output = tf.identity(endpoints['Predictions'],name="output_data")
 with gfile.GFile(FLAGS.export_path, 'wb') as f:
  f.write(sess.graph_def.SerializeToString())
 
if __name__ == '__main__':
 tf.app.run()

3.生成tflite文件

import tensorflow as tf
 
graph_def_file = "/datastore1/Colonist_Lord/Colonist_Lord/workspace/models/model_files/passport_model_with_tflite/ocr_frozen.pb"
input_arrays = ["input_data"]
output_arrays = ["output_data"]
 
converter = tf.lite.TFLiteConverter.from_frozen_graph(
 graph_def_file, input_arrays, output_arrays)
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)

使用pb文件进行测试,效果正常;使用tflite文件进行测试,精度下降严重。下面附上pb与tflite测试代码。

pb测试代码

with tf.gfile.GFile(graph_filename, "rb") as f:
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())
 
with tf.Graph().as_default() as graph:
 tf.import_graph_def(graph_def)
 input_node = graph.get_tensor_by_name('import/input_data:0')
 output_node = graph.get_tensor_by_name('import/output_data:0')
 with tf.Session() as sess:
  for image_file in image_files:
   abs_path = os.path.join(image_folder, image_file)
   img = cv2.imread(abs_path).astype(np.float32)
   img = cv2.resize(img, (int(input_node.shape[1]), int(input_node.shape[2])))
   output_data = sess.run(output_node, feed_dict={input_node: [img]})
   index = np.argmax(output_data)
   label = dict_laebl[index]
   dst_floder = os.path.join(result_folder, label)
   if not os.path.exists(dst_floder):
    os.mkdir(dst_floder)
   cv2.imwrite(os.path.join(dst_floder, image_file), img)
   count += 1

tflite测试代码

model_path = "converted_model.tflite" #"/datastore1/Colonist_Lord/Colonist_Lord/data/passport_char/ocr.tflite"
interpreter = tf.contrib.lite.Interpreter(model_path=model_path)
interpreter.allocate_tensors()
 
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
for image_file in image_files:
 abs_path = os.path.join(image_folder,image_file)
 img = cv2.imread(abs_path).astype(np.float32)
 img = cv2.resize(img, tuple(input_details[0]['shape'][1:3]))
 # input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
 interpreter.set_tensor(input_details[0]['index'], [img])
 
 interpreter.invoke()
 output_data = interpreter.get_tensor(output_details[0]['index'])
 index = np.argmax(output_data)
 label = dict_laebl[index]
 dst_floder = os.path.join(result_folder,label)
 if not os.path.exists(dst_floder):
  os.mkdir(dst_floder)
 cv2.imwrite(os.path.join(dst_floder,image_file),img)
 count+=1

最后也算是绕过这个问题解决了业务需求,后面有空的话,还是会花时间研究一下这个问题。

如果有哪个大佬知道原因,希望不吝赐教。

补充知识:.pb 转tflite代码,使用量化,减小体积,converter.post_training_quantize = True

import tensorflow as tf

path = "/home/python/Downloads/a.pb" # pb文件位置和文件名
inputs = ["input_images"] # 模型文件的输入节点名称
classes = ['feature_fusion/Conv_7/Sigmoid','feature_fusion/concat_3'] # 模型文件的输出节点名称
# converter = tf.contrib.lite.TocoConverter.from_frozen_graph(path, inputs, classes, input_shapes={'input_images':[1, 320, 320, 3]})
converter = tf.lite.TFLiteConverter.from_frozen_graph(path, inputs, classes,
              input_shapes={'input_images': [1, 320, 320, 3]})
converter.post_training_quantize = True
tflite_model = converter.convert()
open("/home/python/Downloads/aNew.tflite", "wb").write(tflite_model)

以上这篇tensorflow pb to tflite 精度下降详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Django中几种重定向方法
Apr 28 Python
Pthon批量处理将pdb文件生成dssp文件
Jun 21 Python
Python标准库之collections包的使用教程
Apr 27 Python
Pandas 数据框增、删、改、查、去重、抽样基本操作方法
Apr 12 Python
解决python大批量读写.doc文件的问题
May 08 Python
pytorch神经网络之卷积层与全连接层参数的设置方法
Aug 18 Python
python图像处理模块Pillow的学习详解
Oct 09 Python
win7下 python3.6 安装opencv 和 opencv-contrib-python解决 cv2.xfeatures2d.SIFT_create() 的问题
Oct 24 Python
Python开发之pip安装及使用方法详解
Feb 21 Python
Python编程快速上手——PDF文件操作案例分析
Feb 28 Python
Python 通过监听端口实现唯一脚本运行方式
May 05 Python
Python图片检索之以图搜图
May 31 Python
Python HTMLTestRunner测试报告view按钮失效解决方案
May 25 #Python
python用opencv完成图像分割并进行目标物的提取
May 25 #Python
Pytorch转tflite方式
May 25 #Python
Python HTMLTestRunner库安装过程解析
May 25 #Python
Anaconda+vscode+pytorch环境搭建过程详解
May 25 #Python
5行Python代码实现图像分割的步骤详解
May 25 #Python
Win10用vscode打开anaconda环境中的python出错问题的解决
May 25 #Python
You might like
smarty巧妙处理iframe中内容页的代码
2012/03/07 PHP
php 在windows下配置虚拟目录的方法介绍
2013/06/26 PHP
thinkphp验证码显示不出来的解决方法
2014/03/29 PHP
php防止伪造数据从地址栏URL提交的方法
2014/08/24 PHP
PHP中变量引用与变量销毁机制分析
2014/11/15 PHP
Zend Framework教程之Zend_Registry对象用法分析
2016/03/22 PHP
php面向对象的用户登录身份验证
2017/06/08 PHP
JavaScript 继承详解(二)
2009/07/13 Javascript
IE6下JS动态设置图片src地址问题
2010/01/08 Javascript
6款经典实用的jQuery小插件及源码(对话框/提示工具等等)
2013/02/04 Javascript
删除条目时弹出的确认对话框
2014/06/05 Javascript
javascript匿名函数实例分析
2014/11/18 Javascript
Javascript 拖拽雏形中的一些问题(逐行分析代码,让你轻松了拖拽的原理)
2015/01/23 Javascript
jQuery对象与DOM对象之间的相互转换
2015/03/03 Javascript
jquery append 动态添加的元素事件on 不起作用的解决方案
2015/07/30 Javascript
javascript适合移动端的日期时间拾取器
2015/11/10 Javascript
详解Bootstrap四种图片样式
2016/01/04 Javascript
jQuery选取所有复选框被选中的值并用Ajax异步提交数据的实例
2017/08/04 jQuery
Vue 列表上下过渡效果的实例代码
2019/06/25 Javascript
CKeditor4 字体颜色功能配置方法教程
2019/06/26 Javascript
[03:13]DOTA2-DPC中国联赛1月25日Recap集锦
2021/03/11 DOTA
python 实现一个贴吧图片爬虫的示例
2017/10/12 Python
Python自动化运维_文件内容差异对比分析
2017/12/13 Python
对python 匹配字符串开头和结尾的方法详解
2018/10/27 Python
python自动化生成IOS的图标
2018/11/13 Python
Python面向对象程序设计之类的定义与继承简单示例
2019/03/18 Python
详解pytorch中squeeze()和unsqueeze()函数介绍
2020/09/03 Python
解决python 在for循环并且pop数组的时候会跳过某些元素的问题
2020/12/11 Python
python爬虫如何解决图片验证码
2021/02/14 Python
详解解决jupyter不能使用pytorch的问题
2021/02/18 Python
财务会计岗位职责
2015/02/03 职场文书
3.15消费者权益日活动总结
2015/02/09 职场文书
高中教师个人工作总结
2015/02/10 职场文书
2015大学迎新标语
2015/07/16 职场文书
oracle通过存储过程上传list保存功能
2021/05/12 Oracle
解决sql server 数据库,sa用户被锁定的问题
2021/06/11 SQL Server