tensorflow识别自己手写数字


Posted in Python onMarch 14, 2018

tensorflow作为google开源的项目,现在赶超了caffe,好像成为最受欢迎的深度学习框架。确实在编写的时候更能感受到代码的真实存在,这点和caffe不同,caffe通过编写配置文件进行网络的生成。环境tensorflow是0.10的版本,注意其他版本有的语句会有错误,这是tensorflow版本之间的兼容问题。

还需要安装PIL:pip install Pillow

图片的格式: 

? 图像标准化,可安装在20×20像素的框内,同时保留其长宽比。
? 图片都集中在一个28×28的图像中。
? 像素以列为主进行排序。像素值0到255,0表示背景(白色),255表示前景(黑色)。

创建一个.png的文件,背景是白色的,手写的字体是黑色的,

下面是数据测试的代码,一个两层的卷积神经网,然后用save进行模型的保存。

# coding: UTF-8 
import tensorflow as tf 
import numpy as np 
import matplotlib.pyplot as plt 
import input_data 
''''' 
得到数据 
''' 
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) 
 
training = mnist.train.images 
trainlable = mnist.train.labels 
testing = mnist.test.images 
testlabel = mnist.test.labels 
 
print ("MNIST loaded") 
# 获取交互式的方式 
sess = tf.InteractiveSession() 
# 初始化变量 
x = tf.placeholder("float", shape=[None, 784]) 
y_ = tf.placeholder("float", shape=[None, 10]) 
W = tf.Variable(tf.zeros([784, 10])) 
b = tf.Variable(tf.zeros([10])) 
''''' 
生成权重函数,其中shape是数据的形状 
''' 
def weight_variable(shape): 
  initial = tf.truncated_normal(shape, stddev=0.1) 
  return tf.Variable(initial) 
''''' 
生成偏执项 其中shape是数据形状 
''' 
def bias_variable(shape): 
  initial = tf.constant(0.1, shape=shape) 
  return tf.Variable(initial) 
 
def conv2d(x, W): 
  return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 
 
def max_pool_2x2(x): 
  return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], 
             strides=[1, 2, 2, 1], padding='SAME') 
 
W_conv1 = weight_variable([5, 5, 1, 32]) 
b_conv1 = bias_variable([32]) 
x_image = tf.reshape(x, [-1, 28, 28, 1]) 
 
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) 
h_pool1 = max_pool_2x2(h_conv1) 
 
W_conv2 = weight_variable([5, 5, 32, 64]) 
b_conv2 = bias_variable([64]) 
 
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) 
h_pool2 = max_pool_2x2(h_conv2) 
 
 
W_fc1 = weight_variable([7 * 7 * 64, 1024]) 
b_fc1 = bias_variable([1024]) 
 
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64]) 
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) 
 
keep_prob = tf.placeholder("float") 
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) 
 
W_fc2 = weight_variable([1024, 10]) 
b_fc2 = bias_variable([10]) 
 
y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2) 
 
cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv)) 
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) 
correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1)) 
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) 
 
# 保存网络训练的参数 
saver = tf.train.Saver() 
sess.run(tf.initialize_all_variables()) 
for i in range(8000): 
 batch = mnist.train.next_batch(50) 
 if i%100 == 0: 
  train_accuracy = accuracy.eval(feed_dict={ 
    x:batch[0], y_: batch[1], keep_prob: 1.0}) 
  print "step %d, training accuracy %g"%(i, train_accuracy) 
 train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5}) 
 
save_path = saver.save(sess, "model_mnist.ckpt") 
print("Model saved in life:", save_path) 
 
print "test accuracy %g"%accuracy.eval(feed_dict={ 
  x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0})

其中input_data.py如下代码,是进行mnist数据集的下载的:代码是由mnist数据集提供的官方下载的版本。

# Copyright 2015 Google Inc. All Rights Reserved. 
# 
# Licensed under the Apache License, Version 2.0 (the "License"); 
# you may not use this file except in compliance with the License. 
# You may obtain a copy of the License at 
# 
#   http://www.apache.org/licenses/LICENSE-2.0 
# 
# Unless required by applicable law or agreed to in writing, software 
# distributed under the License is distributed on an "AS IS" BASIS, 
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
# See the License for the specific language governing permissions and 
# limitations under the License. 
# ============================================================================== 
"""Functions for downloading and reading MNIST data.""" 
from __future__ import absolute_import 
from __future__ import division 
from __future__ import print_function 
import gzip 
import os 
import tensorflow.python.platform 
import numpy 
from six.moves import urllib 
from six.moves import xrange # pylint: disable=redefined-builtin 
import tensorflow as tf 
SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/' 
def maybe_download(filename, work_directory): 
 """Download the data from Yann's website, unless it's already here.""" 
 if not os.path.exists(work_directory): 
  os.mkdir(work_directory) 
 filepath = os.path.join(work_directory, filename) 
 if not os.path.exists(filepath): 
  filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath) 
  statinfo = os.stat(filepath) 
  print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') 
 return filepath 
def _read32(bytestream): 
 dt = numpy.dtype(numpy.uint32).newbyteorder('>') 
 return numpy.frombuffer(bytestream.read(4), dtype=dt)[0] 
def extract_images(filename): 
 """Extract the images into a 4D uint8 numpy array [index, y, x, depth].""" 
 print('Extracting', filename) 
 with gzip.open(filename) as bytestream: 
  magic = _read32(bytestream) 
  if magic != 2051: 
   raise ValueError( 
     'Invalid magic number %d in MNIST image file: %s' % 
     (magic, filename)) 
  num_images = _read32(bytestream) 
  rows = _read32(bytestream) 
  cols = _read32(bytestream) 
  buf = bytestream.read(rows * cols * num_images) 
  data = numpy.frombuffer(buf, dtype=numpy.uint8) 
  data = data.reshape(num_images, rows, cols, 1) 
  return data 
def dense_to_one_hot(labels_dense, num_classes=10): 
 """Convert class labels from scalars to one-hot vectors.""" 
 num_labels = labels_dense.shape[0] 
 index_offset = numpy.arange(num_labels) * num_classes 
 labels_one_hot = numpy.zeros((num_labels, num_classes)) 
 labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1 
 return labels_one_hot 
def extract_labels(filename, one_hot=False): 
 """Extract the labels into a 1D uint8 numpy array [index].""" 
 print('Extracting', filename) 
 with gzip.open(filename) as bytestream: 
  magic = _read32(bytestream) 
  if magic != 2049: 
   raise ValueError( 
     'Invalid magic number %d in MNIST label file: %s' % 
     (magic, filename)) 
  num_items = _read32(bytestream) 
  buf = bytestream.read(num_items) 
  labels = numpy.frombuffer(buf, dtype=numpy.uint8) 
  if one_hot: 
   return dense_to_one_hot(labels) 
  return labels 
class DataSet(object): 
 def __init__(self, images, labels, fake_data=False, one_hot=False, 
        dtype=tf.float32): 
  """Construct a DataSet. 
  one_hot arg is used only if fake_data is true. `dtype` can be either 
  `uint8` to leave the input as `[0, 255]`, or `float32` to rescale into 
  `[0, 1]`. 
  """ 
  dtype = tf.as_dtype(dtype).base_dtype 
  if dtype not in (tf.uint8, tf.float32): 
   raise TypeError('Invalid image dtype %r, expected uint8 or float32' % 
           dtype) 
  if fake_data: 
   self._num_examples = 10000 
   self.one_hot = one_hot 
  else: 
   assert images.shape[0] == labels.shape[0], ( 
     'images.shape: %s labels.shape: %s' % (images.shape, 
                         labels.shape)) 
   self._num_examples = images.shape[0] 
   # Convert shape from [num examples, rows, columns, depth] 
   # to [num examples, rows*columns] (assuming depth == 1) 
   assert images.shape[3] == 1 
   images = images.reshape(images.shape[0], 
               images.shape[1] * images.shape[2]) 
   if dtype == tf.float32: 
    # Convert from [0, 255] -> [0.0, 1.0]. 
    images = images.astype(numpy.float32) 
    images = numpy.multiply(images, 1.0 / 255.0) 
  self._images = images 
  self._labels = labels 
  self._epochs_completed = 0 
  self._index_in_epoch = 0 
 @property 
 def images(self): 
  return self._images 
 @property 
 def labels(self): 
  return self._labels 
 @property 
 def num_examples(self): 
  return self._num_examples 
 @property 
 def epochs_completed(self): 
  return self._epochs_completed 
 def next_batch(self, batch_size, fake_data=False): 
  """Return the next `batch_size` examples from this data set.""" 
  if fake_data: 
   fake_image = [1] * 784 
   if self.one_hot: 
    fake_label = [1] + [0] * 9 
   else: 
    fake_label = 0 
   return [fake_image for _ in xrange(batch_size)], [ 
     fake_label for _ in xrange(batch_size)] 
  start = self._index_in_epoch 
  self._index_in_epoch += batch_size 
  if self._index_in_epoch > self._num_examples: 
   # Finished epoch 
   self._epochs_completed += 1 
   # Shuffle the data 
   perm = numpy.arange(self._num_examples) 
   numpy.random.shuffle(perm) 
   self._images = self._images[perm] 
   self._labels = self._labels[perm] 
   # Start next epoch 
   start = 0 
   self._index_in_epoch = batch_size 
   assert batch_size <= self._num_examples 
  end = self._index_in_epoch 
  return self._images[start:end], self._labels[start:end] 
def read_data_sets(train_dir, fake_data=False, one_hot=False, dtype=tf.float32): 
 class DataSets(object): 
  pass 
 data_sets = DataSets() 
 if fake_data: 
  def fake(): 
   return DataSet([], [], fake_data=True, one_hot=one_hot, dtype=dtype) 
  data_sets.train = fake() 
  data_sets.validation = fake() 
  data_sets.test = fake() 
  return data_sets 
 TRAIN_IMAGES = 'train-images-idx3-ubyte.gz' 
 TRAIN_LABELS = 'train-labels-idx1-ubyte.gz' 
 TEST_IMAGES = 't10k-images-idx3-ubyte.gz' 
 TEST_LABELS = 't10k-labels-idx1-ubyte.gz' 
 VALIDATION_SIZE = 5000 
 local_file = maybe_download(TRAIN_IMAGES, train_dir) 
 train_images = extract_images(local_file) 
 local_file = maybe_download(TRAIN_LABELS, train_dir) 
 train_labels = extract_labels(local_file, one_hot=one_hot) 
 local_file = maybe_download(TEST_IMAGES, train_dir) 
 test_images = extract_images(local_file) 
 local_file = maybe_download(TEST_LABELS, train_dir) 
 test_labels = extract_labels(local_file, one_hot=one_hot) 
 validation_images = train_images[:VALIDATION_SIZE] 
 validation_labels = train_labels[:VALIDATION_SIZE] 
 train_images = train_images[VALIDATION_SIZE:] 
 train_labels = train_labels[VALIDATION_SIZE:] 
 data_sets.train = DataSet(train_images, train_labels, dtype=dtype) 
 data_sets.validation = DataSet(validation_images, validation_labels, 
                 dtype=dtype) 
 data_sets.test = DataSet(test_images, test_labels, dtype=dtype) 
 return data_sets

然后进行代码的测试:

# import modules 
import sys 
import tensorflow as tf 
from PIL import Image, ImageFilter 
 
 
def predictint(imvalue): 
  """ 
  This function returns the predicted integer. 
  The imput is the pixel values from the imageprepare() function. 
  """ 
 
  # Define the model (same as when creating the model file) 
  x = tf.placeholder(tf.float32, [None, 784]) 
  W = tf.Variable(tf.zeros([784, 10])) 
  b = tf.Variable(tf.zeros([10])) 
 
  def weight_variable(shape): 
    initial = tf.truncated_normal(shape, stddev=0.1) 
    return tf.Variable(initial) 
 
  def bias_variable(shape): 
    initial = tf.constant(0.1, shape=shape) 
    return tf.Variable(initial) 
 
  def conv2d(x, W): 
    return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 
 
  def max_pool_2x2(x): 
    return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') 
 
  W_conv1 = weight_variable([5, 5, 1, 32]) 
  b_conv1 = bias_variable([32]) 
 
  x_image = tf.reshape(x, [-1, 28, 28, 1]) 
  h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) 
  h_pool1 = max_pool_2x2(h_conv1) 
 
  W_conv2 = weight_variable([5, 5, 32, 64]) 
  b_conv2 = bias_variable([64]) 
 
  h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) 
  h_pool2 = max_pool_2x2(h_conv2) 
 
  W_fc1 = weight_variable([7 * 7 * 64, 1024]) 
  b_fc1 = bias_variable([1024]) 
 
  h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64]) 
  h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) 
 
  keep_prob = tf.placeholder(tf.float32) 
  h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) 
 
  W_fc2 = weight_variable([1024, 10]) 
  b_fc2 = bias_variable([10]) 
 
  y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2) 
 
  init_op = tf.initialize_all_variables() 
  saver = tf.train.Saver() 
 
  """ 
  Load the model_mnist.ckpt file 
  file is stored in the same directory as this python script is started 
  Use the model to predict the integer. Integer is returend as list. 
  Based on the documentatoin at 
  https://www.tensorflow.org/versions/master/how_tos/variables/index.html 
  """ 
  with tf.Session() as sess: 
    sess.run(init_op) 
    saver.restore(sess, "model_mnist.ckpt") 
    # print ("Model restored.") 
 
    prediction = tf.argmax(y_conv, 1) 
    return prediction.eval(feed_dict={x: [imvalue], keep_prob: 1.0}, session=sess) 
 
 
def imageprepare(argv): 
  """ 
  This function returns the pixel values. 
  The imput is a png file location. 
  """ 
  im = Image.open(argv).convert('L') 
  width = float(im.size[0]) 
  height = float(im.size[1]) 
  newImage = Image.new('L', (28, 28), (255)) # creates white canvas of 28x28 pixels 
 
  if width > height: # check which dimension is bigger 
    # Width is bigger. Width becomes 20 pixels. 
    nheight = int(round((20.0 / width * height), 0)) # resize height according to ratio width 
    if (nheight == 0): # rare case but minimum is 1 pixel 
      nheigth = 1 
      # resize and sharpen 
    img = im.resize((20, nheight), Image.ANTIALIAS).filter(ImageFilter.SHARPEN) 
    wtop = int(round(((28 - nheight) / 2), 0)) # caculate horizontal pozition 
    newImage.paste(img, (4, wtop)) # paste resized image on white canvas 
  else: 
    # Height is bigger. Heigth becomes 20 pixels. 
    nwidth = int(round((20.0 / height * width), 0)) # resize width according to ratio height 
    if (nwidth == 0): # rare case but minimum is 1 pixel 
      nwidth = 1 
      # resize and sharpen 
    img = im.resize((nwidth, 20), Image.ANTIALIAS).filter(ImageFilter.SHARPEN) 
    wleft = int(round(((28 - nwidth) / 2), 0)) # caculate vertical pozition 
    newImage.paste(img, (wleft, 4)) # paste resized image on white canvas 
 
  # newImage.save("sample.png") 
 
  tv = list(newImage.getdata()) # get pixel values 
 
  # normalize pixels to 0 and 1. 0 is pure white, 1 is pure black. 
  tva = [(255 - x) * 1.0 / 255.0 for x in tv] 
  return tva 
  # print(tva) 
 
 
def main(argv): 
  """ 
  Main function. 
  """ 
  imvalue = imageprepare(argv) 
  predint = predictint(imvalue) 
  print (predint[0]) # first value in list 
 
 
if __name__ == "__main__": 
  main('2.png')

其中我用于测试的代码如下:

tensorflow识别自己手写数字

可以将图片另存到路径下面,然后进行测试。

(1)载入我的手写数字的图像。
(2)将图像转换为黑白(模式“L”)
(3)确定原始图像的尺寸是最大的
(4)调整图像的大小,使得最大尺寸(醚的高度及宽度)为20像素,并且以相同的比例最小化尺寸刻度。
(5)锐化图像。这会极大地强化结果。
(6)把图像粘贴在28×28像素的白色画布上。在最大的尺寸上从顶部或侧面居中图像4个像素。最大尺寸始终是20个像素和4 + 20 + 4 = 28,最小尺寸被定位在28和缩放的图像的新的大小之间差的一半。
(7)获取新的图像(画布+居中的图像)的像素值。
(8)归一化像素值到0和1之间的一个值(这也在TensorFlow MNIST教程中完成)。其中0是白色的,1是纯黑色。从步骤7得到的像素值是与之相反的,其中255是白色的,0黑色,所以数值必须反转。下述公式包括反转和规格化(255-X)* 1.0 / 255.0

Python 相关文章推荐
使用python实现baidu hi自动登录的代码
Feb 10 Python
python使用urllib2模块获取gravatar头像实例
Dec 18 Python
Python编程之微信推送模板消息功能示例
Aug 21 Python
浅谈pycharm的xmx和xms设置方法
Dec 03 Python
python3实现网络爬虫之BeautifulSoup使用详解
Dec 19 Python
python判断单向链表是否包括环,若包含则计算环入口的节点实例分析
Oct 23 Python
Python OrderedDict的使用案例解析
Oct 25 Python
python简单的三元一次方程求解实例
Apr 02 Python
python 追踪except信息方式
Apr 25 Python
Python+Kepler.gl轻松制作酷炫路径动画的实现示例
Jun 02 Python
用opencv给图片换背景色的示例代码
Jul 08 Python
最简单的matplotlib安装教程(小白)
Jul 28 Python
磁盘垃圾文件清理器python代码实现
Aug 24 #Python
Django自定义用户认证示例详解
Mar 14 #Python
python如何压缩新文件到已有ZIP文件
Mar 14 #Python
python中format()函数的简单使用教程
Mar 14 #Python
Python批量提取PDF文件中文本的脚本
Mar 14 #Python
深入理解Django的中间件middleware
Mar 14 #Python
python批量设置多个Excel文件页眉页脚的脚本
Mar 14 #Python
You might like
资料注册后发信小技巧
2006/10/09 PHP
别人整理的服务器变量:$_SERVER
2006/10/20 PHP
PHP中文汉字验证码
2007/04/08 PHP
Yii2实现log输出到file及database的方法
2016/11/12 PHP
php 文件下载 出现下载文件内容乱码损坏的解决方法(推荐)
2016/11/16 PHP
php查找字符串中第一个非0的位置截取
2017/02/27 PHP
JS类中定义原型方法的两种实现的区别
2007/03/08 Javascript
javascript[js]获取url参数的代码
2007/10/17 Javascript
有关javascript的性能优化 (repaint和reflow)
2013/04/12 Javascript
超棒的响应式布局jQuery插件Freetile.js
2014/11/17 Javascript
jQuery 3.0 的变化及使用方法
2016/02/01 Javascript
js中获取时间new Date()的全面介绍
2016/06/20 Javascript
jQuery+CSS3实现仿花瓣网固定顶部位置带悬浮效果的导航菜单
2016/09/21 Javascript
微信小程序 数据封装,参数传值等经验分享
2017/01/09 Javascript
js+css3实现旋转效果
2017/01/20 Javascript
JavaScript实现简单生成随机颜色的方法
2017/09/21 Javascript
JavaScript Tab菜单实现过程解析
2020/05/13 Javascript
vue-以文件流-blob-的形式-下载-导出文件操作
2020/08/07 Javascript
使用Python编写一个在Linux下实现截图分享的脚本的教程
2015/04/24 Python
解读Python编程中的命名空间与作用域
2015/10/16 Python
正确理解python中的关键字“with”与上下文管理器
2017/04/21 Python
Django中的文件的上传的几种方式
2018/07/23 Python
解决Pandas的DataFrame输出截断和省略的问题
2019/02/08 Python
python中matplotlib条件背景颜色的实现
2019/09/02 Python
Python JSON编解码方式原理详解
2020/01/20 Python
国家地理在线商店:Shop National Geographic
2018/06/30 全球购物
eDreams意大利:南欧领先的在线旅行社
2018/11/23 全球购物
C/C++程序员常见面试题一
2012/12/08 面试题
SQL Server数据库笔试题和答案
2016/02/04 面试题
爱国卫生月实施方案
2014/02/21 职场文书
2015年教师新年寄语
2014/12/08 职场文书
2015年教师节主持词
2015/07/03 职场文书
2016教师六五普法学习心得体会
2016/01/21 职场文书
MySQL 隔离数据列和前缀索引的使用总结
2021/05/14 MySQL
PyTorch梯度裁剪避免训练loss nan的操作
2021/05/24 Python
Python内置数据结构列表与元组示例详解
2021/08/04 Python