30秒轻松实现TensorFlow物体检测


Posted in Python onMarch 14, 2018

Google发布了新的TensorFlow物体检测API,包含了预训练模型,一个发布模型的jupyter notebook,一些可用于使用自己数据集对模型进行重新训练的有用脚本。

使用该API可以快速的构建一些图片中物体检测的应用。这里我们一步一步来看如何使用预训练模型来检测图像中的物体。

首先我们载入一些会使用的库

import numpy as np 
import os 
import six.moves.urllib as urllib 
import sys 
import tarfile 
import tensorflow as tf 
import zipfile 
 
from collections import defaultdict 
from io import StringIO 
from matplotlib import pyplot as plt 
from PIL import Image

接下来进行环境设置

%matplotlib inline 
sys.path.append("..")

物体检测载入

from utils import label_map_util 
 
from utils import visualization_utils as vis_util

准备模型

变量  任何使用export_inference_graph.py工具输出的模型可以在这里载入,只需简单改变PATH_TO_CKPT指向一个新的.pb文件。这里我们使用“移动网SSD”模型。

MODEL_NAME = 'ssd_mobilenet_v1_coco_11_06_2017' 
MODEL_FILE = MODEL_NAME + '.tar.gz' 
DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/' 
 
PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb' 
 
PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt') 
 
NUM_CLASSES = 90

下载模型

opener = urllib.request.URLopener() 
opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE) 
tar_file = tarfile.open(MODEL_FILE) 
for file in tar_file.getmembers(): 
  file_name = os.path.basename(file.name) 
  if 'frozen_inference_graph.pb' in file_name: 
    tar_file.extract(file, os.getcwd())

将(frozen)TensorFlow模型载入内存

detection_graph = tf.Graph() 
with detection_graph.as_default(): 
  od_graph_def = tf.GraphDef() 
  with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid: 
    serialized_graph = fid.read() 
    od_graph_def.ParseFromString(serialized_graph) 
    tf.import_graph_def(od_graph_def, name='')

载入标签图

标签图将索引映射到类名称,当我们的卷积预测5时,我们知道它对应飞机。这里我们使用内置函数,但是任何返回将整数映射到恰当字符标签的字典都适用。

label_map = label_map_util.load_labelmap(PATH_TO_LABELS) 
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True) 
category_index = label_map_util.create_category_index(categories)

辅助代码

def load_image_into_numpy_array(image): 
 (im_width, im_height) = image.size 
 return np.array(image.getdata()).reshape( 
   (im_height, im_width, 3)).astype(np.uint8)

检测

PATH_TO_TEST_IMAGES_DIR = 'test_images' 
TEST_IMAGE_PATHS = [ os.path.join(PATH_TO_TEST_IMAGES_DIR, 'image{}.jpg'.format(i)) for i in range(1, 3) ] 
IMAGE_SIZE = (12, 8) 
[python] view plain copy
with detection_graph.as_default(): 
 
 with tf.Session(graph=detection_graph) as sess: 
  for image_path in TEST_IMAGE_PATHS: 
   image = Image.open(image_path) 
   # 这个array在之后会被用来准备为图片加上框和标签 
   image_np = load_image_into_numpy_array(image) 
   # 扩展维度,应为模型期待: [1, None, None, 3] 
   image_np_expanded = np.expand_dims(image_np, axis=0) 
   image_tensor = detection_graph.get_tensor_by_name('image_tensor:0') 
   # 每个框代表一个物体被侦测到. 
   boxes = detection_graph.get_tensor_by_name('detection_boxes:0') 
   # 每个分值代表侦测到物体的可信度. 
   scores = detection_graph.get_tensor_by_name('detection_scores:0') 
   classes = detection_graph.get_tensor_by_name('detection_classes:0') 
   num_detections = detection_graph.get_tensor_by_name('num_detections:0') 
   # 执行侦测任务. 
   (boxes, scores, classes, num_detections) = sess.run( 
     [boxes, scores, classes, num_detections], 
     feed_dict={image_tensor: image_np_expanded}) 
   # 图形化. 
   vis_util.visualize_boxes_and_labels_on_image_array( 
     image_np, 
     np.squeeze(boxes), 
     np.squeeze(classes).astype(np.int32), 
     np.squeeze(scores), 
     category_index, 
     use_normalized_coordinates=True, 
     line_thickness=8) 
   plt.figure(figsize=IMAGE_SIZE) 
   plt.imshow(image_np)

在载入模型部分可以尝试不同的侦测模型以比较速度和准确度,将你想侦测的图片放入TEST_IMAGE_PATHS中运行即可。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中对list去重的多种方法
Sep 18 Python
python实现调用其他python脚本的方法
Oct 05 Python
Android 兼容性问题:java.lang.UnsupportedOperationException解决办法
Mar 19 Python
Python轻量级ORM框架Peewee访问sqlite数据库的方法详解
Jul 20 Python
Python多线程扫描端口代码示例
Feb 09 Python
Php多进程实现代码
May 07 Python
Python面向对象封装操作案例详解 II
Jan 02 Python
TensorFLow 不同大小图片的TFrecords存取实例
Jan 20 Python
Python 动态变量名定义与调用方法
Feb 09 Python
重写django的model下的objects模型管理器方式
May 15 Python
python PIL模块的基本使用
Sep 29 Python
virtualenv隔离Python环境的问题解析
Jun 21 Python
tensorflow识别自己手写数字
Mar 14 #Python
磁盘垃圾文件清理器python代码实现
Aug 24 #Python
Django自定义用户认证示例详解
Mar 14 #Python
python如何压缩新文件到已有ZIP文件
Mar 14 #Python
python中format()函数的简单使用教程
Mar 14 #Python
Python批量提取PDF文件中文本的脚本
Mar 14 #Python
深入理解Django的中间件middleware
Mar 14 #Python
You might like
PHP错误提示的关闭方法详解
2013/06/23 PHP
php中的四舍五入函数代码(floor函数、ceil函数、round与intval)
2014/07/14 PHP
利用PHP如何统计Nginx日志的User Agent数据
2019/03/06 PHP
通用于ie和firefox的函数 GetCurrentStyle (obj, prop)
2006/12/27 Javascript
JavaScript脚本性能的优化方法
2007/02/02 Javascript
JavaScript RegExp方法获取地址栏参数(面向对象)
2009/03/10 Javascript
jquery bind(click)传参让列表中每行绑定一个事件
2014/08/06 Javascript
jQuery动态添加可拖动元素完整实例(附demo源码下载)
2016/06/21 Javascript
基于bootstrap的文件上传控件bootstrap fileinput
2016/12/23 Javascript
JS仿QQ好友列表展开、收缩功能(第一篇)
2017/07/07 Javascript
Vux+Axios拦截器增加loading的问题及实现方法
2018/11/08 Javascript
JS获取表格视图所选行号的ids过程解析
2020/02/21 Javascript
Python下载指定页面上图片的方法
2016/05/12 Python
Python引用传值概念与用法实例小结
2017/10/07 Python
[原创]教女朋友学Python3(二)简单的输入输出及内置函数查看
2017/11/30 Python
微信跳一跳python辅助脚本(总结)
2018/01/11 Python
python 矩阵增加一行或一列的实例
2018/04/04 Python
在Python中实现替换字符串中的子串的示例
2018/10/31 Python
python实现简易数码时钟
2021/02/19 Python
浅谈Python接口对json串的处理方法
2018/12/19 Python
python3中rank函数的用法
2019/11/27 Python
HTML5 script元素async、defer异步加载使用介绍
2013/08/23 HTML / CSS
记一次高分屏下canvas模糊问题
2020/02/17 HTML / CSS
Hertz荷兰:荷兰和全球租车
2018/01/07 全球购物
英国在线房屋中介网站:Yopa
2018/01/09 全球购物
法律专业个人实习自我鉴定
2013/09/23 职场文书
校园联欢晚会主持词
2014/03/17 职场文书
党员干部群众路线教育实践活动个人对照检查材料
2014/09/23 职场文书
2015年元旦标语大全
2014/12/09 职场文书
优秀团员主要事迹材料
2015/11/05 职场文书
人生哲理妙语30条:淡写流年,笑过人生
2019/09/04 职场文书
Mysql - 常用函数 每天积极向上
2021/04/05 MySQL
纯html+css实现奥运五环的示例代码
2021/08/02 HTML / CSS
用Python生成会跳舞的美女
2022/01/18 Python
Python中的 Set 与 dict
2022/03/13 Python
Python中itertools库的四个函数介绍
2022/04/06 Python