详解如何从TensorFlow的mnist数据集导出手写体数字图片


Posted in Python onAugust 05, 2019

在TensorFlow的官方入门课程中,多次用到mnist数据集。

mnist数据集是一个数字手写体图片库,但它的存储格式并非常见的图片格式,所有的图片都集中保存在四个扩展名为idx3-ubyte的二进制文件。

如果我们想要知道大名鼎鼎的mnist手写体数字都长什么样子,就需要从mnist数据集中导出手写体数字图片。了解这些手写体的总体形状,也有助于加深我们对TensorFlow入门课程的理解。

下面先给出通过TensorFlow api接口导出mnist手写体数字图片的python代码,再对代码进行分析。代码在win7下测试通过,linux环境也可以参考本处代码。

(非常良心的注释和打印有木有)

#!/usr/bin/python3.5
# -*- coding: utf-8 -*-
 
import os
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
 
from PIL import Image
 
# 声明图片宽高
rows = 28
cols = 28
 
# 要提取的图片数量
images_to_extract = 8000
 
# 当前路径下的保存目录
save_dir = "./mnist_digits_images"
 
# 读入mnist数据
mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)
 
# 创建会话
sess = tf.Session()
 
# 获取图片总数
shape = sess.run(tf.shape(mnist.train.images))
images_count = shape[0]
pixels_per_image = shape[1]
 
# 获取标签总数
shape = sess.run(tf.shape(mnist.train.labels))
labels_count = shape[0]
 
# mnist.train.labels是一个二维张量,为便于后续生成数字图片目录名,有必要一维化(后来发现只要把数据集的one_hot属性设为False,mnist.train.labels本身就是一维)
#labels = sess.run(tf.argmax(mnist.train.labels, 1))
labels = mnist.train.labels
 
# 检查数据集是否符合预期格式
if (images_count == labels_count) and (shape.size == 1):
  print ("数据集总共包含 %s 张图片,和 %s 个标签" % (images_count, labels_count))
  print ("每张图片包含 %s 个像素" % (pixels_per_image))
  print ("数据类型:%s" % (mnist.train.images.dtype))
 
  # mnist图像数据的数值范围是[0,1],需要扩展到[0,255],以便于人眼观看
  if mnist.train.images.dtype == "float32":
    print ("准备将数据类型从[0,1]转为binary[0,255]...")
    for i in range(0,images_to_extract):
      for n in range(pixels_per_image):
        if mnist.train.images[i][n] != 0:
          mnist.train.images[i][n] = 255
      # 由于数据集图片数量庞大,转换可能要花不少时间,有必要打印转换进度
      if ((i+1)%50) == 0:
        print ("图像浮点数值扩展进度:已转换 %s 张,共需转换 %s 张" % (i+1, images_to_extract))
 
  # 创建数字图片的保存目录
  for i in range(10):
    dir = "%s/%s/" % (save_dir,i)
    if not os.path.exists(dir):
      print ("目录 ""%s"" 不存在!自动创建该目录..." % dir)
      os.makedirs(dir)
 
  # 通过python图片处理库,生成图片
  indices = [0 for x in range(0, 10)]
  for i in range(0,images_to_extract):
    img = Image.new("L",(cols,rows))
    for m in range(rows):
      for n in range(cols):
        img.putpixel((n,m), int(mnist.train.images[i][n+m*cols]))
    # 根据图片所代表的数字label生成对应的保存路径
    digit = labels[i]
    path = "%s/%s/%s.bmp" % (save_dir, labels[i], indices[digit])
    indices[digit] += 1
    img.save(path)
    # 由于数据集图片数量庞大,保存过程可能要花不少时间,有必要打印保存进度
    if ((i+1)%50) == 0:
      print ("图片保存进度:已保存 %s 张,共需保存 %s 张" % (i+1, images_to_extract))
  
else:
  print ("图片数量和标签数量不一致!")

上述代码的实现思路如下:

1.读入mnist手写体数据;

2.把数据的值从[0,1]浮点范围转化为黑白格式(背景为0-黑色,前景为255-白色);

3.根据mnist.train.labels的内容,生成数字索引,也就是建立每一张图片和其所代表数字的关联,由此创建对应的保存目录;

4.循环遍历mnist.train.images,把每张图片的像素数据赋值给python图片处理库PIL的Image类实例,再调用Image类的save方法把图片保存在第3步骤中创建的对应目录。

在运行上述代码之前,你需要确保本地已经安装python的图片处理库PIL,pip安装命令如下:

pip3 install Pillow

或 pip install Pillow,取决于你的pip版本。

上述python代码运行后,在当前目录下会生成mnist_digits_images目录,在该目录下,可以看到如下内容:

详解如何从TensorFlow的mnist数据集导出手写体数字图片

可以看到,我们成功地生成了黑底白字的数字图片。

如果仔细观察这些图片,会看到一些肉眼也难以分辨的数字,譬如:

详解如何从TensorFlow的mnist数据集导出手写体数字图片

详解如何从TensorFlow的mnist数据集导出手写体数字图片

上面这几个数字是2。想不到吧?

下面这两个是5(看起来更像6):

详解如何从TensorFlow的mnist数据集导出手写体数字图片详解如何从TensorFlow的mnist数据集导出手写体数字图片

这个是7:(7长这样?有句MMP不知当讲不当讲)

详解如何从TensorFlow的mnist数据集导出手写体数字图片

猜猜下面这个是什么:

详解如何从TensorFlow的mnist数据集导出手写体数字图片

这是大写的L?不是。

有点像1,是1吗?也不是。

倒立拉粑的7?sorry,又猜错了。

实话告诉您,它是2!一开始我也是不相信的,知道真相的那一刻我下巴差点掉下来!

这些手写图片,一般人用肉眼观察,识别率能达到98%就不错了,但是通过TensorFlow搭建的卷积神经网络识别率可以达到99%,非常地神奇!

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python用10行代码实现对黄色图片的检测功能
Aug 10 Python
安装ElasticSearch搜索工具并配置Python驱动的方法
Dec 22 Python
Python urls.py的三种配置写法实例详解
Apr 28 Python
基于python时间处理方法(详解)
Aug 14 Python
Python实现基于多线程、多用户的FTP服务器与客户端功能完整实例
Aug 18 Python
R语言 vs Python对比:数据分析哪家强?
Nov 17 Python
在Qt5和PyQt5中设置支持高分辨率屏幕自适应的方法
Jun 18 Python
python多线程与多进程及其区别详解
Aug 08 Python
python selenium爬取斗鱼所有直播房间信息过程详解
Aug 09 Python
python 实现二维字典的键值合并等函数
Dec 06 Python
python实现单张图像拼接与批量图片拼接
Mar 23 Python
删除pycharm鼠标右键快捷键打开项目的操作
Jan 16 Python
Python获取时间范围内日期列表和周列表的函数
Aug 05 #Python
Django ORM 查询管理器源码解析
Aug 05 #Python
python实现车牌识别的示例代码
Aug 05 #Python
使用python实现滑动验证码功能
Aug 05 #Python
Django 源码WSGI剖析过程详解
Aug 05 #Python
Python使用itchat 功能分析微信好友性别和位置
Aug 05 #Python
Python队列RabbitMQ 使用方法实例记录
Aug 05 #Python
You might like
当海贼王变成JOJO风
2020/03/02 日漫
php获取目录所有文件并将结果保存到数组(实例)
2013/10/25 PHP
php数组查找函数in_array()、array_search()、array_key_exists()使用实例
2014/04/29 PHP
ThinkPHP提示错误Fatal error: Allowed memory size的解决方法
2015/02/12 PHP
php 微信公众平台开发模式实现多客服的实例代码
2016/11/07 PHP
php基于dom实现读取图书xml格式数据的方法
2017/02/03 PHP
Iframe thickbox2.0使用的方法
2009/03/05 Javascript
关于jQuery UI 使用心得及技巧
2012/10/10 Javascript
javascript中的onkeyup和onkeydown区别介绍
2013/04/28 Javascript
JS实现匀速运动的代码实例
2013/11/29 Javascript
浅析JavaScript原型继承的陷阱
2013/12/03 Javascript
jQuery的live()方法对hover事件的处理示例
2014/02/27 Javascript
javascript event在FF和IE的兼容传参心得(绝对好用)
2014/07/10 Javascript
JavaScript中property和attribute的区别详细介绍
2015/03/03 Javascript
如何实现JavaScript动态加载CSS和JS文件
2020/12/28 Javascript
BootStrap 表单控件之单选按钮水平排列
2017/05/23 Javascript
vue利用better-scroll实现轮播图与页面滚动详解
2017/10/20 Javascript
基于vue cli重构多页面脚手架过程详解
2018/01/23 Javascript
JavaScript日期工具类DateUtils定义与用法示例
2018/09/03 Javascript
js前端面试之同步与异步问题详解
2019/04/03 Javascript
Vue路由守卫之路由独享守卫
2019/09/25 Javascript
js实现旋转的星空效果
2019/11/01 Javascript
vue 实现把路由单独分离出来
2020/08/13 Javascript
python获取网页状态码示例
2014/03/30 Python
Python提示[Errno 32]Broken pipe导致线程crash错误解决方法
2014/11/19 Python
python实现实时视频流播放代码实例
2020/01/11 Python
浅谈Keras中shuffle和validation_split的顺序
2020/06/19 Python
python将下载到本地m3u8视频合成MP4的代码详解
2020/11/24 Python
什么是CSS3 HSLA色彩模式?HSLA模拟渐变色条
2016/04/26 HTML / CSS
CSS3 Flex 弹性布局实例代码详解
2018/11/01 HTML / CSS
Roxy美国官网:澳大利亚冲浪、滑雪健身品牌
2016/07/30 全球购物
美国知名的百货清仓店:Neiman Marcus Last Call
2016/08/03 全球购物
小班重阳节活动方案
2014/02/08 职场文书
公司合并协议书范本
2014/09/30 职场文书
个人房屋买卖协议书(范本)
2014/10/04 职场文书
钳工实训报告总结
2014/11/04 职场文书