TensorFlow MNIST手写数据集的实现方法


Posted in Python onFebruary 05, 2020

MNIST数据集介绍

MNIST数据集中包含了各种各样的手写数字图片,数据集的官网是:http://yann.lecun.com/exdb/mnist/index.html,我们可以从这里下载数据集。使用如下的代码对数据集进行加载:

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

运行上述代码会自动下载数据集并将文件解压在MNIST_data文件夹下面。代码中的one_hot=True,表示将样本的标签转化为one_hot编码。

MNIST数据集中的图片是28*28的,每张图被转化为一个行向量,长度是28*28=784,每一个值代表一个像素点。数据集中共有60000张手写数据图片,其中55000张训练数据,5000张测试数据。

在MNIST中,mnist.train.images是一个形状为[55000, 784]的张量,其中的第一个维度是用来索引图片,第二个维度图片中的像素。MNIST数据集包含有三部分,训练数据集,验证数据集,测试数据集(mnist.validation)。

标签是介于0-9之间的数字,用于描述图片中的数字,转化为one-hot向量即表示的数字对应的下标为1,其余的值为0。标签的训练数据是[55000,10]的数字矩阵。

下面定义了一个简单的网络对数据集进行训练,代码如下:

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
tf.reset_default_graph()
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])
w = tf.Variable(tf.random_normal([784, 10]))
b = tf.Variable(tf.zeros([10]))
pred = tf.matmul(x, w) + b
pred = tf.nn.softmax(pred)
cost = tf.reduce_mean(-tf.reduce_sum(y * tf.log(pred), reduction_indices=1))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
training_epochs = 25
batch_size = 100
display_step = 1
save_path = 'model/'
saver = tf.train.Saver()
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  for epoch in range(training_epochs):
    avg_cost = 0
    total_batch = int(mnist.train.num_examples/batch_size)
    for i in range(total_batch):
      batch_xs, batch_ys = mnist.train.next_batch(batch_size)
      _, c = sess.run([optimizer, cost], feed_dict={x:batch_xs, y:batch_ys})
      avg_cost += c / total_batch
    if (epoch + 1) % display_step == 0:
      print('epoch= ', epoch+1, ' cost= ', avg_cost)
  print('finished')
  correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
  accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  print('accuracy: ', accuracy.eval({x:mnist.test.images, y:mnist.test.labels}))
  save = saver.save(sess, save_path=save_path+'mnist.cpkt')
print(" starting 2nd session ...... ")
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  saver.restore(sess, save_path=save_path+'mnist.cpkt')
  correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
  accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  print('accuracy: ', accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))
  output = tf.argmax(pred, 1)
  batch_xs, batch_ys = mnist.test.next_batch(2)
  outputval= sess.run([output], feed_dict={x:batch_xs, y:batch_ys})
  print(outputval)
  im = batch_xs[0]
  im = im.reshape(-1, 28)
  plt.imshow(im, cmap='gray')
  plt.show()
  im = batch_xs[1]
  im = im.reshape(-1, 28)
  plt.imshow(im, cmap='gray')
  plt.show()

总结

以上所述是小编给大家介绍的TensorFlow MNIST手写数据集的实现方法,希望对大家有所帮助!

Python 相关文章推荐
Python操作json数据的一个简单例子
Apr 17 Python
用Python实现QQ游戏大家来找茬辅助工具
Sep 14 Python
Python中解析JSON并同时进行自定义编码处理实例
Feb 08 Python
Python中使用PIL库实现图片高斯模糊实例
Feb 08 Python
Python将xml和xsl转换为html的方法
Mar 10 Python
python简单实现计算过期时间的方法
Jun 09 Python
Python升级导致yum、pip报错的解决方法
Sep 06 Python
Python 比较文本相似性的方法(difflib,Levenshtein)
Oct 15 Python
python多个模块py文件的数据共享实例
Jan 11 Python
wxPython色环电阻计算器
Nov 18 Python
python__new__内置静态方法使用解析
Jan 07 Python
python 远程执行命令的详细代码
Feb 15 Python
tensorflow之并行读入数据详解
Feb 05 #Python
tensorflow mnist 数据加载实现并画图效果
Feb 05 #Python
tensorflow 自定义损失函数示例代码
Feb 05 #Python
利用Tensorflow的队列多线程读取数据方式
Feb 05 #Python
Tensorflow 多线程与多进程数据加载实例
Feb 05 #Python
TensorFlow自定义损失函数来预测商品销售量
Feb 05 #Python
解决Tensorflow 内存泄露问题
Feb 05 #Python
You might like
thinkPHP中钩子的使用方法实例分析
2017/11/16 PHP
浅谈PHP之ThinkPHP框架使用详解
2020/07/21 PHP
菜鸟javascript基础整理1
2010/12/06 Javascript
yepnope.js 异步加载资源文件
2011/09/08 Javascript
jQuery判断iframe中元素是否存在的方法
2013/05/11 Javascript
display和visibility的区别示例介绍
2014/02/26 Javascript
jQuery判断多个input file 都不能为空的例子
2015/06/23 Javascript
js实现文件上传表单域美化特效
2015/11/02 Javascript
js倒计时简单实现方法
2015/12/17 Javascript
基于jquery实现最简单的选项卡切换效果
2016/05/08 Javascript
微信小程序 二维码canvas绘制实例详解
2017/01/06 Javascript
node.js 发布订阅模式的实例
2017/09/10 Javascript
Angular6 用户自定义标签开发的实现方法
2019/01/08 Javascript
layui table去掉右侧滑动条的实现方法
2019/09/05 Javascript
记录微信小程序 height: calc(xx - xx);无效问题
2019/12/30 Javascript
JavaScript实现PC端横向轮播图
2020/02/07 Javascript
Python编程中装饰器的使用示例解析
2016/06/20 Python
python中实现k-means聚类算法详解
2017/11/11 Python
Python中的defaultdict与__missing__()使用介绍
2018/02/03 Python
Win10下python3.5和python2.7环境变量配置教程
2018/09/18 Python
pyqt5 tablewidget 利用线程动态刷新数据的方法
2019/06/17 Python
python中字符串数组逆序排列方法总结
2019/06/23 Python
html5基础教程常用技巧整理
2013/08/20 HTML / CSS
HTML5拖拽功能实现的拼图游戏
2018/07/31 HTML / CSS
俄罗斯旅游网站:Tripadvisor俄罗斯
2017/03/21 全球购物
佳能法国商店:Canon法国
2019/02/14 全球购物
Chemist Warehouse中文网:澳洲连锁大药房
2021/02/05 全球购物
简历中个人自我评价范文
2013/12/26 职场文书
计算机专业毕业生自荐信
2013/12/31 职场文书
小学生保护环境倡议书
2014/05/15 职场文书
实习护士自荐信
2014/06/21 职场文书
2015年办公室个人工作总结
2015/04/20 职场文书
涨价通知
2015/04/23 职场文书
撤诉申请怎么写
2015/05/19 职场文书
教师继续教育反思周记
2015/06/25 职场文书
Oracle使用别名的好处
2022/04/19 Oracle