30秒轻松实现TensorFlow物体检测


Posted in Python onMarch 14, 2018

Google发布了新的TensorFlow物体检测API,包含了预训练模型,一个发布模型的jupyter notebook,一些可用于使用自己数据集对模型进行重新训练的有用脚本。

使用该API可以快速的构建一些图片中物体检测的应用。这里我们一步一步来看如何使用预训练模型来检测图像中的物体。

首先我们载入一些会使用的库

import numpy as np 
import os 
import six.moves.urllib as urllib 
import sys 
import tarfile 
import tensorflow as tf 
import zipfile 
 
from collections import defaultdict 
from io import StringIO 
from matplotlib import pyplot as plt 
from PIL import Image

接下来进行环境设置

%matplotlib inline 
sys.path.append("..")

物体检测载入

from utils import label_map_util 
 
from utils import visualization_utils as vis_util

准备模型

变量  任何使用export_inference_graph.py工具输出的模型可以在这里载入,只需简单改变PATH_TO_CKPT指向一个新的.pb文件。这里我们使用“移动网SSD”模型。

MODEL_NAME = 'ssd_mobilenet_v1_coco_11_06_2017' 
MODEL_FILE = MODEL_NAME + '.tar.gz' 
DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/' 
 
PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb' 
 
PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt') 
 
NUM_CLASSES = 90

下载模型

opener = urllib.request.URLopener() 
opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE) 
tar_file = tarfile.open(MODEL_FILE) 
for file in tar_file.getmembers(): 
  file_name = os.path.basename(file.name) 
  if 'frozen_inference_graph.pb' in file_name: 
    tar_file.extract(file, os.getcwd())

将(frozen)TensorFlow模型载入内存

detection_graph = tf.Graph() 
with detection_graph.as_default(): 
  od_graph_def = tf.GraphDef() 
  with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid: 
    serialized_graph = fid.read() 
    od_graph_def.ParseFromString(serialized_graph) 
    tf.import_graph_def(od_graph_def, name='')

载入标签图

标签图将索引映射到类名称,当我们的卷积预测5时,我们知道它对应飞机。这里我们使用内置函数,但是任何返回将整数映射到恰当字符标签的字典都适用。

label_map = label_map_util.load_labelmap(PATH_TO_LABELS) 
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True) 
category_index = label_map_util.create_category_index(categories)

辅助代码

def load_image_into_numpy_array(image): 
 (im_width, im_height) = image.size 
 return np.array(image.getdata()).reshape( 
   (im_height, im_width, 3)).astype(np.uint8)

检测

PATH_TO_TEST_IMAGES_DIR = 'test_images' 
TEST_IMAGE_PATHS = [ os.path.join(PATH_TO_TEST_IMAGES_DIR, 'image{}.jpg'.format(i)) for i in range(1, 3) ] 
IMAGE_SIZE = (12, 8) 
[python] view plain copy
with detection_graph.as_default(): 
 
 with tf.Session(graph=detection_graph) as sess: 
  for image_path in TEST_IMAGE_PATHS: 
   image = Image.open(image_path) 
   # 这个array在之后会被用来准备为图片加上框和标签 
   image_np = load_image_into_numpy_array(image) 
   # 扩展维度,应为模型期待: [1, None, None, 3] 
   image_np_expanded = np.expand_dims(image_np, axis=0) 
   image_tensor = detection_graph.get_tensor_by_name('image_tensor:0') 
   # 每个框代表一个物体被侦测到. 
   boxes = detection_graph.get_tensor_by_name('detection_boxes:0') 
   # 每个分值代表侦测到物体的可信度. 
   scores = detection_graph.get_tensor_by_name('detection_scores:0') 
   classes = detection_graph.get_tensor_by_name('detection_classes:0') 
   num_detections = detection_graph.get_tensor_by_name('num_detections:0') 
   # 执行侦测任务. 
   (boxes, scores, classes, num_detections) = sess.run( 
     [boxes, scores, classes, num_detections], 
     feed_dict={image_tensor: image_np_expanded}) 
   # 图形化. 
   vis_util.visualize_boxes_and_labels_on_image_array( 
     image_np, 
     np.squeeze(boxes), 
     np.squeeze(classes).astype(np.int32), 
     np.squeeze(scores), 
     category_index, 
     use_normalized_coordinates=True, 
     line_thickness=8) 
   plt.figure(figsize=IMAGE_SIZE) 
   plt.imshow(image_np)

在载入模型部分可以尝试不同的侦测模型以比较速度和准确度,将你想侦测的图片放入TEST_IMAGE_PATHS中运行即可。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的rfind()方法使用详解
May 19 Python
python使用urllib2提交http post请求的方法
May 26 Python
Python的条件语句与运算符优先级详解
Oct 13 Python
python中kmeans聚类实现代码
Feb 23 Python
Numpy数组的保存与读取方法
Apr 04 Python
Django 浅谈根据配置生成SQL语句的问题
May 29 Python
python使用matplotlib画柱状图、散点图
Mar 18 Python
python:解析requests返回的response(json格式)说明
Apr 30 Python
django 利用Q对象与F对象进行查询的实现
May 15 Python
python 发送邮件的四种方法汇总
Dec 02 Python
pycharm配置python 设置pip安装源为豆瓣源
Feb 05 Python
python和Appium的移动端多设备自动化测试框架
Apr 26 Python
tensorflow识别自己手写数字
Mar 14 #Python
磁盘垃圾文件清理器python代码实现
Aug 24 #Python
Django自定义用户认证示例详解
Mar 14 #Python
python如何压缩新文件到已有ZIP文件
Mar 14 #Python
python中format()函数的简单使用教程
Mar 14 #Python
Python批量提取PDF文件中文本的脚本
Mar 14 #Python
深入理解Django的中间件middleware
Mar 14 #Python
You might like
常用的php对象类型判断
2008/08/27 PHP
过滤掉PHP数组中的重复值的实现代码
2011/07/17 PHP
百度工程师讲PHP函数的实现原理及性能分析(一)
2015/05/13 PHP
PHP对文件夹递归执行chmod命令的方法
2015/06/19 PHP
谈谈PHP连接Access数据库的注意事项
2016/08/12 PHP
js文字滚动停顿效果代码
2008/06/28 Javascript
初学js 新节点的创建 删除 的步骤
2011/07/04 Javascript
Javascript获取窗口(容器)的大小及位置参数列举及简要说明
2012/12/09 Javascript
jquery点击页面任何区域实现鼠标焦点十字效果
2013/06/21 Javascript
eclipse如何忽略js文件报错(附图)
2013/10/30 Javascript
jQuery 鼠标经过(hover)事件的延时处理示例
2014/04/14 Javascript
jquery mobile页面跳转后样式丢失js失效的解决方法
2014/09/06 Javascript
jQuery中unwrap()方法用法实例
2015/01/16 Javascript
jquery获取及设置outerhtml的方法
2015/03/09 Javascript
基于jquery实现select选择框内容左右移动添加删除代码分享
2015/08/25 Javascript
JavaScript实现跑马灯抽奖活动实例代码解析与优化(一)
2016/02/16 Javascript
jquery uploadify如何取消已上传成功文件
2017/02/08 Javascript
vue移动端实现下拉刷新
2018/04/22 Javascript
Vue中的methods、watch、computed的区别
2018/11/26 Javascript
举例讲解Django中数据模型访问外键值的方法
2015/07/21 Python
使用Python从有道词典网页获取单词翻译
2016/07/03 Python
Python脚本实现监听服务器的思路代码详解
2020/05/28 Python
python datetime时间格式的相互转换问题
2020/06/11 Python
什么是Python包的循环导入
2020/09/08 Python
网购亚洲时装、美容产品和生活百货:YesStyle
2016/09/15 全球购物
Notino罗马尼亚网站:购买香水和化妆品
2019/07/20 全球购物
如何打开WebSphere远程debug
2014/10/10 面试题
.NET程序员的数据库面试题
2012/10/10 面试题
设计模式的基本要素是什么
2014/04/21 面试题
高二英语教学反思
2014/01/19 职场文书
小学生秋游活动方案
2014/02/23 职场文书
农业局党的群众路线教育实践活动整改方案
2014/09/20 职场文书
2014年底工作总结
2014/12/15 职场文书
一年级数学下册复习计划
2015/01/17 职场文书
星际争霸 Light vs Action 一场把教主看到鬼畜的比赛
2022/04/01 星际争霸
一文搞懂Redis中String数据类型
2022/04/03 Redis