python MNIST手写识别数据调用API的方法


Posted in Python onAugust 08, 2018

MNIST数据集比较小,一般入门机器学习都会采用这个数据集来训练

下载地址:yann.lecun.com/exdb/mnist/

有4个有用的文件:
train-images-idx3-ubyte: training set images
train-labels-idx1-ubyte: training set labels
t10k-images-idx3-ubyte: test set images
t10k-labels-idx1-ubyte: test set labels

The training set contains 60000 examples, and the test set 10000 examples. 数据集存储是用binary file存储的,黑白图片。

下面给出load数据集的代码:

import os
import struct
import numpy as np
import matplotlib.pyplot as plt

def load_mnist():
  '''
  Load mnist data
  http://yann.lecun.com/exdb/mnist/

  60000 training examples
  10000 test sets

  Arguments:
    kind: 'train' or 'test', string charater input with a default value 'train'

  Return:
    xxx_images: n*m array, n is the sample count, m is the feature number which is 28*28
    xxx_labels: class labels for each image, (0-9)
  '''

  root_path = '/home/cc/deep_learning/data_sets/mnist'

  train_labels_path = os.path.join(root_path, 'train-labels.idx1-ubyte')
  train_images_path = os.path.join(root_path, 'train-images.idx3-ubyte')

  test_labels_path = os.path.join(root_path, 't10k-labels.idx1-ubyte')
  test_images_path = os.path.join(root_path, 't10k-images.idx3-ubyte')

  with open(train_labels_path, 'rb') as lpath:
    # '>' denotes bigedian
    # 'I' denotes unsigned char
    magic, n = struct.unpack('>II', lpath.read(8))
    #loaded = np.fromfile(lpath, dtype = np.uint8)
    train_labels = np.fromfile(lpath, dtype = np.uint8).astype(np.float)

  with open(train_images_path, 'rb') as ipath:
    magic, num, rows, cols = struct.unpack('>IIII', ipath.read(16))
    loaded = np.fromfile(train_images_path, dtype = np.uint8)
    # images start from the 16th bytes
    train_images = loaded[16:].reshape(len(train_labels), 784).astype(np.float)

  with open(test_labels_path, 'rb') as lpath:
    # '>' denotes bigedian
    # 'I' denotes unsigned char
    magic, n = struct.unpack('>II', lpath.read(8))
    #loaded = np.fromfile(lpath, dtype = np.uint8)
    test_labels = np.fromfile(lpath, dtype = np.uint8).astype(np.float)

  with open(test_images_path, 'rb') as ipath:
    magic, num, rows, cols = struct.unpack('>IIII', ipath.read(16))
    loaded = np.fromfile(test_images_path, dtype = np.uint8)
    # images start from the 16th bytes
    test_images = loaded[16:].reshape(len(test_labels), 784)  

  return train_images, train_labels, test_images, test_labels

再看看图片集是什么样的:

def test_mnist_data():
  '''
  Just to check the data

  Argument:
    none

  Return:
    none
  '''
  train_images, train_labels, test_images, test_labels = load_mnist()
  fig, ax = plt.subplots(nrows = 2, ncols = 5, sharex = True, sharey = True)
  ax =ax.flatten()
  for i in range(10):
    img = train_images[i][:].reshape(28, 28)
    ax[i].imshow(img, cmap = 'Greys', interpolation = 'nearest')
    print('corresponding labels = %d' %train_labels[i])

if __name__ == '__main__':
  test_mnist_data()

跑出的结果如下:

python MNIST手写识别数据调用API的方法

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
深入讲解Java编程中类的生命周期
Feb 05 Python
老生常谈Python序列化和反序列化
Jun 28 Python
Python 使用PIL中的resize进行缩放的实例讲解
Aug 03 Python
Django 实现前端图片压缩功能的方法
Aug 07 Python
Python类中的魔法方法之 __slots__原理解析
Aug 26 Python
python django生成迁移文件的实例
Aug 31 Python
python已协程方式处理任务实现过程
Dec 27 Python
Python中if有多个条件处理方法
Feb 26 Python
浅谈keras使用预训练模型vgg16分类,损失和准确度不变
Jul 02 Python
Python 整行读取文本方法并去掉readlines换行\n操作
Sep 03 Python
python爬虫scrapy框架的梨视频案例解析
Feb 20 Python
Python中Numpy和Matplotlib的基本使用指南
Nov 02 Python
python实现屏保计时器的示例代码
Aug 08 #Python
详解Python 装饰器执行顺序迷思
Aug 08 #Python
python Flask 装饰器顺序问题解决
Aug 08 #Python
Python BS4库的安装与使用详解
Aug 08 #Python
python特性语法之遍历、公共方法、引用
Aug 08 #Python
用Python shell简化开发
Aug 08 #Python
在Python中使用gRPC的方法示例
Aug 08 #Python
You might like
zend optimizer在wamp的基础上安装图文教程
2013/10/26 PHP
php5.3 goto函数介绍和示例
2014/03/21 PHP
完美的2个php检测字符串是否是utf-8编码函数分享
2014/07/28 PHP
PHP上传文件参考配置大文件上传
2015/12/16 PHP
PHP图像识别技术原理与实现
2016/10/27 PHP
php 生成加密公钥加密私钥实例详解
2017/06/16 PHP
PHP实现的多维数组排序算法分析
2018/02/10 PHP
php使用fullcalendar日历插件详解
2019/03/06 PHP
Laravel 中创建 Zip 压缩文件并提供下载的实现方法
2019/04/02 PHP
利用jq让你的div居中的好方法分享
2013/11/21 Javascript
jquery实现保存已选用户
2014/07/21 Javascript
jquery pagination插件动态分页实例(Bootstrap分页)
2016/12/23 Javascript
js记录点击某个按钮的次数-刷新次数为初始状态的实例
2017/02/15 Javascript
javascript定时器取消定时器及优化方法
2017/07/08 Javascript
vue滚动轴插件better-scroll使用详解
2017/10/17 Javascript
Angular实现点击按钮后在上方显示输入内容的方法
2017/12/27 Javascript
vue中post请求以a=a&b=b 的格式写遇到的问题
2018/04/27 Javascript
Angular5集成eventbus的示例代码
2018/07/19 Javascript
解决Vue+Element ui开发中碰到的IE问题
2018/09/03 Javascript
Python 字符串操作实现代码(截取/替换/查找/分割)
2013/06/08 Python
Python中import机制详解
2017/11/14 Python
python SSH模块登录,远程机执行shell命令实例解析
2018/01/12 Python
python实现zabbix发送短信脚本
2018/09/17 Python
浅谈Python中threading join和setDaemon用法及区别说明
2020/05/02 Python
浅谈优化Django ORM中的性能问题
2020/07/09 Python
利用CSS3的特性改变文本选中时的颜色
2013/09/11 HTML / CSS
印尼太阳百货公司网站:Matahari
2018/02/04 全球购物
bonprix匈牙利:女士、男士和儿童服装
2019/07/19 全球购物
美津浓巴西官方网站:Mizuno巴西
2019/07/24 全球购物
Java面试题及答案
2012/09/08 面试题
输入一行文字,找出其中大写字母、小写字母、空格、数字、及其他字符各有多少
2016/04/15 面试题
在网络中有两台主机A和B,并通过路由器和其他交换设备连接起来,已经确认物理连接正确无误,怎么来测试这两台机器是否连通?如果不通,怎么来判断故障点?怎么排
2014/01/13 面试题
护士自我评价
2014/02/01 职场文书
2014学习优秀共产党员先进事迹材料思想汇报
2014/09/14 职场文书
2016秋季幼儿园开学寄语
2015/12/03 职场文书
作文之亲情600字
2019/09/23 职场文书