详解如何用TensorFlow训练和识别/分类自定义图片


Posted in Python onAugust 05, 2019

很多正在入门或刚入门TensorFlow机器学习的同学希望能够通过自己指定图片源对模型进行训练,然后识别和分类自己指定的图片。但是,在TensorFlow官方入门教程中,并无明确给出如何把自定义数据输入训练模型的方法。现在,我们就参考官方入门课程《Deep MNIST for Experts》一节的内容(传送门:https://www.tensorflow.org/get_started/mnist/pros),介绍如何将自定义图片输入到TensorFlow的训练模型。

在《Deep MNISTfor Experts》一节的代码中,程序将TensorFlow自带的mnist图片数据集mnist.train.images作为训练输入,将mnist.test.images作为验证输入。当学习了该节内容后,我们会惊叹卷积神经网络的超高识别率,但对于刚开始学习TensorFlow的同学,内心可能会产生一个问号:如何将mnist数据集替换为自己指定的图片源?譬如,我要将图片源改为自己C盘里面的图片,应该怎么调整代码?

我们先看下该节课程中涉及到mnist图片调用的代码:

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
batch = mnist.train.next_batch(50)
train_accuracy = accuracy.eval(feed_dict={x: batch[0], y_: batch[1], keep_prob: 1.0})
train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
print('test accuracy %g' % accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))

对于刚接触TensorFlow的同学,要修改上述代码,可能会较为吃力。我也是经过一番摸索,才成功调用自己的图片集。

要实现输入自定义图片,需要自己先准备好一套图片集。为节省时间,我们把mnist的手写体数字集一张一张地解析出来,存放到自己的本地硬盘,保存为bmp格式,然后再把本地硬盘的手写体图片一张一张地读取出来,组成集合,再输入神经网络。mnist手写体数字集的提取方式详见《如何从TensorFlow的mnist数据集导出手写体数字图片》。

将mnist手写体数字集导出图片到本地后,就可以仿照以下python代码,实现自定义图片的训练:

#!/usr/bin/python3.5
# -*- coding: utf-8 -*- 
 
import os
 
import numpy as np
import tensorflow as tf
 
from PIL import Image
 
 
# 第一次遍历图片目录是为了获取图片总数
input_count = 0
for i in range(0,10):
  dir = './custom_images/%s/' % i         # 这里可以改成你自己的图片目录,i为分类标签
  for rt, dirs, files in os.walk(dir):
    for filename in files:
      input_count += 1
 
# 定义对应维数和各维长度的数组
input_images = np.array([[0]*784 for i in range(input_count)])
input_labels = np.array([[0]*10 for i in range(input_count)])
 
# 第二次遍历图片目录是为了生成图片数据和标签
index = 0
for i in range(0,10):
  dir = './custom_images/%s/' % i         # 这里可以改成你自己的图片目录,i为分类标签
  for rt, dirs, files in os.walk(dir):
    for filename in files:
      filename = dir + filename
      img = Image.open(filename)
      width = img.size[0]
      height = img.size[1]
      for h in range(0, height):
        for w in range(0, width):
          # 通过这样的处理,使数字的线条变细,有利于提高识别准确率
          if img.getpixel((w, h)) > 230:
            input_images[index][w+h*width] = 0
          else:
            input_images[index][w+h*width] = 1
      input_labels[index][i] = 1
      index += 1
 
 
# 定义输入节点,对应于图片像素值矩阵集合和图片标签(即所代表的数字)
x = tf.placeholder(tf.float32, shape=[None, 784])
y_ = tf.placeholder(tf.float32, shape=[None, 10])
 
x_image = tf.reshape(x, [-1, 28, 28, 1])
 
# 定义第一个卷积层的variables和ops
W_conv1 = tf.Variable(tf.truncated_normal([7, 7, 1, 32], stddev=0.1))
b_conv1 = tf.Variable(tf.constant(0.1, shape=[32]))
 
L1_conv = tf.nn.conv2d(x_image, W_conv1, strides=[1, 1, 1, 1], padding='SAME')
L1_relu = tf.nn.relu(L1_conv + b_conv1)
L1_pool = tf.nn.max_pool(L1_relu, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 
# 定义第二个卷积层的variables和ops
W_conv2 = tf.Variable(tf.truncated_normal([3, 3, 32, 64], stddev=0.1))
b_conv2 = tf.Variable(tf.constant(0.1, shape=[64]))
 
L2_conv = tf.nn.conv2d(L1_pool, W_conv2, strides=[1, 1, 1, 1], padding='SAME')
L2_relu = tf.nn.relu(L2_conv + b_conv2)
L2_pool = tf.nn.max_pool(L2_relu, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 
 
# 全连接层
W_fc1 = tf.Variable(tf.truncated_normal([7 * 7 * 64, 1024], stddev=0.1))
b_fc1 = tf.Variable(tf.constant(0.1, shape=[1024]))
 
h_pool2_flat = tf.reshape(L2_pool, [-1, 7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
 
 
# dropout
keep_prob = tf.placeholder(tf.float32)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
 
 
# readout层
W_fc2 = tf.Variable(tf.truncated_normal([1024, 10], stddev=0.1))
b_fc2 = tf.Variable(tf.constant(0.1, shape=[10]))
 
y_conv = tf.matmul(h_fc1_drop, W_fc2) + b_fc2
 
# 定义优化器和训练op
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_conv))
train_step = tf.train.AdamOptimizer((1e-4)).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
 
 
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
 
  print ("一共读取了 %s 个输入图像, %s 个标签" % (input_count, input_count))
 
  # 设置每次训练op的输入个数和迭代次数,这里为了支持任意图片总数,定义了一个余数remainder,譬如,如果每次训练op的输入个数为60,图片总数为150张,则前面两次各输入60张,最后一次输入30张(余数30)
  batch_size = 60
  iterations = 100
  batches_count = int(input_count / batch_size)
  remainder = input_count % batch_size
  print ("数据集分成 %s 批, 前面每批 %s 个数据,最后一批 %s 个数据" % (batches_count+1, batch_size, remainder))
 
  # 执行训练迭代
  for it in range(iterations):
    # 这里的关键是要把输入数组转为np.array
    for n in range(batches_count):
      train_step.run(feed_dict={x: input_images[n*batch_size:(n+1)*batch_size], y_: input_labels[n*batch_size:(n+1)*batch_size], keep_prob: 0.5})
    if remainder > 0:
      start_index = batches_count * batch_size;
      train_step.run(feed_dict={x: input_images[start_index:input_count-1], y_: input_labels[start_index:input_count-1], keep_prob: 0.5})
 
    # 每完成五次迭代,判断准确度是否已达到100%,达到则退出迭代循环
    iterate_accuracy = 0
    if it%5 == 0:
      iterate_accuracy = accuracy.eval(feed_dict={x: input_images, y_: input_labels, keep_prob: 1.0})
      print ('iteration %d: accuracy %s' % (it, iterate_accuracy))
      if iterate_accuracy >= 1:
        break;
 
  print ('完成训练!')

上述python代码的执行结果截图如下:

详解如何用TensorFlow训练和识别/分类自定义图片

对于上述代码中与模型构建相关的代码,请查阅官方《Deep MNIST for Experts》一节的内容进行理解。在本文中,需要重点掌握的是如何将本地图片源整合成为feed_dict可接受的格式。其中最关键的是这两行:

# 定义对应维数和各维长度的数组
input_images = np.array([[0]*784 for i in range(input_count)])
input_labels = np.array([[0]*10 for i in range(input_count)])

它们对应于feed_dict的两个placeholder:

x = tf.placeholder(tf.float32, shape=[None, 784])
y_ = tf.placeholder(tf.float32, shape=[None, 10])

这样一看,是不是很简单?

我们将在下一篇博文中介绍如何通过本文成果识别车牌数字。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Linux系统上安装Python的Scrapy框架的教程
Jun 11 Python
python虚拟环境virtualenv的使用教程
Oct 20 Python
Python探索之pLSA实现代码
Oct 25 Python
Python字典数据对象拆分的简单实现方法
Dec 05 Python
Python中列表与元组的乘法操作示例
Feb 10 Python
对python字典元素的添加与修改方法详解
Jul 06 Python
浅析python中numpy包中的argsort函数的使用
Aug 30 Python
python获取微信企业号打卡数据并生成windows计划任务
Apr 30 Python
Python自动化之数据驱动让你的脚本简洁10倍【推荐】
Jun 04 Python
python按比例随机切分数据的实现
Jul 11 Python
Django项目创建到启动详解(最全最详细)
Sep 07 Python
Python人工智能之混合高斯模型运动目标检测详解分析
Nov 07 Python
详解如何从TensorFlow的mnist数据集导出手写体数字图片
Aug 05 #Python
Python获取时间范围内日期列表和周列表的函数
Aug 05 #Python
Django ORM 查询管理器源码解析
Aug 05 #Python
python实现车牌识别的示例代码
Aug 05 #Python
使用python实现滑动验证码功能
Aug 05 #Python
Django 源码WSGI剖析过程详解
Aug 05 #Python
Python使用itchat 功能分析微信好友性别和位置
Aug 05 #Python
You might like
20个PHP常用类库小结
2011/09/11 PHP
ThinkPHP框架实现数据增删改
2017/05/07 PHP
PHP多种序列化/反序列化的方法详解
2017/06/23 PHP
tp5(thinkPHP5)操作mongoDB数据库的方法
2018/01/20 PHP
解决laravel groupBy 对查询结果进行分组出现的问题
2019/10/09 PHP
php 使用 __call实现重载功能示例
2019/11/18 PHP
处理及遍历XML文档DOM元素属性及方法整理
2013/08/23 Javascript
JavaScript两种跨域技术全面介绍
2014/04/16 Javascript
在css加载完毕后自动判断页面是否加入css或js文件
2014/09/10 Javascript
javascript实现的图片切割多块效果实例
2015/05/07 Javascript
利用css+原生js制作简单的钟表
2020/04/07 Javascript
jQuery Ajax请求后台数据并在前台接收
2016/12/10 Javascript
JavaScript中 DOM操作方法小结
2017/04/25 Javascript
JavaScript实现开关等效果
2017/09/08 Javascript
Bootstrap一款超好用的前端框架
2017/09/25 Javascript
微信小程序登录态和检验注册过没的app.js写法
2019/05/22 Javascript
Vue中通过属性绑定为元素绑定style行内样式的实例代码
2020/04/30 Javascript
js+canvas实现转盘效果(两个版本)
2020/09/13 Javascript
Python中使用 Selenium 实现网页截图实例
2014/07/18 Python
Python的Django框架可适配的各种数据库介绍
2015/07/15 Python
Python实现二维数组按照某行或列排序的方法【numpy lexsort】
2017/09/22 Python
解决Python requests库编码 socks5代理的问题
2018/05/07 Python
Python实现的json文件读取及中文乱码显示问题解决方法
2018/08/06 Python
python中join()方法介绍
2018/10/11 Python
python 获取一个值在某个区间的指定倍数的值方法
2018/11/12 Python
python plt可视化——打印特殊符号和制作图例代码
2020/04/17 Python
Python分类测试代码实例汇总
2020/07/23 Python
纯CSS实现的大小渐变、渐远效果
2014/04/15 HTML / CSS
AmazeUI 导航条的实现示例
2020/08/14 HTML / CSS
加拿大著名时装品牌:SOIA & KYO
2016/08/23 全球购物
奥地利票务门户网站:oeticket.com
2019/12/31 全球购物
国际贸易毕业生求职信范文
2014/02/21 职场文书
祖国在我心中演讲稿400字
2014/05/04 职场文书
校庆标语集锦
2014/06/25 职场文书
营销计划书范文
2015/01/17 职场文书
MySQL 语句执行顺序举例解析
2022/06/05 MySQL