tensorflow识别自己手写数字


Posted in Python onMarch 14, 2018

tensorflow作为google开源的项目,现在赶超了caffe,好像成为最受欢迎的深度学习框架。确实在编写的时候更能感受到代码的真实存在,这点和caffe不同,caffe通过编写配置文件进行网络的生成。环境tensorflow是0.10的版本,注意其他版本有的语句会有错误,这是tensorflow版本之间的兼容问题。

还需要安装PIL:pip install Pillow

图片的格式: 

? 图像标准化,可安装在20×20像素的框内,同时保留其长宽比。
? 图片都集中在一个28×28的图像中。
? 像素以列为主进行排序。像素值0到255,0表示背景(白色),255表示前景(黑色)。

创建一个.png的文件,背景是白色的,手写的字体是黑色的,

下面是数据测试的代码,一个两层的卷积神经网,然后用save进行模型的保存。

# coding: UTF-8 
import tensorflow as tf 
import numpy as np 
import matplotlib.pyplot as plt 
import input_data 
''''' 
得到数据 
''' 
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) 
 
training = mnist.train.images 
trainlable = mnist.train.labels 
testing = mnist.test.images 
testlabel = mnist.test.labels 
 
print ("MNIST loaded") 
# 获取交互式的方式 
sess = tf.InteractiveSession() 
# 初始化变量 
x = tf.placeholder("float", shape=[None, 784]) 
y_ = tf.placeholder("float", shape=[None, 10]) 
W = tf.Variable(tf.zeros([784, 10])) 
b = tf.Variable(tf.zeros([10])) 
''''' 
生成权重函数,其中shape是数据的形状 
''' 
def weight_variable(shape): 
  initial = tf.truncated_normal(shape, stddev=0.1) 
  return tf.Variable(initial) 
''''' 
生成偏执项 其中shape是数据形状 
''' 
def bias_variable(shape): 
  initial = tf.constant(0.1, shape=shape) 
  return tf.Variable(initial) 
 
def conv2d(x, W): 
  return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 
 
def max_pool_2x2(x): 
  return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], 
             strides=[1, 2, 2, 1], padding='SAME') 
 
W_conv1 = weight_variable([5, 5, 1, 32]) 
b_conv1 = bias_variable([32]) 
x_image = tf.reshape(x, [-1, 28, 28, 1]) 
 
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) 
h_pool1 = max_pool_2x2(h_conv1) 
 
W_conv2 = weight_variable([5, 5, 32, 64]) 
b_conv2 = bias_variable([64]) 
 
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) 
h_pool2 = max_pool_2x2(h_conv2) 
 
 
W_fc1 = weight_variable([7 * 7 * 64, 1024]) 
b_fc1 = bias_variable([1024]) 
 
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64]) 
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) 
 
keep_prob = tf.placeholder("float") 
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) 
 
W_fc2 = weight_variable([1024, 10]) 
b_fc2 = bias_variable([10]) 
 
y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2) 
 
cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv)) 
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) 
correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1)) 
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) 
 
# 保存网络训练的参数 
saver = tf.train.Saver() 
sess.run(tf.initialize_all_variables()) 
for i in range(8000): 
 batch = mnist.train.next_batch(50) 
 if i%100 == 0: 
  train_accuracy = accuracy.eval(feed_dict={ 
    x:batch[0], y_: batch[1], keep_prob: 1.0}) 
  print "step %d, training accuracy %g"%(i, train_accuracy) 
 train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5}) 
 
save_path = saver.save(sess, "model_mnist.ckpt") 
print("Model saved in life:", save_path) 
 
print "test accuracy %g"%accuracy.eval(feed_dict={ 
  x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0})

其中input_data.py如下代码,是进行mnist数据集的下载的:代码是由mnist数据集提供的官方下载的版本。

# Copyright 2015 Google Inc. All Rights Reserved. 
# 
# Licensed under the Apache License, Version 2.0 (the "License"); 
# you may not use this file except in compliance with the License. 
# You may obtain a copy of the License at 
# 
#   http://www.apache.org/licenses/LICENSE-2.0 
# 
# Unless required by applicable law or agreed to in writing, software 
# distributed under the License is distributed on an "AS IS" BASIS, 
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
# See the License for the specific language governing permissions and 
# limitations under the License. 
# ============================================================================== 
"""Functions for downloading and reading MNIST data.""" 
from __future__ import absolute_import 
from __future__ import division 
from __future__ import print_function 
import gzip 
import os 
import tensorflow.python.platform 
import numpy 
from six.moves import urllib 
from six.moves import xrange # pylint: disable=redefined-builtin 
import tensorflow as tf 
SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/' 
def maybe_download(filename, work_directory): 
 """Download the data from Yann's website, unless it's already here.""" 
 if not os.path.exists(work_directory): 
  os.mkdir(work_directory) 
 filepath = os.path.join(work_directory, filename) 
 if not os.path.exists(filepath): 
  filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath) 
  statinfo = os.stat(filepath) 
  print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') 
 return filepath 
def _read32(bytestream): 
 dt = numpy.dtype(numpy.uint32).newbyteorder('>') 
 return numpy.frombuffer(bytestream.read(4), dtype=dt)[0] 
def extract_images(filename): 
 """Extract the images into a 4D uint8 numpy array [index, y, x, depth].""" 
 print('Extracting', filename) 
 with gzip.open(filename) as bytestream: 
  magic = _read32(bytestream) 
  if magic != 2051: 
   raise ValueError( 
     'Invalid magic number %d in MNIST image file: %s' % 
     (magic, filename)) 
  num_images = _read32(bytestream) 
  rows = _read32(bytestream) 
  cols = _read32(bytestream) 
  buf = bytestream.read(rows * cols * num_images) 
  data = numpy.frombuffer(buf, dtype=numpy.uint8) 
  data = data.reshape(num_images, rows, cols, 1) 
  return data 
def dense_to_one_hot(labels_dense, num_classes=10): 
 """Convert class labels from scalars to one-hot vectors.""" 
 num_labels = labels_dense.shape[0] 
 index_offset = numpy.arange(num_labels) * num_classes 
 labels_one_hot = numpy.zeros((num_labels, num_classes)) 
 labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1 
 return labels_one_hot 
def extract_labels(filename, one_hot=False): 
 """Extract the labels into a 1D uint8 numpy array [index].""" 
 print('Extracting', filename) 
 with gzip.open(filename) as bytestream: 
  magic = _read32(bytestream) 
  if magic != 2049: 
   raise ValueError( 
     'Invalid magic number %d in MNIST label file: %s' % 
     (magic, filename)) 
  num_items = _read32(bytestream) 
  buf = bytestream.read(num_items) 
  labels = numpy.frombuffer(buf, dtype=numpy.uint8) 
  if one_hot: 
   return dense_to_one_hot(labels) 
  return labels 
class DataSet(object): 
 def __init__(self, images, labels, fake_data=False, one_hot=False, 
        dtype=tf.float32): 
  """Construct a DataSet. 
  one_hot arg is used only if fake_data is true. `dtype` can be either 
  `uint8` to leave the input as `[0, 255]`, or `float32` to rescale into 
  `[0, 1]`. 
  """ 
  dtype = tf.as_dtype(dtype).base_dtype 
  if dtype not in (tf.uint8, tf.float32): 
   raise TypeError('Invalid image dtype %r, expected uint8 or float32' % 
           dtype) 
  if fake_data: 
   self._num_examples = 10000 
   self.one_hot = one_hot 
  else: 
   assert images.shape[0] == labels.shape[0], ( 
     'images.shape: %s labels.shape: %s' % (images.shape, 
                         labels.shape)) 
   self._num_examples = images.shape[0] 
   # Convert shape from [num examples, rows, columns, depth] 
   # to [num examples, rows*columns] (assuming depth == 1) 
   assert images.shape[3] == 1 
   images = images.reshape(images.shape[0], 
               images.shape[1] * images.shape[2]) 
   if dtype == tf.float32: 
    # Convert from [0, 255] -> [0.0, 1.0]. 
    images = images.astype(numpy.float32) 
    images = numpy.multiply(images, 1.0 / 255.0) 
  self._images = images 
  self._labels = labels 
  self._epochs_completed = 0 
  self._index_in_epoch = 0 
 @property 
 def images(self): 
  return self._images 
 @property 
 def labels(self): 
  return self._labels 
 @property 
 def num_examples(self): 
  return self._num_examples 
 @property 
 def epochs_completed(self): 
  return self._epochs_completed 
 def next_batch(self, batch_size, fake_data=False): 
  """Return the next `batch_size` examples from this data set.""" 
  if fake_data: 
   fake_image = [1] * 784 
   if self.one_hot: 
    fake_label = [1] + [0] * 9 
   else: 
    fake_label = 0 
   return [fake_image for _ in xrange(batch_size)], [ 
     fake_label for _ in xrange(batch_size)] 
  start = self._index_in_epoch 
  self._index_in_epoch += batch_size 
  if self._index_in_epoch > self._num_examples: 
   # Finished epoch 
   self._epochs_completed += 1 
   # Shuffle the data 
   perm = numpy.arange(self._num_examples) 
   numpy.random.shuffle(perm) 
   self._images = self._images[perm] 
   self._labels = self._labels[perm] 
   # Start next epoch 
   start = 0 
   self._index_in_epoch = batch_size 
   assert batch_size <= self._num_examples 
  end = self._index_in_epoch 
  return self._images[start:end], self._labels[start:end] 
def read_data_sets(train_dir, fake_data=False, one_hot=False, dtype=tf.float32): 
 class DataSets(object): 
  pass 
 data_sets = DataSets() 
 if fake_data: 
  def fake(): 
   return DataSet([], [], fake_data=True, one_hot=one_hot, dtype=dtype) 
  data_sets.train = fake() 
  data_sets.validation = fake() 
  data_sets.test = fake() 
  return data_sets 
 TRAIN_IMAGES = 'train-images-idx3-ubyte.gz' 
 TRAIN_LABELS = 'train-labels-idx1-ubyte.gz' 
 TEST_IMAGES = 't10k-images-idx3-ubyte.gz' 
 TEST_LABELS = 't10k-labels-idx1-ubyte.gz' 
 VALIDATION_SIZE = 5000 
 local_file = maybe_download(TRAIN_IMAGES, train_dir) 
 train_images = extract_images(local_file) 
 local_file = maybe_download(TRAIN_LABELS, train_dir) 
 train_labels = extract_labels(local_file, one_hot=one_hot) 
 local_file = maybe_download(TEST_IMAGES, train_dir) 
 test_images = extract_images(local_file) 
 local_file = maybe_download(TEST_LABELS, train_dir) 
 test_labels = extract_labels(local_file, one_hot=one_hot) 
 validation_images = train_images[:VALIDATION_SIZE] 
 validation_labels = train_labels[:VALIDATION_SIZE] 
 train_images = train_images[VALIDATION_SIZE:] 
 train_labels = train_labels[VALIDATION_SIZE:] 
 data_sets.train = DataSet(train_images, train_labels, dtype=dtype) 
 data_sets.validation = DataSet(validation_images, validation_labels, 
                 dtype=dtype) 
 data_sets.test = DataSet(test_images, test_labels, dtype=dtype) 
 return data_sets

然后进行代码的测试:

# import modules 
import sys 
import tensorflow as tf 
from PIL import Image, ImageFilter 
 
 
def predictint(imvalue): 
  """ 
  This function returns the predicted integer. 
  The imput is the pixel values from the imageprepare() function. 
  """ 
 
  # Define the model (same as when creating the model file) 
  x = tf.placeholder(tf.float32, [None, 784]) 
  W = tf.Variable(tf.zeros([784, 10])) 
  b = tf.Variable(tf.zeros([10])) 
 
  def weight_variable(shape): 
    initial = tf.truncated_normal(shape, stddev=0.1) 
    return tf.Variable(initial) 
 
  def bias_variable(shape): 
    initial = tf.constant(0.1, shape=shape) 
    return tf.Variable(initial) 
 
  def conv2d(x, W): 
    return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 
 
  def max_pool_2x2(x): 
    return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') 
 
  W_conv1 = weight_variable([5, 5, 1, 32]) 
  b_conv1 = bias_variable([32]) 
 
  x_image = tf.reshape(x, [-1, 28, 28, 1]) 
  h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) 
  h_pool1 = max_pool_2x2(h_conv1) 
 
  W_conv2 = weight_variable([5, 5, 32, 64]) 
  b_conv2 = bias_variable([64]) 
 
  h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) 
  h_pool2 = max_pool_2x2(h_conv2) 
 
  W_fc1 = weight_variable([7 * 7 * 64, 1024]) 
  b_fc1 = bias_variable([1024]) 
 
  h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64]) 
  h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) 
 
  keep_prob = tf.placeholder(tf.float32) 
  h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) 
 
  W_fc2 = weight_variable([1024, 10]) 
  b_fc2 = bias_variable([10]) 
 
  y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2) 
 
  init_op = tf.initialize_all_variables() 
  saver = tf.train.Saver() 
 
  """ 
  Load the model_mnist.ckpt file 
  file is stored in the same directory as this python script is started 
  Use the model to predict the integer. Integer is returend as list. 
  Based on the documentatoin at 
  https://www.tensorflow.org/versions/master/how_tos/variables/index.html 
  """ 
  with tf.Session() as sess: 
    sess.run(init_op) 
    saver.restore(sess, "model_mnist.ckpt") 
    # print ("Model restored.") 
 
    prediction = tf.argmax(y_conv, 1) 
    return prediction.eval(feed_dict={x: [imvalue], keep_prob: 1.0}, session=sess) 
 
 
def imageprepare(argv): 
  """ 
  This function returns the pixel values. 
  The imput is a png file location. 
  """ 
  im = Image.open(argv).convert('L') 
  width = float(im.size[0]) 
  height = float(im.size[1]) 
  newImage = Image.new('L', (28, 28), (255)) # creates white canvas of 28x28 pixels 
 
  if width > height: # check which dimension is bigger 
    # Width is bigger. Width becomes 20 pixels. 
    nheight = int(round((20.0 / width * height), 0)) # resize height according to ratio width 
    if (nheight == 0): # rare case but minimum is 1 pixel 
      nheigth = 1 
      # resize and sharpen 
    img = im.resize((20, nheight), Image.ANTIALIAS).filter(ImageFilter.SHARPEN) 
    wtop = int(round(((28 - nheight) / 2), 0)) # caculate horizontal pozition 
    newImage.paste(img, (4, wtop)) # paste resized image on white canvas 
  else: 
    # Height is bigger. Heigth becomes 20 pixels. 
    nwidth = int(round((20.0 / height * width), 0)) # resize width according to ratio height 
    if (nwidth == 0): # rare case but minimum is 1 pixel 
      nwidth = 1 
      # resize and sharpen 
    img = im.resize((nwidth, 20), Image.ANTIALIAS).filter(ImageFilter.SHARPEN) 
    wleft = int(round(((28 - nwidth) / 2), 0)) # caculate vertical pozition 
    newImage.paste(img, (wleft, 4)) # paste resized image on white canvas 
 
  # newImage.save("sample.png") 
 
  tv = list(newImage.getdata()) # get pixel values 
 
  # normalize pixels to 0 and 1. 0 is pure white, 1 is pure black. 
  tva = [(255 - x) * 1.0 / 255.0 for x in tv] 
  return tva 
  # print(tva) 
 
 
def main(argv): 
  """ 
  Main function. 
  """ 
  imvalue = imageprepare(argv) 
  predint = predictint(imvalue) 
  print (predint[0]) # first value in list 
 
 
if __name__ == "__main__": 
  main('2.png')

其中我用于测试的代码如下:

tensorflow识别自己手写数字

可以将图片另存到路径下面,然后进行测试。

(1)载入我的手写数字的图像。
(2)将图像转换为黑白(模式“L”)
(3)确定原始图像的尺寸是最大的
(4)调整图像的大小,使得最大尺寸(醚的高度及宽度)为20像素,并且以相同的比例最小化尺寸刻度。
(5)锐化图像。这会极大地强化结果。
(6)把图像粘贴在28×28像素的白色画布上。在最大的尺寸上从顶部或侧面居中图像4个像素。最大尺寸始终是20个像素和4 + 20 + 4 = 28,最小尺寸被定位在28和缩放的图像的新的大小之间差的一半。
(7)获取新的图像(画布+居中的图像)的像素值。
(8)归一化像素值到0和1之间的一个值(这也在TensorFlow MNIST教程中完成)。其中0是白色的,1是纯黑色。从步骤7得到的像素值是与之相反的,其中255是白色的,0黑色,所以数值必须反转。下述公式包括反转和规格化(255-X)* 1.0 / 255.0

Python 相关文章推荐
python切换hosts文件代码示例
Dec 31 Python
python根据距离和时长计算配速示例
Feb 16 Python
Python中的两个内置模块介绍
Apr 05 Python
使用httplib模块来制作Python下HTTP客户端的方法
Jun 19 Python
使用Python生成随机密码的示例分享
Feb 18 Python
python 基础教程之Map使用方法
Jan 17 Python
python3实现磁盘空间监控
Jun 21 Python
python处理“
Jun 10 Python
PyQt5实现从主窗口打开子窗口的方法
Jun 19 Python
Python实现某论坛自动签到功能
Aug 20 Python
使用Python3 poplib模块删除服务器多天前的邮件实现代码
Apr 24 Python
django中ImageField的使用详解
Dec 21 Python
磁盘垃圾文件清理器python代码实现
Aug 24 #Python
Django自定义用户认证示例详解
Mar 14 #Python
python如何压缩新文件到已有ZIP文件
Mar 14 #Python
python中format()函数的简单使用教程
Mar 14 #Python
Python批量提取PDF文件中文本的脚本
Mar 14 #Python
深入理解Django的中间件middleware
Mar 14 #Python
python批量设置多个Excel文件页眉页脚的脚本
Mar 14 #Python
You might like
神族 Protoss 剧情介绍
2020/03/14 星际争霸
PHP 存取 MySQL 数据库的一个例子
2006/10/09 PHP
php 注释规范
2012/03/29 PHP
php异常处理使用示例
2014/02/25 PHP
调试PHP程序的多种方法介绍
2014/11/06 PHP
php+redis消息队列实现抢购功能
2018/02/08 PHP
Laravel5框架自定义错误页面配置操作示例
2019/04/17 PHP
jQuery实现随意改变div任意属性的名称和值(部分原生js实现)
2013/05/28 Javascript
JSuggest自动匹配下拉框使用方法(示例代码)
2013/12/27 Javascript
JavaScript修改浏览器tab标题小技巧
2015/01/06 Javascript
ECMAScript 5中的属性描述符详解
2015/03/02 Javascript
js脚本分页代码分享(7种样式)
2015/08/19 Javascript
JavaScript实现算术平方根算法-代码超简单
2015/09/11 Javascript
全面解析Bootstrap表单使用方法(表单样式)
2015/11/24 Javascript
剖析Node.js异步编程中的回调与代码设计模式
2016/02/16 Javascript
一个用jquery写的判断div滚动条到底部的方法【推荐】
2016/04/29 Javascript
原生JS实现图片左右轮播
2016/12/30 Javascript
详解nodeJS之路径PATH模块
2017/05/31 NodeJs
原生JS实现 MUI导航栏透明渐变效果
2017/11/07 Javascript
Vue中使用better-scroll实现轮播图组件
2020/03/07 Javascript
如何区分vue中的v-show 与 v-if
2020/09/08 Javascript
[01:53]2016完美“圣”典风云人物:Maybe专访
2016/12/05 DOTA
Python中的两个内置模块介绍
2015/04/05 Python
python切片(获取一个子列表(数组))详解
2019/08/09 Python
css3使用animation属性实现炫酷效果(推荐)
2020/02/04 HTML / CSS
大学生的创业计划书就该这么写
2014/01/30 职场文书
优秀员工推荐信
2014/05/10 职场文书
初中教师业务学习材料
2014/05/12 职场文书
2014年学校工会工作总结
2014/12/06 职场文书
超市督导岗位职责
2015/04/10 职场文书
2015大学迎新晚会策划书
2015/07/16 职场文书
检讨书怎么写?
2019/06/21 职场文书
个人工作总结(管理人员)范文
2019/08/13 职场文书
详解RedisTemplate下Redis分布式锁引发的系列问题
2021/04/27 Redis
redis 限制内存使用大小的实现
2021/05/08 Redis
Python实现为PDF去除水印的示例代码
2022/04/03 Python