详解如何从TensorFlow的mnist数据集导出手写体数字图片


Posted in Python onAugust 05, 2019

在TensorFlow的官方入门课程中,多次用到mnist数据集。

mnist数据集是一个数字手写体图片库,但它的存储格式并非常见的图片格式,所有的图片都集中保存在四个扩展名为idx3-ubyte的二进制文件。

如果我们想要知道大名鼎鼎的mnist手写体数字都长什么样子,就需要从mnist数据集中导出手写体数字图片。了解这些手写体的总体形状,也有助于加深我们对TensorFlow入门课程的理解。

下面先给出通过TensorFlow api接口导出mnist手写体数字图片的python代码,再对代码进行分析。代码在win7下测试通过,linux环境也可以参考本处代码。

(非常良心的注释和打印有木有)

#!/usr/bin/python3.5
# -*- coding: utf-8 -*-
 
import os
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
 
from PIL import Image
 
# 声明图片宽高
rows = 28
cols = 28
 
# 要提取的图片数量
images_to_extract = 8000
 
# 当前路径下的保存目录
save_dir = "./mnist_digits_images"
 
# 读入mnist数据
mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)
 
# 创建会话
sess = tf.Session()
 
# 获取图片总数
shape = sess.run(tf.shape(mnist.train.images))
images_count = shape[0]
pixels_per_image = shape[1]
 
# 获取标签总数
shape = sess.run(tf.shape(mnist.train.labels))
labels_count = shape[0]
 
# mnist.train.labels是一个二维张量,为便于后续生成数字图片目录名,有必要一维化(后来发现只要把数据集的one_hot属性设为False,mnist.train.labels本身就是一维)
#labels = sess.run(tf.argmax(mnist.train.labels, 1))
labels = mnist.train.labels
 
# 检查数据集是否符合预期格式
if (images_count == labels_count) and (shape.size == 1):
  print ("数据集总共包含 %s 张图片,和 %s 个标签" % (images_count, labels_count))
  print ("每张图片包含 %s 个像素" % (pixels_per_image))
  print ("数据类型:%s" % (mnist.train.images.dtype))
 
  # mnist图像数据的数值范围是[0,1],需要扩展到[0,255],以便于人眼观看
  if mnist.train.images.dtype == "float32":
    print ("准备将数据类型从[0,1]转为binary[0,255]...")
    for i in range(0,images_to_extract):
      for n in range(pixels_per_image):
        if mnist.train.images[i][n] != 0:
          mnist.train.images[i][n] = 255
      # 由于数据集图片数量庞大,转换可能要花不少时间,有必要打印转换进度
      if ((i+1)%50) == 0:
        print ("图像浮点数值扩展进度:已转换 %s 张,共需转换 %s 张" % (i+1, images_to_extract))
 
  # 创建数字图片的保存目录
  for i in range(10):
    dir = "%s/%s/" % (save_dir,i)
    if not os.path.exists(dir):
      print ("目录 ""%s"" 不存在!自动创建该目录..." % dir)
      os.makedirs(dir)
 
  # 通过python图片处理库,生成图片
  indices = [0 for x in range(0, 10)]
  for i in range(0,images_to_extract):
    img = Image.new("L",(cols,rows))
    for m in range(rows):
      for n in range(cols):
        img.putpixel((n,m), int(mnist.train.images[i][n+m*cols]))
    # 根据图片所代表的数字label生成对应的保存路径
    digit = labels[i]
    path = "%s/%s/%s.bmp" % (save_dir, labels[i], indices[digit])
    indices[digit] += 1
    img.save(path)
    # 由于数据集图片数量庞大,保存过程可能要花不少时间,有必要打印保存进度
    if ((i+1)%50) == 0:
      print ("图片保存进度:已保存 %s 张,共需保存 %s 张" % (i+1, images_to_extract))
  
else:
  print ("图片数量和标签数量不一致!")

上述代码的实现思路如下:

1.读入mnist手写体数据;

2.把数据的值从[0,1]浮点范围转化为黑白格式(背景为0-黑色,前景为255-白色);

3.根据mnist.train.labels的内容,生成数字索引,也就是建立每一张图片和其所代表数字的关联,由此创建对应的保存目录;

4.循环遍历mnist.train.images,把每张图片的像素数据赋值给python图片处理库PIL的Image类实例,再调用Image类的save方法把图片保存在第3步骤中创建的对应目录。

在运行上述代码之前,你需要确保本地已经安装python的图片处理库PIL,pip安装命令如下:

pip3 install Pillow

或 pip install Pillow,取决于你的pip版本。

上述python代码运行后,在当前目录下会生成mnist_digits_images目录,在该目录下,可以看到如下内容:

详解如何从TensorFlow的mnist数据集导出手写体数字图片

可以看到,我们成功地生成了黑底白字的数字图片。

如果仔细观察这些图片,会看到一些肉眼也难以分辨的数字,譬如:

详解如何从TensorFlow的mnist数据集导出手写体数字图片

详解如何从TensorFlow的mnist数据集导出手写体数字图片

上面这几个数字是2。想不到吧?

下面这两个是5(看起来更像6):

详解如何从TensorFlow的mnist数据集导出手写体数字图片详解如何从TensorFlow的mnist数据集导出手写体数字图片

这个是7:(7长这样?有句MMP不知当讲不当讲)

详解如何从TensorFlow的mnist数据集导出手写体数字图片

猜猜下面这个是什么:

详解如何从TensorFlow的mnist数据集导出手写体数字图片

这是大写的L?不是。

有点像1,是1吗?也不是。

倒立拉粑的7?sorry,又猜错了。

实话告诉您,它是2!一开始我也是不相信的,知道真相的那一刻我下巴差点掉下来!

这些手写图片,一般人用肉眼观察,识别率能达到98%就不错了,但是通过TensorFlow搭建的卷积神经网络识别率可以达到99%,非常地神奇!

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用IronPython把Python脚本集成到.NET程序中的教程
Mar 31 Python
Python教程之全局变量用法
Jun 27 Python
Python对文件和目录进行操作的方法(file对象/os/os.path/shutil 模块)
May 08 Python
python读取TXT每行,并存到LIST中的方法
Oct 26 Python
python ChainMap的使用和说明详解
Jun 11 Python
python爬虫爬取监控教务系统的思路详解
Jan 08 Python
Python字典深浅拷贝与循环方式方法详解
Feb 09 Python
python-sys.stdout作为默认函数参数的实现
Feb 21 Python
浅析Python requests 模块
Oct 09 Python
详解解决jupyter不能使用pytorch的问题
Feb 18 Python
pytorch 实现L2和L1正则化regularization的操作
Mar 03 Python
Python数据清洗工具之Numpy的基本操作
Apr 22 Python
Python获取时间范围内日期列表和周列表的函数
Aug 05 #Python
Django ORM 查询管理器源码解析
Aug 05 #Python
python实现车牌识别的示例代码
Aug 05 #Python
使用python实现滑动验证码功能
Aug 05 #Python
Django 源码WSGI剖析过程详解
Aug 05 #Python
Python使用itchat 功能分析微信好友性别和位置
Aug 05 #Python
Python队列RabbitMQ 使用方法实例记录
Aug 05 #Python
You might like
2.PHP入门
2006/10/09 PHP
php语法检查的方法总结
2019/01/21 PHP
php设计模式之抽象工厂模式分析【星际争霸游戏案例】
2020/01/23 PHP
Javascript 的addEventListener()及attachEvent()区别分析
2009/05/21 Javascript
javascript 验证日期的函数
2010/03/18 Javascript
Extjs显示从数据库取出时间转换JSON后的出现问题
2012/11/20 Javascript
document.createElement()用法
2013/03/13 Javascript
Javascript页面添加到收藏夹的简单方法
2013/08/07 Javascript
使用js显示当前时间示例
2014/03/02 Javascript
基于javascript实现动态显示当前系统时间
2016/01/28 Javascript
Jquery中map函数的用法
2016/06/03 Javascript
微信小程序开发探究
2016/12/27 Javascript
原生js仿浏览器滚动条效果
2017/03/02 Javascript
node.js中EJS 模板快速入门教程
2017/05/08 Javascript
vue router-link传参以及参数的使用实例
2017/11/10 Javascript
jquery 输入框查找关键字并提亮颜色的实例代码
2018/01/23 jQuery
JavaScript中toLocaleString()和toString()的区别实例分析
2018/08/14 Javascript
详细讲解如何创建, 发布自己的 Vue UI 组件库
2019/05/29 Javascript
js实现无缝滚动双图切换效果
2019/07/09 Javascript
vue element-ui table组件动态生成表头和数据并修改单元格格式 父子组件通信
2019/08/15 Javascript
Vue Cli3 打包配置并自动忽略console.log语句的方法
2020/04/23 Javascript
在antd中setFieldsValue和defaultVal的用法
2020/10/29 Javascript
[38:23]2014 DOTA2国际邀请赛中国区预选赛 5 23 CIS VS LGD第一场
2014/05/24 DOTA
[01:01:35]Optic vs paiN 2018国际邀请赛小组赛BO2 第二场 8.19
2018/08/21 DOTA
[58:35]OG vs EG 2019国际邀请赛淘汰赛 胜者组 BO3 第二场 8.22
2019/09/05 DOTA
[36:16]完美世界DOTA2联赛PWL S3 access vs Rebirth 第一场 12.19
2020/12/24 DOTA
使用Python横向合并excel文件的实例
2018/12/11 Python
css3一个简易的 LED 数字时钟实现方法
2020/01/15 HTML / CSS
科沃斯机器人官网商城:Ecovacs
2016/08/29 全球购物
东南亚旅游平台:The Trip Guru
2018/01/01 全球购物
BLACKMORES澳洲官网:澳大利亚排名第一的保健品牌
2018/09/27 全球购物
俄语翻译实习生的自我评价分享
2013/11/06 职场文书
2013年研究生毕业感言
2014/02/06 职场文书
单位活动策划方案
2014/08/17 职场文书
小学语文复习计划
2015/01/19 职场文书
Python爬虫入门案例之回车桌面壁纸网美女图片采集
2021/10/16 Python