30秒轻松实现TensorFlow物体检测


Posted in Python onMarch 14, 2018

Google发布了新的TensorFlow物体检测API,包含了预训练模型,一个发布模型的jupyter notebook,一些可用于使用自己数据集对模型进行重新训练的有用脚本。

使用该API可以快速的构建一些图片中物体检测的应用。这里我们一步一步来看如何使用预训练模型来检测图像中的物体。

首先我们载入一些会使用的库

import numpy as np 
import os 
import six.moves.urllib as urllib 
import sys 
import tarfile 
import tensorflow as tf 
import zipfile 
 
from collections import defaultdict 
from io import StringIO 
from matplotlib import pyplot as plt 
from PIL import Image

接下来进行环境设置

%matplotlib inline 
sys.path.append("..")

物体检测载入

from utils import label_map_util 
 
from utils import visualization_utils as vis_util

准备模型

变量  任何使用export_inference_graph.py工具输出的模型可以在这里载入,只需简单改变PATH_TO_CKPT指向一个新的.pb文件。这里我们使用“移动网SSD”模型。

MODEL_NAME = 'ssd_mobilenet_v1_coco_11_06_2017' 
MODEL_FILE = MODEL_NAME + '.tar.gz' 
DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/' 
 
PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb' 
 
PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt') 
 
NUM_CLASSES = 90

下载模型

opener = urllib.request.URLopener() 
opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE) 
tar_file = tarfile.open(MODEL_FILE) 
for file in tar_file.getmembers(): 
  file_name = os.path.basename(file.name) 
  if 'frozen_inference_graph.pb' in file_name: 
    tar_file.extract(file, os.getcwd())

将(frozen)TensorFlow模型载入内存

detection_graph = tf.Graph() 
with detection_graph.as_default(): 
  od_graph_def = tf.GraphDef() 
  with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid: 
    serialized_graph = fid.read() 
    od_graph_def.ParseFromString(serialized_graph) 
    tf.import_graph_def(od_graph_def, name='')

载入标签图

标签图将索引映射到类名称,当我们的卷积预测5时,我们知道它对应飞机。这里我们使用内置函数,但是任何返回将整数映射到恰当字符标签的字典都适用。

label_map = label_map_util.load_labelmap(PATH_TO_LABELS) 
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True) 
category_index = label_map_util.create_category_index(categories)

辅助代码

def load_image_into_numpy_array(image): 
 (im_width, im_height) = image.size 
 return np.array(image.getdata()).reshape( 
   (im_height, im_width, 3)).astype(np.uint8)

检测

PATH_TO_TEST_IMAGES_DIR = 'test_images' 
TEST_IMAGE_PATHS = [ os.path.join(PATH_TO_TEST_IMAGES_DIR, 'image{}.jpg'.format(i)) for i in range(1, 3) ] 
IMAGE_SIZE = (12, 8) 
[python] view plain copy
with detection_graph.as_default(): 
 
 with tf.Session(graph=detection_graph) as sess: 
  for image_path in TEST_IMAGE_PATHS: 
   image = Image.open(image_path) 
   # 这个array在之后会被用来准备为图片加上框和标签 
   image_np = load_image_into_numpy_array(image) 
   # 扩展维度,应为模型期待: [1, None, None, 3] 
   image_np_expanded = np.expand_dims(image_np, axis=0) 
   image_tensor = detection_graph.get_tensor_by_name('image_tensor:0') 
   # 每个框代表一个物体被侦测到. 
   boxes = detection_graph.get_tensor_by_name('detection_boxes:0') 
   # 每个分值代表侦测到物体的可信度. 
   scores = detection_graph.get_tensor_by_name('detection_scores:0') 
   classes = detection_graph.get_tensor_by_name('detection_classes:0') 
   num_detections = detection_graph.get_tensor_by_name('num_detections:0') 
   # 执行侦测任务. 
   (boxes, scores, classes, num_detections) = sess.run( 
     [boxes, scores, classes, num_detections], 
     feed_dict={image_tensor: image_np_expanded}) 
   # 图形化. 
   vis_util.visualize_boxes_and_labels_on_image_array( 
     image_np, 
     np.squeeze(boxes), 
     np.squeeze(classes).astype(np.int32), 
     np.squeeze(scores), 
     category_index, 
     use_normalized_coordinates=True, 
     line_thickness=8) 
   plt.figure(figsize=IMAGE_SIZE) 
   plt.imshow(image_np)

在载入模型部分可以尝试不同的侦测模型以比较速度和准确度,将你想侦测的图片放入TEST_IMAGE_PATHS中运行即可。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python读写文件操作示例程序
Dec 02 Python
python实现简单温度转换的方法
Mar 13 Python
python数据类型判断type与isinstance的区别实例解析
Oct 31 Python
python 3.0 模拟用户登录功能并实现三次错误锁定
Nov 01 Python
django的settings中设置中文支持的实现
Apr 28 Python
Python编程实现tail-n查看日志文件的方法
Jul 08 Python
python3模拟实现xshell远程执行liunx命令的方法
Jul 12 Python
python 怎样将dataframe中的字符串日期转化为日期的方法
Sep 26 Python
python如何从键盘获取输入实例
Jun 18 Python
Python爬虫之Selenium鼠标事件的实现
Dec 04 Python
Python列表的索引与切片
Apr 07 Python
Python万能模板案例之matplotlib绘制直方图的基本配置
Apr 13 Python
tensorflow识别自己手写数字
Mar 14 #Python
磁盘垃圾文件清理器python代码实现
Aug 24 #Python
Django自定义用户认证示例详解
Mar 14 #Python
python如何压缩新文件到已有ZIP文件
Mar 14 #Python
python中format()函数的简单使用教程
Mar 14 #Python
Python批量提取PDF文件中文本的脚本
Mar 14 #Python
深入理解Django的中间件middleware
Mar 14 #Python
You might like
php数组使用规则分析
2015/02/27 PHP
学习php设计模式 php实现装饰器模式(decorator)
2015/12/07 PHP
smarty高级特性之对象的使用方法
2015/12/25 PHP
PHP简单处理表单输入的特殊字符的方法
2016/02/03 PHP
Laravel框架实现的记录SQL日志功能示例
2018/06/19 PHP
php高性能日志系统 seaslog 的安装与使用方法分析
2020/02/29 PHP
javascript window.opener的用法分析
2010/04/07 Javascript
js 处理URL实用技巧
2010/11/23 Javascript
扩展jquery实现客户端表格的分页、排序功能代码
2011/03/16 Javascript
Extjs EditorGridPanel中ComboBox列的显示问题
2011/07/04 Javascript
Extjs中使用extend(js继承) 的代码
2012/03/15 Javascript
js弹出模式对话框,并接收回传值的方法
2013/03/12 Javascript
ExtJS4如何自动生成控制grid的列显示、隐藏的checkbox
2014/05/02 Javascript
jQuery 处理页面的事件详解
2015/01/20 Javascript
jQuery如何使用自动触发事件trigger
2015/11/29 Javascript
Node.js 文件夹目录结构创建实例代码
2016/07/08 Javascript
JS实现太极旋转思路分析
2016/12/09 Javascript
Angular1.x复杂指令实例详解
2017/03/01 Javascript
JavaScript 函数的定义-调用、注意事项
2017/04/16 Javascript
js图片轮播插件的封装
2017/07/21 Javascript
基于daterangepicker日历插件使用参数注意的问题
2017/08/10 Javascript
解决layui前端框架 form表单,table表等内置控件不显示的问题
2018/08/19 Javascript
vue axios 简单封装以及思考
2018/10/09 Javascript
JavaScript实现单图片上传并预览功能
2019/09/30 Javascript
JavaScript中的this/call/apply/bind的使用及区别
2020/03/06 Javascript
python的类变量和成员变量用法实例教程
2014/08/25 Python
python读取word文档的方法
2015/05/09 Python
Python中的列表生成式与生成器学习教程
2016/03/13 Python
运行Python编写的程序方法实例
2020/10/21 Python
Python Selenium XPath根据文本内容查找元素的方法
2020/12/07 Python
python爬虫爬取某网站视频的示例代码
2021/02/20 Python
CSS3 @media的基本用法总结
2019/09/10 HTML / CSS
广州御银科技股份有限公司试卷(C++)
2016/11/04 面试题
房屋租赁意向书
2014/04/01 职场文书
写给医院的感谢信
2015/01/22 职场文书
简历中的自我评价怎么写呢?
2019/04/30 职场文书