30秒轻松实现TensorFlow物体检测


Posted in Python onMarch 14, 2018

Google发布了新的TensorFlow物体检测API,包含了预训练模型,一个发布模型的jupyter notebook,一些可用于使用自己数据集对模型进行重新训练的有用脚本。

使用该API可以快速的构建一些图片中物体检测的应用。这里我们一步一步来看如何使用预训练模型来检测图像中的物体。

首先我们载入一些会使用的库

import numpy as np 
import os 
import six.moves.urllib as urllib 
import sys 
import tarfile 
import tensorflow as tf 
import zipfile 
 
from collections import defaultdict 
from io import StringIO 
from matplotlib import pyplot as plt 
from PIL import Image

接下来进行环境设置

%matplotlib inline 
sys.path.append("..")

物体检测载入

from utils import label_map_util 
 
from utils import visualization_utils as vis_util

准备模型

变量  任何使用export_inference_graph.py工具输出的模型可以在这里载入,只需简单改变PATH_TO_CKPT指向一个新的.pb文件。这里我们使用“移动网SSD”模型。

MODEL_NAME = 'ssd_mobilenet_v1_coco_11_06_2017' 
MODEL_FILE = MODEL_NAME + '.tar.gz' 
DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/' 
 
PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb' 
 
PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt') 
 
NUM_CLASSES = 90

下载模型

opener = urllib.request.URLopener() 
opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE) 
tar_file = tarfile.open(MODEL_FILE) 
for file in tar_file.getmembers(): 
  file_name = os.path.basename(file.name) 
  if 'frozen_inference_graph.pb' in file_name: 
    tar_file.extract(file, os.getcwd())

将(frozen)TensorFlow模型载入内存

detection_graph = tf.Graph() 
with detection_graph.as_default(): 
  od_graph_def = tf.GraphDef() 
  with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid: 
    serialized_graph = fid.read() 
    od_graph_def.ParseFromString(serialized_graph) 
    tf.import_graph_def(od_graph_def, name='')

载入标签图

标签图将索引映射到类名称,当我们的卷积预测5时,我们知道它对应飞机。这里我们使用内置函数,但是任何返回将整数映射到恰当字符标签的字典都适用。

label_map = label_map_util.load_labelmap(PATH_TO_LABELS) 
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True) 
category_index = label_map_util.create_category_index(categories)

辅助代码

def load_image_into_numpy_array(image): 
 (im_width, im_height) = image.size 
 return np.array(image.getdata()).reshape( 
   (im_height, im_width, 3)).astype(np.uint8)

检测

PATH_TO_TEST_IMAGES_DIR = 'test_images' 
TEST_IMAGE_PATHS = [ os.path.join(PATH_TO_TEST_IMAGES_DIR, 'image{}.jpg'.format(i)) for i in range(1, 3) ] 
IMAGE_SIZE = (12, 8) 
[python] view plain copy
with detection_graph.as_default(): 
 
 with tf.Session(graph=detection_graph) as sess: 
  for image_path in TEST_IMAGE_PATHS: 
   image = Image.open(image_path) 
   # 这个array在之后会被用来准备为图片加上框和标签 
   image_np = load_image_into_numpy_array(image) 
   # 扩展维度,应为模型期待: [1, None, None, 3] 
   image_np_expanded = np.expand_dims(image_np, axis=0) 
   image_tensor = detection_graph.get_tensor_by_name('image_tensor:0') 
   # 每个框代表一个物体被侦测到. 
   boxes = detection_graph.get_tensor_by_name('detection_boxes:0') 
   # 每个分值代表侦测到物体的可信度. 
   scores = detection_graph.get_tensor_by_name('detection_scores:0') 
   classes = detection_graph.get_tensor_by_name('detection_classes:0') 
   num_detections = detection_graph.get_tensor_by_name('num_detections:0') 
   # 执行侦测任务. 
   (boxes, scores, classes, num_detections) = sess.run( 
     [boxes, scores, classes, num_detections], 
     feed_dict={image_tensor: image_np_expanded}) 
   # 图形化. 
   vis_util.visualize_boxes_and_labels_on_image_array( 
     image_np, 
     np.squeeze(boxes), 
     np.squeeze(classes).astype(np.int32), 
     np.squeeze(scores), 
     category_index, 
     use_normalized_coordinates=True, 
     line_thickness=8) 
   plt.figure(figsize=IMAGE_SIZE) 
   plt.imshow(image_np)

在载入模型部分可以尝试不同的侦测模型以比较速度和准确度,将你想侦测的图片放入TEST_IMAGE_PATHS中运行即可。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Python的Bottle框架中使用微信API的示例
Apr 23 Python
python文件与目录操作实例详解
Feb 22 Python
疯狂上涨的Python 开发者应从2.x还是3.x着手?
Nov 16 Python
Python爬虫包BeautifulSoup实例(三)
Jun 17 Python
Python生成器generator用法示例
Aug 10 Python
使用python打印十行杨辉三角过程详解
Jul 10 Python
sklearn+python:线性回归案例
Feb 24 Python
Python函数基本使用原理详解
Mar 19 Python
使用Pycharm在运行过程中,查看每个变量的操作(show variables)
Jun 08 Python
django haystack实现全文检索的示例代码
Jun 24 Python
pycharm debug 断点调试心得分享
Apr 16 Python
Golang Web 框架Iris安装部署
Aug 14 Python
tensorflow识别自己手写数字
Mar 14 #Python
磁盘垃圾文件清理器python代码实现
Aug 24 #Python
Django自定义用户认证示例详解
Mar 14 #Python
python如何压缩新文件到已有ZIP文件
Mar 14 #Python
python中format()函数的简单使用教程
Mar 14 #Python
Python批量提取PDF文件中文本的脚本
Mar 14 #Python
深入理解Django的中间件middleware
Mar 14 #Python
You might like
php 删除cookie和浏览器重定向
2009/03/16 PHP
PHP中array_map与array_column之间的关系分析
2014/08/19 PHP
Zend Framework+smarty用法实例详解
2016/03/19 PHP
php 读取文件夹下所有图片、文件的实例
2018/10/17 PHP
php输出文字乱码的解决方法
2019/10/04 PHP
jQuery Pagination Ajax分页插件(分页切换时无刷新与延迟)中文翻译版
2013/01/11 Javascript
jQuery:delegate中select()不起作用的解决方法(实例讲解)
2014/01/26 Javascript
jquery控制display属性为none或block
2014/03/31 Javascript
TypeScript 中接口详解
2015/06/19 Javascript
使用ngView配合AngularJS应用实现动画效果的方法
2015/06/19 Javascript
Angular页面间切换及传值的4种方法
2016/11/04 Javascript
提高JavaScript执行效率的23个实用技巧
2017/03/01 Javascript
JavaScript箭头(arrow)函数详解
2017/06/04 Javascript
Vue学习笔记进阶篇之vue-cli安装及介绍
2017/07/18 Javascript
详解如何在angular2中获取节点
2017/11/23 Javascript
vue登录路由验证的实现
2017/12/13 Javascript
JavaScript实现读取与输出XML文件数据的方法示例
2018/06/05 Javascript
JS将网址url转化为JSON格式的方法
2018/07/02 Javascript
详解angular应用容器化部署
2018/08/14 Javascript
js针对图片加载失败的处理方法分析
2019/08/24 Javascript
Vue 实现监听窗口关闭事件,并在窗口关闭前发送请求
2020/09/01 Javascript
在Python中marshal对象序列化的相关知识
2015/07/01 Python
Python 3中print函数的使用方法总结
2017/08/08 Python
Pandas 同元素多列去重的实例
2018/07/03 Python
树莓派使用USB摄像头和motion实现监控
2019/06/22 Python
django云端留言板实例详解
2019/07/22 Python
关于win10在tensorflow的安装及在pycharm中运行步骤详解
2020/03/16 Python
Python中Selenium模块的使用详解
2020/10/09 Python
python 实现数据库中数据添加、查询与更新的示例代码
2020/12/07 Python
学校创先争优活动总结
2014/08/28 职场文书
教师学习八项规定六项禁令思想汇报
2014/09/27 职场文书
政风行风建设整改方案
2014/10/27 职场文书
毕业论文致谢词
2015/05/14 职场文书
国庆节主题班会
2015/08/15 职场文书
python如何获取网络数据
2021/04/11 Python
CocosCreator ScrollView优化系列之分帧加载
2021/04/14 Python