TensorFlow MNIST手写数据集的实现方法


Posted in Python onFebruary 05, 2020

MNIST数据集介绍

MNIST数据集中包含了各种各样的手写数字图片,数据集的官网是:http://yann.lecun.com/exdb/mnist/index.html,我们可以从这里下载数据集。使用如下的代码对数据集进行加载:

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

运行上述代码会自动下载数据集并将文件解压在MNIST_data文件夹下面。代码中的one_hot=True,表示将样本的标签转化为one_hot编码。

MNIST数据集中的图片是28*28的,每张图被转化为一个行向量,长度是28*28=784,每一个值代表一个像素点。数据集中共有60000张手写数据图片,其中55000张训练数据,5000张测试数据。

在MNIST中,mnist.train.images是一个形状为[55000, 784]的张量,其中的第一个维度是用来索引图片,第二个维度图片中的像素。MNIST数据集包含有三部分,训练数据集,验证数据集,测试数据集(mnist.validation)。

标签是介于0-9之间的数字,用于描述图片中的数字,转化为one-hot向量即表示的数字对应的下标为1,其余的值为0。标签的训练数据是[55000,10]的数字矩阵。

下面定义了一个简单的网络对数据集进行训练,代码如下:

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
tf.reset_default_graph()
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])
w = tf.Variable(tf.random_normal([784, 10]))
b = tf.Variable(tf.zeros([10]))
pred = tf.matmul(x, w) + b
pred = tf.nn.softmax(pred)
cost = tf.reduce_mean(-tf.reduce_sum(y * tf.log(pred), reduction_indices=1))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
training_epochs = 25
batch_size = 100
display_step = 1
save_path = 'model/'
saver = tf.train.Saver()
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  for epoch in range(training_epochs):
    avg_cost = 0
    total_batch = int(mnist.train.num_examples/batch_size)
    for i in range(total_batch):
      batch_xs, batch_ys = mnist.train.next_batch(batch_size)
      _, c = sess.run([optimizer, cost], feed_dict={x:batch_xs, y:batch_ys})
      avg_cost += c / total_batch
    if (epoch + 1) % display_step == 0:
      print('epoch= ', epoch+1, ' cost= ', avg_cost)
  print('finished')
  correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
  accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  print('accuracy: ', accuracy.eval({x:mnist.test.images, y:mnist.test.labels}))
  save = saver.save(sess, save_path=save_path+'mnist.cpkt')
print(" starting 2nd session ...... ")
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  saver.restore(sess, save_path=save_path+'mnist.cpkt')
  correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
  accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  print('accuracy: ', accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))
  output = tf.argmax(pred, 1)
  batch_xs, batch_ys = mnist.test.next_batch(2)
  outputval= sess.run([output], feed_dict={x:batch_xs, y:batch_ys})
  print(outputval)
  im = batch_xs[0]
  im = im.reshape(-1, 28)
  plt.imshow(im, cmap='gray')
  plt.show()
  im = batch_xs[1]
  im = im.reshape(-1, 28)
  plt.imshow(im, cmap='gray')
  plt.show()

总结

以上所述是小编给大家介绍的TensorFlow MNIST手写数据集的实现方法,希望对大家有所帮助!

Python 相关文章推荐
Python __setattr__、 __getattr__、 __delattr__、__call__用法示例
Mar 06 Python
python字符串的常用操作方法小结
May 21 Python
Python实现生成随机数据插入mysql数据库的方法
Dec 25 Python
python+tkinter编写电脑桌面放大镜程序实例代码
Jan 16 Python
python+pyqt5实现KFC点餐收银系统
Jan 24 Python
Python自动化完成tb喵币任务的操作方法
Oct 30 Python
python GUI库图形界面开发之PyQt5信号与槽机制、自定义信号基础介绍
Feb 25 Python
Keras中 ImageDataGenerator函数的参数用法
Jul 03 Python
对python中list的五种查找方法说明
Jul 13 Python
Python学习笔记之装饰器
Aug 06 Python
python中pathlib模块的基本用法与总结
Aug 17 Python
pyqt5打包成exe可执行文件的方法
May 14 Python
tensorflow之并行读入数据详解
Feb 05 #Python
tensorflow mnist 数据加载实现并画图效果
Feb 05 #Python
tensorflow 自定义损失函数示例代码
Feb 05 #Python
利用Tensorflow的队列多线程读取数据方式
Feb 05 #Python
Tensorflow 多线程与多进程数据加载实例
Feb 05 #Python
TensorFlow自定义损失函数来预测商品销售量
Feb 05 #Python
解决Tensorflow 内存泄露问题
Feb 05 #Python
You might like
php实例分享之html转为rtf格式
2014/06/02 PHP
PHP中把对象数组转换成普通数组的方法
2015/07/10 PHP
Thinkphp无限级分类代码
2015/11/11 PHP
jquery 学习之二 属性(类)
2010/11/25 Javascript
js动态删除div元素基本思路及实现代码
2014/05/08 Javascript
jQuery实现菜单式图片滑动切换
2015/03/14 Javascript
jQuery实现html元素拖拽
2015/07/21 Javascript
jQuery数据类型小结(14个)
2016/01/08 Javascript
JS实现获取剪贴板内容的方法
2016/06/21 Javascript
微信小程序实现YDUI的ScrollNav组件
2018/02/02 Javascript
详解如何在Angular优雅编写HTTP请求
2018/12/05 Javascript
[02:21]十步杀一人,千里不留行——DOTA2全新英雄天涯墨客展示
2018/08/29 DOTA
[00:37]食人魔魔法师轮盘吉兆顺应全新至宝将拥有额外款式
2019/12/19 DOTA
Python中的进程分支fork和exec详解
2015/04/11 Python
将Emacs打造成强大的Python代码编辑工具
2015/11/20 Python
python中PIL安装简单教程
2016/04/21 Python
Python中的sort()方法使用基础教程
2017/01/08 Python
python使用__slots__让你的代码更加节省内存
2018/09/05 Python
Python实现的服务器示例小结【单进程、多进程、多线程、非阻塞式】
2019/05/23 Python
NumPy排序的实现
2020/01/21 Python
深入浅析python的第三方库pandas
2020/02/13 Python
使用python3 实现插入数据到mysql
2020/03/02 Python
解决keras backend 越跑越慢问题
2020/06/18 Python
学python需要去培训机构吗
2020/07/01 Python
聊聊python中的异常嵌套
2020/09/01 Python
python 如何实现遗传算法
2020/09/22 Python
顶丰TOPPIK台湾官网:增发纤维假发,告别秃发困扰
2018/06/13 全球购物
美国在线购物频道:Shop LC
2019/04/21 全球购物
阿里巴巴英国:Alibaba英国
2019/12/11 全球购物
总经理办公室主任岗位职责
2013/11/12 职场文书
人事行政主管岗位职责
2013/12/22 职场文书
导游词开场白
2015/01/31 职场文书
餐饮服务员岗位职责
2015/02/09 职场文书
如何利用JavaScript实现二叉搜索树
2021/04/02 Javascript
Linux磁盘管理方法介绍
2022/06/01 Servers
前端传参数进行Mybatis调用mysql存储过程执行返回值详解
2022/08/14 MySQL