30秒轻松实现TensorFlow物体检测


Posted in Python onMarch 14, 2018

Google发布了新的TensorFlow物体检测API,包含了预训练模型,一个发布模型的jupyter notebook,一些可用于使用自己数据集对模型进行重新训练的有用脚本。

使用该API可以快速的构建一些图片中物体检测的应用。这里我们一步一步来看如何使用预训练模型来检测图像中的物体。

首先我们载入一些会使用的库

import numpy as np 
import os 
import six.moves.urllib as urllib 
import sys 
import tarfile 
import tensorflow as tf 
import zipfile 
 
from collections import defaultdict 
from io import StringIO 
from matplotlib import pyplot as plt 
from PIL import Image

接下来进行环境设置

%matplotlib inline 
sys.path.append("..")

物体检测载入

from utils import label_map_util 
 
from utils import visualization_utils as vis_util

准备模型

变量  任何使用export_inference_graph.py工具输出的模型可以在这里载入,只需简单改变PATH_TO_CKPT指向一个新的.pb文件。这里我们使用“移动网SSD”模型。

MODEL_NAME = 'ssd_mobilenet_v1_coco_11_06_2017' 
MODEL_FILE = MODEL_NAME + '.tar.gz' 
DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/' 
 
PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb' 
 
PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt') 
 
NUM_CLASSES = 90

下载模型

opener = urllib.request.URLopener() 
opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE) 
tar_file = tarfile.open(MODEL_FILE) 
for file in tar_file.getmembers(): 
  file_name = os.path.basename(file.name) 
  if 'frozen_inference_graph.pb' in file_name: 
    tar_file.extract(file, os.getcwd())

将(frozen)TensorFlow模型载入内存

detection_graph = tf.Graph() 
with detection_graph.as_default(): 
  od_graph_def = tf.GraphDef() 
  with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid: 
    serialized_graph = fid.read() 
    od_graph_def.ParseFromString(serialized_graph) 
    tf.import_graph_def(od_graph_def, name='')

载入标签图

标签图将索引映射到类名称,当我们的卷积预测5时,我们知道它对应飞机。这里我们使用内置函数,但是任何返回将整数映射到恰当字符标签的字典都适用。

label_map = label_map_util.load_labelmap(PATH_TO_LABELS) 
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True) 
category_index = label_map_util.create_category_index(categories)

辅助代码

def load_image_into_numpy_array(image): 
 (im_width, im_height) = image.size 
 return np.array(image.getdata()).reshape( 
   (im_height, im_width, 3)).astype(np.uint8)

检测

PATH_TO_TEST_IMAGES_DIR = 'test_images' 
TEST_IMAGE_PATHS = [ os.path.join(PATH_TO_TEST_IMAGES_DIR, 'image{}.jpg'.format(i)) for i in range(1, 3) ] 
IMAGE_SIZE = (12, 8) 
[python] view plain copy
with detection_graph.as_default(): 
 
 with tf.Session(graph=detection_graph) as sess: 
  for image_path in TEST_IMAGE_PATHS: 
   image = Image.open(image_path) 
   # 这个array在之后会被用来准备为图片加上框和标签 
   image_np = load_image_into_numpy_array(image) 
   # 扩展维度,应为模型期待: [1, None, None, 3] 
   image_np_expanded = np.expand_dims(image_np, axis=0) 
   image_tensor = detection_graph.get_tensor_by_name('image_tensor:0') 
   # 每个框代表一个物体被侦测到. 
   boxes = detection_graph.get_tensor_by_name('detection_boxes:0') 
   # 每个分值代表侦测到物体的可信度. 
   scores = detection_graph.get_tensor_by_name('detection_scores:0') 
   classes = detection_graph.get_tensor_by_name('detection_classes:0') 
   num_detections = detection_graph.get_tensor_by_name('num_detections:0') 
   # 执行侦测任务. 
   (boxes, scores, classes, num_detections) = sess.run( 
     [boxes, scores, classes, num_detections], 
     feed_dict={image_tensor: image_np_expanded}) 
   # 图形化. 
   vis_util.visualize_boxes_and_labels_on_image_array( 
     image_np, 
     np.squeeze(boxes), 
     np.squeeze(classes).astype(np.int32), 
     np.squeeze(scores), 
     category_index, 
     use_normalized_coordinates=True, 
     line_thickness=8) 
   plt.figure(figsize=IMAGE_SIZE) 
   plt.imshow(image_np)

在载入模型部分可以尝试不同的侦测模型以比较速度和准确度,将你想侦测的图片放入TEST_IMAGE_PATHS中运行即可。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python的ORM框架SQLObject入门实例
Apr 28 Python
在Windows系统上搭建Nginx+Python+MySQL环境的教程
Dec 25 Python
举例讲解Python中的迭代器、生成器与列表解析用法
Mar 20 Python
python 生成器生成杨辉三角的方法(必看)
Apr 10 Python
简单学习Python多进程Multiprocessing
Aug 29 Python
浅谈使用Python变量时要避免的3个错误
Oct 30 Python
Python+tkinter模拟“记住我”自动登录实例代码
Jan 16 Python
python定时检测无响应进程并重启的实例代码
Apr 22 Python
Python3 执行Linux Bash命令的方法
Jul 12 Python
Java ExcutorService优雅关闭方式解析
May 30 Python
Python判断远程服务器上Excel文件是否被人打开的方法
Jul 13 Python
手把手教你如何用Pycharm2020.1.1配置远程连接的详细步骤
Aug 07 Python
tensorflow识别自己手写数字
Mar 14 #Python
磁盘垃圾文件清理器python代码实现
Aug 24 #Python
Django自定义用户认证示例详解
Mar 14 #Python
python如何压缩新文件到已有ZIP文件
Mar 14 #Python
python中format()函数的简单使用教程
Mar 14 #Python
Python批量提取PDF文件中文本的脚本
Mar 14 #Python
深入理解Django的中间件middleware
Mar 14 #Python
You might like
PHP提取字符串中的图片地址[正则表达式]
2011/11/12 PHP
PHP+MYSQL实现用户的增删改查
2015/03/24 PHP
Cookie跨域问题解决方案代码示例
2020/11/24 PHP
破解Session cookie的方法
2006/07/28 Javascript
鼠标移入移出事件改变图片的分辨率的两种方法
2013/12/17 Javascript
JQuery结合CSS操作打印样式的方法
2013/12/24 Javascript
常用的几段javascript代码分享
2014/03/25 Javascript
jQuery的$.proxy()应用示例介绍
2014/04/03 Javascript
AngularJS入门教程之学习环境搭建
2014/12/06 Javascript
简单实现异步编程promise模式
2015/07/31 Javascript
跟我学习javascript的for循环和for...in循环
2015/11/18 Javascript
location.hash保存页面状态的技巧
2016/04/28 Javascript
JS和jQuery使用submit方法无法提交表单的原因分析及解决办法
2016/05/17 Javascript
实例详解jQuery的无new构建
2016/08/02 Javascript
基于Bootstrap仿淘宝分页控件实现代码
2016/11/07 Javascript
微信小程序 引入es6 promise
2017/04/12 Javascript
微信小程序上滑加载下拉刷新(onscrollLower)分批加载数据(二)
2017/05/11 Javascript
js移动端事件基础及常用事件库详解
2017/08/15 Javascript
如何理解Vue的v-model指令的使用方法
2018/07/19 Javascript
解决vue项目nginx部署到非根目录下刷新空白的问题
2018/09/27 Javascript
Vue.js中provide/inject实现响应式数据更新的方法示例
2019/10/16 Javascript
python错误:AttributeError: 'module' object has no attribute 'setdefaultencoding'问题的解决方法
2014/08/22 Python
使用Python脚本操作MongoDB的教程
2015/04/16 Python
Python设计模式编程中Adapter适配器模式的使用实例
2016/03/02 Python
PyCharm 常用快捷键和设置方法
2017/12/20 Python
python3实现猜数字游戏
2020/12/07 Python
Python爬虫实现百度翻译功能过程详解
2020/05/29 Python
Python如何向SQLServer存储二进制图片
2020/06/08 Python
什么是java序列化,如何实现java序列化
2012/11/14 面试题
人资专员岗位职责
2014/04/04 职场文书
吨的认识教学反思
2014/04/27 职场文书
公司法定代表人授权委托书
2014/09/29 职场文书
党的作风建设心得体会
2014/10/22 职场文书
2016年公司新年寄语
2015/08/17 职场文书
单身狗福利?Python爬取某婚恋网征婚数据
2021/06/03 Python
JavaScript设计模式之原型模式详情
2022/06/21 Javascript