30秒轻松实现TensorFlow物体检测


Posted in Python onMarch 14, 2018

Google发布了新的TensorFlow物体检测API,包含了预训练模型,一个发布模型的jupyter notebook,一些可用于使用自己数据集对模型进行重新训练的有用脚本。

使用该API可以快速的构建一些图片中物体检测的应用。这里我们一步一步来看如何使用预训练模型来检测图像中的物体。

首先我们载入一些会使用的库

import numpy as np 
import os 
import six.moves.urllib as urllib 
import sys 
import tarfile 
import tensorflow as tf 
import zipfile 
 
from collections import defaultdict 
from io import StringIO 
from matplotlib import pyplot as plt 
from PIL import Image

接下来进行环境设置

%matplotlib inline 
sys.path.append("..")

物体检测载入

from utils import label_map_util 
 
from utils import visualization_utils as vis_util

准备模型

变量  任何使用export_inference_graph.py工具输出的模型可以在这里载入,只需简单改变PATH_TO_CKPT指向一个新的.pb文件。这里我们使用“移动网SSD”模型。

MODEL_NAME = 'ssd_mobilenet_v1_coco_11_06_2017' 
MODEL_FILE = MODEL_NAME + '.tar.gz' 
DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/' 
 
PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb' 
 
PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt') 
 
NUM_CLASSES = 90

下载模型

opener = urllib.request.URLopener() 
opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE) 
tar_file = tarfile.open(MODEL_FILE) 
for file in tar_file.getmembers(): 
  file_name = os.path.basename(file.name) 
  if 'frozen_inference_graph.pb' in file_name: 
    tar_file.extract(file, os.getcwd())

将(frozen)TensorFlow模型载入内存

detection_graph = tf.Graph() 
with detection_graph.as_default(): 
  od_graph_def = tf.GraphDef() 
  with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid: 
    serialized_graph = fid.read() 
    od_graph_def.ParseFromString(serialized_graph) 
    tf.import_graph_def(od_graph_def, name='')

载入标签图

标签图将索引映射到类名称,当我们的卷积预测5时,我们知道它对应飞机。这里我们使用内置函数,但是任何返回将整数映射到恰当字符标签的字典都适用。

label_map = label_map_util.load_labelmap(PATH_TO_LABELS) 
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True) 
category_index = label_map_util.create_category_index(categories)

辅助代码

def load_image_into_numpy_array(image): 
 (im_width, im_height) = image.size 
 return np.array(image.getdata()).reshape( 
   (im_height, im_width, 3)).astype(np.uint8)

检测

PATH_TO_TEST_IMAGES_DIR = 'test_images' 
TEST_IMAGE_PATHS = [ os.path.join(PATH_TO_TEST_IMAGES_DIR, 'image{}.jpg'.format(i)) for i in range(1, 3) ] 
IMAGE_SIZE = (12, 8) 
[python] view plain copy
with detection_graph.as_default(): 
 
 with tf.Session(graph=detection_graph) as sess: 
  for image_path in TEST_IMAGE_PATHS: 
   image = Image.open(image_path) 
   # 这个array在之后会被用来准备为图片加上框和标签 
   image_np = load_image_into_numpy_array(image) 
   # 扩展维度,应为模型期待: [1, None, None, 3] 
   image_np_expanded = np.expand_dims(image_np, axis=0) 
   image_tensor = detection_graph.get_tensor_by_name('image_tensor:0') 
   # 每个框代表一个物体被侦测到. 
   boxes = detection_graph.get_tensor_by_name('detection_boxes:0') 
   # 每个分值代表侦测到物体的可信度. 
   scores = detection_graph.get_tensor_by_name('detection_scores:0') 
   classes = detection_graph.get_tensor_by_name('detection_classes:0') 
   num_detections = detection_graph.get_tensor_by_name('num_detections:0') 
   # 执行侦测任务. 
   (boxes, scores, classes, num_detections) = sess.run( 
     [boxes, scores, classes, num_detections], 
     feed_dict={image_tensor: image_np_expanded}) 
   # 图形化. 
   vis_util.visualize_boxes_and_labels_on_image_array( 
     image_np, 
     np.squeeze(boxes), 
     np.squeeze(classes).astype(np.int32), 
     np.squeeze(scores), 
     category_index, 
     use_normalized_coordinates=True, 
     line_thickness=8) 
   plt.figure(figsize=IMAGE_SIZE) 
   plt.imshow(image_np)

在载入模型部分可以尝试不同的侦测模型以比较速度和准确度,将你想侦测的图片放入TEST_IMAGE_PATHS中运行即可。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python脚本实现查找webshell的方法
Jul 31 Python
Python实现list反转实例汇总
Nov 11 Python
python解决汉字编码问题:Unicode Decode Error
Jan 19 Python
Python学习小技巧之列表项的拼接
May 20 Python
python中defaultdict的用法详解
Jun 07 Python
Python算法之求n个节点不同二叉树个数
Oct 27 Python
Python实现一个Git日志统计分析的小工具
Dec 14 Python
Django中redis的使用方法(包括安装、配置、启动)
Feb 21 Python
python利用wx实现界面按钮和按钮监听和字体改变的方法
Jul 17 Python
python中文分词库jieba使用方法详解
Feb 11 Python
python利用递归方法实现求集合的幂集
Sep 07 Python
BeautifulSoup中find和find_all的使用详解
Dec 07 Python
tensorflow识别自己手写数字
Mar 14 #Python
磁盘垃圾文件清理器python代码实现
Aug 24 #Python
Django自定义用户认证示例详解
Mar 14 #Python
python如何压缩新文件到已有ZIP文件
Mar 14 #Python
python中format()函数的简单使用教程
Mar 14 #Python
Python批量提取PDF文件中文本的脚本
Mar 14 #Python
深入理解Django的中间件middleware
Mar 14 #Python
You might like
PHP is_subclass_of函数的一个BUG和解决方法
2014/06/01 PHP
php设计模式之委托模式
2016/02/13 PHP
原生js实现跨浏览器获取鼠标按键的值
2013/04/08 Javascript
jquery事件preventDefault()方法用法实例
2015/01/16 Javascript
javascript跨域方法、原理以及出现问题解决方法(详解)
2015/08/06 Javascript
新入门node.js必须要知道的概念(必看篇)
2016/08/10 Javascript
简单谈谈ES6的六个小特性
2016/11/18 Javascript
基于Bootstrap漂亮简洁的CSS3价格表(附源码下载)
2017/02/28 Javascript
vue2实现移动端上传、预览、压缩图片解决拍照旋转问题
2017/04/13 Javascript
详解使用vue实现tab 切换操作
2017/07/03 Javascript
利用百度地图API获取当前位置信息的实例
2017/11/06 Javascript
ES6/JavaScript使用技巧分享
2017/12/14 Javascript
JavaScript中click和onclick本质区别与用法分析
2018/06/07 Javascript
浅谈vuex为什么不建议在action中修改state
2020/02/02 Javascript
jquery实现抽奖功能
2020/10/22 jQuery
微信小程序中target和currentTarget的区别小结
2020/11/06 Javascript
Python实现文件复制删除
2016/04/19 Python
Python采用Django开发自己的博客系统
2020/09/29 Python
Python绘制3d螺旋曲线图实例代码
2017/12/20 Python
Django后台获取前端post上传的文件方法
2018/05/28 Python
python调用matlab的m自定义函数方法
2019/02/18 Python
Python集合操作方法详解
2020/02/09 Python
LocalStorage记住用户和密码功能
2017/07/24 HTML / CSS
SmartBuyGlasses比利时:购买品牌太阳镜和眼镜
2019/08/09 全球购物
精致的手工皮鞋:Shoe Embassy
2019/11/08 全球购物
C#中有没有运算符重载?能否使用指针?
2014/05/05 面试题
会计毕业生求职简历的自我评价
2013/10/20 职场文书
暑期社会实践学生的自我评价
2014/01/09 职场文书
道路建设实施方案
2014/03/18 职场文书
会议接待欢迎标语
2014/10/08 职场文书
绿色校园广播稿
2014/10/13 职场文书
大连星海广场导游词
2015/02/10 职场文书
2015年党务工作者个人工作总结
2015/10/22 职场文书
大学学生会主席竞选稿怎么写?
2019/08/19 职场文书
基于Python实现射击小游戏的制作
2022/04/06 Python
Flutter Navigator 实现路由传递参数
2022/04/22 Java/Android