详解如何用TensorFlow训练和识别/分类自定义图片


Posted in Python onAugust 05, 2019

很多正在入门或刚入门TensorFlow机器学习的同学希望能够通过自己指定图片源对模型进行训练,然后识别和分类自己指定的图片。但是,在TensorFlow官方入门教程中,并无明确给出如何把自定义数据输入训练模型的方法。现在,我们就参考官方入门课程《Deep MNIST for Experts》一节的内容(传送门:https://www.tensorflow.org/get_started/mnist/pros),介绍如何将自定义图片输入到TensorFlow的训练模型。

在《Deep MNISTfor Experts》一节的代码中,程序将TensorFlow自带的mnist图片数据集mnist.train.images作为训练输入,将mnist.test.images作为验证输入。当学习了该节内容后,我们会惊叹卷积神经网络的超高识别率,但对于刚开始学习TensorFlow的同学,内心可能会产生一个问号:如何将mnist数据集替换为自己指定的图片源?譬如,我要将图片源改为自己C盘里面的图片,应该怎么调整代码?

我们先看下该节课程中涉及到mnist图片调用的代码:

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
batch = mnist.train.next_batch(50)
train_accuracy = accuracy.eval(feed_dict={x: batch[0], y_: batch[1], keep_prob: 1.0})
train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
print('test accuracy %g' % accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))

对于刚接触TensorFlow的同学,要修改上述代码,可能会较为吃力。我也是经过一番摸索,才成功调用自己的图片集。

要实现输入自定义图片,需要自己先准备好一套图片集。为节省时间,我们把mnist的手写体数字集一张一张地解析出来,存放到自己的本地硬盘,保存为bmp格式,然后再把本地硬盘的手写体图片一张一张地读取出来,组成集合,再输入神经网络。mnist手写体数字集的提取方式详见《如何从TensorFlow的mnist数据集导出手写体数字图片》。

将mnist手写体数字集导出图片到本地后,就可以仿照以下python代码,实现自定义图片的训练:

#!/usr/bin/python3.5
# -*- coding: utf-8 -*- 
 
import os
 
import numpy as np
import tensorflow as tf
 
from PIL import Image
 
 
# 第一次遍历图片目录是为了获取图片总数
input_count = 0
for i in range(0,10):
  dir = './custom_images/%s/' % i         # 这里可以改成你自己的图片目录,i为分类标签
  for rt, dirs, files in os.walk(dir):
    for filename in files:
      input_count += 1
 
# 定义对应维数和各维长度的数组
input_images = np.array([[0]*784 for i in range(input_count)])
input_labels = np.array([[0]*10 for i in range(input_count)])
 
# 第二次遍历图片目录是为了生成图片数据和标签
index = 0
for i in range(0,10):
  dir = './custom_images/%s/' % i         # 这里可以改成你自己的图片目录,i为分类标签
  for rt, dirs, files in os.walk(dir):
    for filename in files:
      filename = dir + filename
      img = Image.open(filename)
      width = img.size[0]
      height = img.size[1]
      for h in range(0, height):
        for w in range(0, width):
          # 通过这样的处理,使数字的线条变细,有利于提高识别准确率
          if img.getpixel((w, h)) > 230:
            input_images[index][w+h*width] = 0
          else:
            input_images[index][w+h*width] = 1
      input_labels[index][i] = 1
      index += 1
 
 
# 定义输入节点,对应于图片像素值矩阵集合和图片标签(即所代表的数字)
x = tf.placeholder(tf.float32, shape=[None, 784])
y_ = tf.placeholder(tf.float32, shape=[None, 10])
 
x_image = tf.reshape(x, [-1, 28, 28, 1])
 
# 定义第一个卷积层的variables和ops
W_conv1 = tf.Variable(tf.truncated_normal([7, 7, 1, 32], stddev=0.1))
b_conv1 = tf.Variable(tf.constant(0.1, shape=[32]))
 
L1_conv = tf.nn.conv2d(x_image, W_conv1, strides=[1, 1, 1, 1], padding='SAME')
L1_relu = tf.nn.relu(L1_conv + b_conv1)
L1_pool = tf.nn.max_pool(L1_relu, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 
# 定义第二个卷积层的variables和ops
W_conv2 = tf.Variable(tf.truncated_normal([3, 3, 32, 64], stddev=0.1))
b_conv2 = tf.Variable(tf.constant(0.1, shape=[64]))
 
L2_conv = tf.nn.conv2d(L1_pool, W_conv2, strides=[1, 1, 1, 1], padding='SAME')
L2_relu = tf.nn.relu(L2_conv + b_conv2)
L2_pool = tf.nn.max_pool(L2_relu, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 
 
# 全连接层
W_fc1 = tf.Variable(tf.truncated_normal([7 * 7 * 64, 1024], stddev=0.1))
b_fc1 = tf.Variable(tf.constant(0.1, shape=[1024]))
 
h_pool2_flat = tf.reshape(L2_pool, [-1, 7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
 
 
# dropout
keep_prob = tf.placeholder(tf.float32)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
 
 
# readout层
W_fc2 = tf.Variable(tf.truncated_normal([1024, 10], stddev=0.1))
b_fc2 = tf.Variable(tf.constant(0.1, shape=[10]))
 
y_conv = tf.matmul(h_fc1_drop, W_fc2) + b_fc2
 
# 定义优化器和训练op
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_conv))
train_step = tf.train.AdamOptimizer((1e-4)).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
 
 
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
 
  print ("一共读取了 %s 个输入图像, %s 个标签" % (input_count, input_count))
 
  # 设置每次训练op的输入个数和迭代次数,这里为了支持任意图片总数,定义了一个余数remainder,譬如,如果每次训练op的输入个数为60,图片总数为150张,则前面两次各输入60张,最后一次输入30张(余数30)
  batch_size = 60
  iterations = 100
  batches_count = int(input_count / batch_size)
  remainder = input_count % batch_size
  print ("数据集分成 %s 批, 前面每批 %s 个数据,最后一批 %s 个数据" % (batches_count+1, batch_size, remainder))
 
  # 执行训练迭代
  for it in range(iterations):
    # 这里的关键是要把输入数组转为np.array
    for n in range(batches_count):
      train_step.run(feed_dict={x: input_images[n*batch_size:(n+1)*batch_size], y_: input_labels[n*batch_size:(n+1)*batch_size], keep_prob: 0.5})
    if remainder > 0:
      start_index = batches_count * batch_size;
      train_step.run(feed_dict={x: input_images[start_index:input_count-1], y_: input_labels[start_index:input_count-1], keep_prob: 0.5})
 
    # 每完成五次迭代,判断准确度是否已达到100%,达到则退出迭代循环
    iterate_accuracy = 0
    if it%5 == 0:
      iterate_accuracy = accuracy.eval(feed_dict={x: input_images, y_: input_labels, keep_prob: 1.0})
      print ('iteration %d: accuracy %s' % (it, iterate_accuracy))
      if iterate_accuracy >= 1:
        break;
 
  print ('完成训练!')

上述python代码的执行结果截图如下:

详解如何用TensorFlow训练和识别/分类自定义图片

对于上述代码中与模型构建相关的代码,请查阅官方《Deep MNIST for Experts》一节的内容进行理解。在本文中,需要重点掌握的是如何将本地图片源整合成为feed_dict可接受的格式。其中最关键的是这两行:

# 定义对应维数和各维长度的数组
input_images = np.array([[0]*784 for i in range(input_count)])
input_labels = np.array([[0]*10 for i in range(input_count)])

它们对应于feed_dict的两个placeholder:

x = tf.placeholder(tf.float32, shape=[None, 784])
y_ = tf.placeholder(tf.float32, shape=[None, 10])

这样一看,是不是很简单?

我们将在下一篇博文中介绍如何通过本文成果识别车牌数字。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
星球大战与Python之间的那些事
Jan 07 Python
python处理Excel xlrd的简单使用
Sep 12 Python
Python正则表达式知识汇总
Sep 22 Python
python实现简易通讯录修改版
Mar 13 Python
python爬虫获取小区经纬度以及结构化地址
Dec 30 Python
python快速编写单行注释多行注释的方法
Jul 31 Python
Django中自定义admin Xadmin的实现代码
Aug 09 Python
python:目标检测模型预测准确度计算方式(基于IoU)
Jan 18 Python
Pytorch 实现计算分类器准确率(总分类及子分类)
Jan 18 Python
python实现同一局域网下传输图片
Mar 20 Python
Python使用pdb调试代码的技巧
May 03 Python
Anaconda的安装与虚拟环境建立
Nov 18 Python
详解如何从TensorFlow的mnist数据集导出手写体数字图片
Aug 05 #Python
Python获取时间范围内日期列表和周列表的函数
Aug 05 #Python
Django ORM 查询管理器源码解析
Aug 05 #Python
python实现车牌识别的示例代码
Aug 05 #Python
使用python实现滑动验证码功能
Aug 05 #Python
Django 源码WSGI剖析过程详解
Aug 05 #Python
Python使用itchat 功能分析微信好友性别和位置
Aug 05 #Python
You might like
PHP 向右侧拉菜单实现代码,测试使用中
2009/11/03 PHP
PHP+MYSQL实现用户的增删改查
2015/03/24 PHP
百度地图经纬度转换到腾讯地图/Google 对应的经纬度
2015/08/28 PHP
PHP7之Mongodb API使用详解
2015/12/26 PHP
PHP实现RTX发送消息提醒的实例代码
2017/01/03 PHP
PHP从零开始打造自己的MVC框架之入口文件实现方法详解
2019/06/03 PHP
encode脚本和normal脚本混用的问题与解决方法
2007/03/08 Javascript
javascript 24小时弹出一次的代码(利用cookies)
2009/09/03 Javascript
Javascript UrlDecode函数代码
2010/01/09 Javascript
javascript中的float运算精度实例分析
2010/08/21 Javascript
使用控制台破解百小度一个月只准改一次名字
2015/08/13 Javascript
基于jQuery Bar Indicator 插件实现进度条展示效果
2015/09/30 Javascript
jquery关于事件冒泡和事件委托的技巧及阻止与允许事件冒泡的三种实现方法
2015/11/27 Javascript
基于javascript实现listbox左右移动
2016/01/29 Javascript
JavaScript中的各种操作符使用总结
2016/05/26 Javascript
基于JavaScript实现树形下拉框
2016/08/10 Javascript
vue-router单页面路由
2017/06/17 Javascript
微信小程序入门之广告条实现方法示例
2018/12/05 Javascript
laypage.js分页插件使用方法详解
2019/07/27 Javascript
[02:08]我的刀塔不可能这么可爱 胡晓桃_1
2014/06/20 DOTA
python类:class创建、数据方法属性及访问控制详解
2016/07/25 Python
python实现文件批量编码转换及注意事项
2019/10/14 Python
基于python实现图片转字符画代码实例
2020/09/04 Python
今天学到的CSS最新技术(与图片背景相关)
2012/12/24 HTML / CSS
多重CSS背景动画实现方法示例
2014/04/04 HTML / CSS
html5服务器推送_动力节点Java学院整理
2017/07/12 HTML / CSS
Kipling凯浦林美国官网:世界著名时尚休闲包袋品牌
2016/08/24 全球购物
选购世界上最好的美妆品:Cult Beauty
2017/11/03 全球购物
摩飞电器俄罗斯官方网站:Morphy Richards俄罗斯
2020/07/30 全球购物
Jones Bootmaker官网:优质靴子和鞋子在线
2020/11/30 全球购物
三年级数学教学反思
2014/01/31 职场文书
分层教学实施方案
2014/03/19 职场文书
教师爱岗敬业演讲稿
2014/05/05 职场文书
工作汇报开头与结尾怎么写
2014/11/08 职场文书
公安机关起诉意见书
2015/05/20 职场文书
python数字图像处理之图像自动阈值分割示例
2022/06/28 Python