30秒轻松实现TensorFlow物体检测


Posted in Python onMarch 14, 2018

Google发布了新的TensorFlow物体检测API,包含了预训练模型,一个发布模型的jupyter notebook,一些可用于使用自己数据集对模型进行重新训练的有用脚本。

使用该API可以快速的构建一些图片中物体检测的应用。这里我们一步一步来看如何使用预训练模型来检测图像中的物体。

首先我们载入一些会使用的库

import numpy as np 
import os 
import six.moves.urllib as urllib 
import sys 
import tarfile 
import tensorflow as tf 
import zipfile 
 
from collections import defaultdict 
from io import StringIO 
from matplotlib import pyplot as plt 
from PIL import Image

接下来进行环境设置

%matplotlib inline 
sys.path.append("..")

物体检测载入

from utils import label_map_util 
 
from utils import visualization_utils as vis_util

准备模型

变量  任何使用export_inference_graph.py工具输出的模型可以在这里载入,只需简单改变PATH_TO_CKPT指向一个新的.pb文件。这里我们使用“移动网SSD”模型。

MODEL_NAME = 'ssd_mobilenet_v1_coco_11_06_2017' 
MODEL_FILE = MODEL_NAME + '.tar.gz' 
DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/' 
 
PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb' 
 
PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt') 
 
NUM_CLASSES = 90

下载模型

opener = urllib.request.URLopener() 
opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE) 
tar_file = tarfile.open(MODEL_FILE) 
for file in tar_file.getmembers(): 
  file_name = os.path.basename(file.name) 
  if 'frozen_inference_graph.pb' in file_name: 
    tar_file.extract(file, os.getcwd())

将(frozen)TensorFlow模型载入内存

detection_graph = tf.Graph() 
with detection_graph.as_default(): 
  od_graph_def = tf.GraphDef() 
  with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid: 
    serialized_graph = fid.read() 
    od_graph_def.ParseFromString(serialized_graph) 
    tf.import_graph_def(od_graph_def, name='')

载入标签图

标签图将索引映射到类名称,当我们的卷积预测5时,我们知道它对应飞机。这里我们使用内置函数,但是任何返回将整数映射到恰当字符标签的字典都适用。

label_map = label_map_util.load_labelmap(PATH_TO_LABELS) 
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True) 
category_index = label_map_util.create_category_index(categories)

辅助代码

def load_image_into_numpy_array(image): 
 (im_width, im_height) = image.size 
 return np.array(image.getdata()).reshape( 
   (im_height, im_width, 3)).astype(np.uint8)

检测

PATH_TO_TEST_IMAGES_DIR = 'test_images' 
TEST_IMAGE_PATHS = [ os.path.join(PATH_TO_TEST_IMAGES_DIR, 'image{}.jpg'.format(i)) for i in range(1, 3) ] 
IMAGE_SIZE = (12, 8) 
[python] view plain copy
with detection_graph.as_default(): 
 
 with tf.Session(graph=detection_graph) as sess: 
  for image_path in TEST_IMAGE_PATHS: 
   image = Image.open(image_path) 
   # 这个array在之后会被用来准备为图片加上框和标签 
   image_np = load_image_into_numpy_array(image) 
   # 扩展维度,应为模型期待: [1, None, None, 3] 
   image_np_expanded = np.expand_dims(image_np, axis=0) 
   image_tensor = detection_graph.get_tensor_by_name('image_tensor:0') 
   # 每个框代表一个物体被侦测到. 
   boxes = detection_graph.get_tensor_by_name('detection_boxes:0') 
   # 每个分值代表侦测到物体的可信度. 
   scores = detection_graph.get_tensor_by_name('detection_scores:0') 
   classes = detection_graph.get_tensor_by_name('detection_classes:0') 
   num_detections = detection_graph.get_tensor_by_name('num_detections:0') 
   # 执行侦测任务. 
   (boxes, scores, classes, num_detections) = sess.run( 
     [boxes, scores, classes, num_detections], 
     feed_dict={image_tensor: image_np_expanded}) 
   # 图形化. 
   vis_util.visualize_boxes_and_labels_on_image_array( 
     image_np, 
     np.squeeze(boxes), 
     np.squeeze(classes).astype(np.int32), 
     np.squeeze(scores), 
     category_index, 
     use_normalized_coordinates=True, 
     line_thickness=8) 
   plt.figure(figsize=IMAGE_SIZE) 
   plt.imshow(image_np)

在载入模型部分可以尝试不同的侦测模型以比较速度和准确度,将你想侦测的图片放入TEST_IMAGE_PATHS中运行即可。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python使用代理抓取网站图片(多线程)
Mar 14 Python
Python使用urllib模块的urlopen超时问题解决方法
Nov 08 Python
剖析Python的Tornado框架中session支持的实现代码
Aug 21 Python
Python的爬虫程序编写框架Scrapy入门学习教程
Jul 02 Python
Python标准库inspect的具体使用方法
Dec 06 Python
python+POP3实现批量下载邮件附件
Jun 19 Python
Python合并多个Excel数据的方法
Jul 16 Python
Python脚本利用adb进行手机控制的方法
Jul 08 Python
pytorch之Resize()函数具体使用详解
Feb 27 Python
使用Keras加载含有自定义层或函数的模型操作
Jun 10 Python
Matplotlib 绘制饼图解决文字重叠的方法
Jul 24 Python
python b站视频下载的五种版本
May 27 Python
tensorflow识别自己手写数字
Mar 14 #Python
磁盘垃圾文件清理器python代码实现
Aug 24 #Python
Django自定义用户认证示例详解
Mar 14 #Python
python如何压缩新文件到已有ZIP文件
Mar 14 #Python
python中format()函数的简单使用教程
Mar 14 #Python
Python批量提取PDF文件中文本的脚本
Mar 14 #Python
深入理解Django的中间件middleware
Mar 14 #Python
You might like
PHP防注入安全代码
2008/04/09 PHP
在命令行下运行PHP脚本[带参数]的方法
2010/01/22 PHP
php中取得文件的后缀名?
2012/02/20 PHP
php 修改、增加xml结点属性的实现代码
2013/10/22 PHP
PHP简单实现“相关文章推荐”功能的方法
2014/07/19 PHP
完美解决php 导出excle的.csv格式的数据时乱码问题
2017/02/18 PHP
PHP 进度条函数的简单实例
2017/09/19 PHP
php 截取中英文混合字符串的方法
2018/05/31 PHP
PHP生成随机字符串实例代码(字母+数字)
2019/09/11 PHP
基于jquery的has()方法以及与find()方法以及filter()方法的区别详解
2013/04/26 Javascript
图片Slider 带左右按钮的js示例
2013/08/30 Javascript
JS遍历页面所有对象属性及实现方法
2016/08/01 Javascript
利用jQuery插件imgAreaSelect实现获得选择域的图像信息
2016/12/02 Javascript
canvas实现绘制吃豆鱼效果
2017/01/12 Javascript
jQuery实现radio第一次点击选中第二次点击取消功能
2017/05/15 jQuery
jQuery实现键盘回车搜索功能
2017/07/25 jQuery
解决vue A对象赋值给B对象,修改B属性会影响到A的问题
2018/09/25 Javascript
微信小程序--特定区域滚动到顶部时固定的方法
2019/04/28 Javascript
微信小程序时间戳转日期的详解
2019/04/30 Javascript
详解VUE调用本地json的使用方法
2019/05/15 Javascript
原生js实现下拉选项卡
2019/11/27 Javascript
JavaScript中layim之整合右键菜单的示例代码
2021/02/06 Javascript
python单元测试unittest实例详解
2015/05/11 Python
浅谈Matplotlib简介和pyplot的简单使用——文本标注和箭头
2018/01/09 Python
Python实现可设置持续运行时间、线程数及时间间隔的多线程异步post请求功能
2018/01/11 Python
Python中一行和多行import模块问题
2018/04/01 Python
Python中的二维数组实例(list与numpy.array)
2018/04/13 Python
python判断数字是否是超级素数幂
2018/09/27 Python
python 接口实现 供第三方调用的例子
2019/08/13 Python
Pytorch .pth权重文件的使用解析
2020/02/14 Python
印度尼西亚在线时尚购物网站:ZALORA印尼
2016/08/02 全球购物
园林资料员岗位职责
2013/12/30 职场文书
团支书竞选演讲稿
2014/04/28 职场文书
家属答谢词
2015/01/05 职场文书
python 逐步回归算法
2021/04/06 Python
MySQL5.7并行复制原理及实现
2021/06/03 MySQL