30秒轻松实现TensorFlow物体检测


Posted in Python onMarch 14, 2018

Google发布了新的TensorFlow物体检测API,包含了预训练模型,一个发布模型的jupyter notebook,一些可用于使用自己数据集对模型进行重新训练的有用脚本。

使用该API可以快速的构建一些图片中物体检测的应用。这里我们一步一步来看如何使用预训练模型来检测图像中的物体。

首先我们载入一些会使用的库

import numpy as np 
import os 
import six.moves.urllib as urllib 
import sys 
import tarfile 
import tensorflow as tf 
import zipfile 
 
from collections import defaultdict 
from io import StringIO 
from matplotlib import pyplot as plt 
from PIL import Image

接下来进行环境设置

%matplotlib inline 
sys.path.append("..")

物体检测载入

from utils import label_map_util 
 
from utils import visualization_utils as vis_util

准备模型

变量  任何使用export_inference_graph.py工具输出的模型可以在这里载入,只需简单改变PATH_TO_CKPT指向一个新的.pb文件。这里我们使用“移动网SSD”模型。

MODEL_NAME = 'ssd_mobilenet_v1_coco_11_06_2017' 
MODEL_FILE = MODEL_NAME + '.tar.gz' 
DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/' 
 
PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb' 
 
PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt') 
 
NUM_CLASSES = 90

下载模型

opener = urllib.request.URLopener() 
opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE) 
tar_file = tarfile.open(MODEL_FILE) 
for file in tar_file.getmembers(): 
  file_name = os.path.basename(file.name) 
  if 'frozen_inference_graph.pb' in file_name: 
    tar_file.extract(file, os.getcwd())

将(frozen)TensorFlow模型载入内存

detection_graph = tf.Graph() 
with detection_graph.as_default(): 
  od_graph_def = tf.GraphDef() 
  with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid: 
    serialized_graph = fid.read() 
    od_graph_def.ParseFromString(serialized_graph) 
    tf.import_graph_def(od_graph_def, name='')

载入标签图

标签图将索引映射到类名称,当我们的卷积预测5时,我们知道它对应飞机。这里我们使用内置函数,但是任何返回将整数映射到恰当字符标签的字典都适用。

label_map = label_map_util.load_labelmap(PATH_TO_LABELS) 
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True) 
category_index = label_map_util.create_category_index(categories)

辅助代码

def load_image_into_numpy_array(image): 
 (im_width, im_height) = image.size 
 return np.array(image.getdata()).reshape( 
   (im_height, im_width, 3)).astype(np.uint8)

检测

PATH_TO_TEST_IMAGES_DIR = 'test_images' 
TEST_IMAGE_PATHS = [ os.path.join(PATH_TO_TEST_IMAGES_DIR, 'image{}.jpg'.format(i)) for i in range(1, 3) ] 
IMAGE_SIZE = (12, 8) 
[python] view plain copy
with detection_graph.as_default(): 
 
 with tf.Session(graph=detection_graph) as sess: 
  for image_path in TEST_IMAGE_PATHS: 
   image = Image.open(image_path) 
   # 这个array在之后会被用来准备为图片加上框和标签 
   image_np = load_image_into_numpy_array(image) 
   # 扩展维度,应为模型期待: [1, None, None, 3] 
   image_np_expanded = np.expand_dims(image_np, axis=0) 
   image_tensor = detection_graph.get_tensor_by_name('image_tensor:0') 
   # 每个框代表一个物体被侦测到. 
   boxes = detection_graph.get_tensor_by_name('detection_boxes:0') 
   # 每个分值代表侦测到物体的可信度. 
   scores = detection_graph.get_tensor_by_name('detection_scores:0') 
   classes = detection_graph.get_tensor_by_name('detection_classes:0') 
   num_detections = detection_graph.get_tensor_by_name('num_detections:0') 
   # 执行侦测任务. 
   (boxes, scores, classes, num_detections) = sess.run( 
     [boxes, scores, classes, num_detections], 
     feed_dict={image_tensor: image_np_expanded}) 
   # 图形化. 
   vis_util.visualize_boxes_and_labels_on_image_array( 
     image_np, 
     np.squeeze(boxes), 
     np.squeeze(classes).astype(np.int32), 
     np.squeeze(scores), 
     category_index, 
     use_normalized_coordinates=True, 
     line_thickness=8) 
   plt.figure(figsize=IMAGE_SIZE) 
   plt.imshow(image_np)

在载入模型部分可以尝试不同的侦测模型以比较速度和准确度,将你想侦测的图片放入TEST_IMAGE_PATHS中运行即可。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python增量循环删除MySQL表数据的方法
Sep 23 Python
基于Python函数的作用域规则和闭包(详解)
Nov 29 Python
python flask中静态文件的管理方法
Mar 20 Python
Django在pycharm下修改默认启动端口的方法
Jul 26 Python
Python人工智能之路 之PyAudio 实现录音 自动化交互实现问答
Aug 13 Python
使用python获取邮箱邮件的设置方法
Sep 20 Python
python装饰器使用实例详解
Dec 14 Python
使用pickle存储数据dump 和 load实例讲解
Dec 30 Python
Python爬虫库BeautifulSoup的介绍与简单使用实例
Jan 25 Python
聊聊Python pandas 中loc函数的使用,及跟iloc的区别说明
Mar 03 Python
python用字节处理文件实例讲解
Apr 13 Python
高考要来啦!用Python爬取历年高考数据并分析
Jun 03 Python
tensorflow识别自己手写数字
Mar 14 #Python
磁盘垃圾文件清理器python代码实现
Aug 24 #Python
Django自定义用户认证示例详解
Mar 14 #Python
python如何压缩新文件到已有ZIP文件
Mar 14 #Python
python中format()函数的简单使用教程
Mar 14 #Python
Python批量提取PDF文件中文本的脚本
Mar 14 #Python
深入理解Django的中间件middleware
Mar 14 #Python
You might like
PHP让网站移动访问更加友好方法
2019/02/14 PHP
微信公众平台开发教程③ PHP实现微信公众号支付功能图文详解
2019/04/10 PHP
thinkphp5.1框架容器与依赖注入实例分析
2019/07/23 PHP
php和nginx交互实例讲解
2019/09/24 PHP
jquery判断字符输入个数(数字英文长度记为1,中文记为2,超过长度自动截取)
2010/10/15 Javascript
用jquery中插件dialog实现弹框效果实例代码
2013/11/15 Javascript
javascript 中__proto__和prototype详解
2014/11/25 Javascript
基于bootstrap插件实现autocomplete自动完成表单
2016/05/07 Javascript
JavaScript中判断数据类型的方法总结
2016/05/24 Javascript
ionic实现带字的toggle滑动组件
2016/08/27 Javascript
详解使用grunt完成requirejs的合并压缩和js文件的版本控制
2017/03/02 Javascript
Node.js 使用命令行工具检查更新
2017/06/08 Javascript
Python读取Excel的方法实例分析
2015/07/11 Python
基于python实现微信模板消息
2015/12/21 Python
python实现爬虫统计学校BBS男女比例之多线程爬虫(二)
2015/12/31 Python
详解Python的collections模块中的deque双端队列结构
2016/07/07 Python
将string类型的数据类型转换为spark rdd时报错的解决方法
2019/02/18 Python
pytorch实现mnist分类的示例讲解
2020/01/10 Python
Tensorflow 实现释放内存
2020/02/03 Python
Windows下Anaconda和PyCharm的安装与使用详解
2020/04/23 Python
Python xmltodict模块安装及代码实例
2020/10/05 Python
css3.0 图形构成实例练习二
2013/03/19 HTML / CSS
船餐厅和泰晤士河餐饮游轮:Bateaux London
2018/03/19 全球购物
bonprix荷兰网上商店:便宜的服装、鞋子和家居用品
2020/07/04 全球购物
硕士研究生求职自荐信范文
2014/03/11 职场文书
聚美优品恶搞广告词
2014/03/14 职场文书
学校领导班子对照检查材料
2014/08/28 职场文书
学习张林森心得体会
2014/09/10 职场文书
公司租房协议书范本
2014/10/08 职场文书
开展党的群众路线教育实践活动剖析材料
2014/10/13 职场文书
运动会闭幕式通讯稿
2015/07/18 职场文书
新教师2015年度工作总结
2015/07/22 职场文书
Nginx反爬虫策略,防止UA抓取网站
2021/03/31 Servers
python3 sqlite3限制条件查询的操作
2021/04/07 Python
Python控制台输出俄罗斯方块的方法实例
2021/04/17 Python
浅谈CSS不规则边框的生成方案
2021/05/25 HTML / CSS