TensorFlow MNIST手写数据集的实现方法


Posted in Python onFebruary 05, 2020

MNIST数据集介绍

MNIST数据集中包含了各种各样的手写数字图片,数据集的官网是:http://yann.lecun.com/exdb/mnist/index.html,我们可以从这里下载数据集。使用如下的代码对数据集进行加载:

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

运行上述代码会自动下载数据集并将文件解压在MNIST_data文件夹下面。代码中的one_hot=True,表示将样本的标签转化为one_hot编码。

MNIST数据集中的图片是28*28的,每张图被转化为一个行向量,长度是28*28=784,每一个值代表一个像素点。数据集中共有60000张手写数据图片,其中55000张训练数据,5000张测试数据。

在MNIST中,mnist.train.images是一个形状为[55000, 784]的张量,其中的第一个维度是用来索引图片,第二个维度图片中的像素。MNIST数据集包含有三部分,训练数据集,验证数据集,测试数据集(mnist.validation)。

标签是介于0-9之间的数字,用于描述图片中的数字,转化为one-hot向量即表示的数字对应的下标为1,其余的值为0。标签的训练数据是[55000,10]的数字矩阵。

下面定义了一个简单的网络对数据集进行训练,代码如下:

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
tf.reset_default_graph()
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])
w = tf.Variable(tf.random_normal([784, 10]))
b = tf.Variable(tf.zeros([10]))
pred = tf.matmul(x, w) + b
pred = tf.nn.softmax(pred)
cost = tf.reduce_mean(-tf.reduce_sum(y * tf.log(pred), reduction_indices=1))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
training_epochs = 25
batch_size = 100
display_step = 1
save_path = 'model/'
saver = tf.train.Saver()
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  for epoch in range(training_epochs):
    avg_cost = 0
    total_batch = int(mnist.train.num_examples/batch_size)
    for i in range(total_batch):
      batch_xs, batch_ys = mnist.train.next_batch(batch_size)
      _, c = sess.run([optimizer, cost], feed_dict={x:batch_xs, y:batch_ys})
      avg_cost += c / total_batch
    if (epoch + 1) % display_step == 0:
      print('epoch= ', epoch+1, ' cost= ', avg_cost)
  print('finished')
  correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
  accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  print('accuracy: ', accuracy.eval({x:mnist.test.images, y:mnist.test.labels}))
  save = saver.save(sess, save_path=save_path+'mnist.cpkt')
print(" starting 2nd session ...... ")
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  saver.restore(sess, save_path=save_path+'mnist.cpkt')
  correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
  accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  print('accuracy: ', accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))
  output = tf.argmax(pred, 1)
  batch_xs, batch_ys = mnist.test.next_batch(2)
  outputval= sess.run([output], feed_dict={x:batch_xs, y:batch_ys})
  print(outputval)
  im = batch_xs[0]
  im = im.reshape(-1, 28)
  plt.imshow(im, cmap='gray')
  plt.show()
  im = batch_xs[1]
  im = im.reshape(-1, 28)
  plt.imshow(im, cmap='gray')
  plt.show()

总结

以上所述是小编给大家介绍的TensorFlow MNIST手写数据集的实现方法,希望对大家有所帮助!

Python 相关文章推荐
Python向日志输出中添加上下文信息
May 24 Python
Python实现网站注册验证码生成类
Jun 08 Python
学习python中matplotlib绘图设置坐标轴刻度、文本
Feb 07 Python
儿童学习python的一些小技巧
May 27 Python
python实现输入数字的连续加减方法
Jun 22 Python
深度辨析Python的eval()与exec()的方法
Mar 26 Python
Python中py文件转换成exe可执行文件的方法
Jun 14 Python
python中property属性的介绍及其应用详解
Aug 29 Python
flask框架自定义url转换器操作详解
Jan 25 Python
python中提高pip install速度
Feb 14 Python
基于CentOS搭建Python Django环境过程解析
Aug 24 Python
python爬取2021猫眼票房字体加密实例
Feb 19 Python
tensorflow之并行读入数据详解
Feb 05 #Python
tensorflow mnist 数据加载实现并画图效果
Feb 05 #Python
tensorflow 自定义损失函数示例代码
Feb 05 #Python
利用Tensorflow的队列多线程读取数据方式
Feb 05 #Python
Tensorflow 多线程与多进程数据加载实例
Feb 05 #Python
TensorFlow自定义损失函数来预测商品销售量
Feb 05 #Python
解决Tensorflow 内存泄露问题
Feb 05 #Python
You might like
用PHP动态创建Flash动画
2006/10/09 PHP
PHP Header用于页面跳转要注意的几个问题总结
2008/10/03 PHP
利用php获取服务器时间的实现代码
2013/06/07 PHP
深入解析WordPress中加载模板的get_template_part函数
2016/01/11 PHP
用javascript做拖动布局的思路
2008/05/31 Javascript
javascript(jquery)利用函数修改全局变量的代码
2009/11/02 Javascript
isArray()函数(JavaScript中对象类型判断的几种方法)
2009/11/26 Javascript
JavaScript 大数据相加的问题
2011/08/03 Javascript
解决js正则匹配换行问题实现代码
2012/12/10 Javascript
onbeforeunload与onunload事件异同点总结
2013/06/24 Javascript
javascript中Number对象的toString()方法分析
2014/12/20 Javascript
jQuery实现个性翻牌效果导航菜单的方法
2015/03/09 Javascript
JavaScript 变量、作用域及内存
2015/04/08 Javascript
基于BootStrap Metronic开发框架经验小结【二】列表分页处理和插件JSTree的使用
2016/05/12 Javascript
详解webpack进阶之loader篇
2017/08/23 Javascript
Angularjs按需查询实例代码
2017/10/30 Javascript
Layer组件多个iframe弹出层打开与关闭及参数传递的方法
2019/09/25 Javascript
详解uniapp的全局变量实现方式
2021/01/11 Javascript
django之session与分页(实例讲解)
2017/11/13 Python
《Python学习手册》学习总结
2018/01/17 Python
对Python 网络设备巡检脚本的实例讲解
2018/04/22 Python
HTML5 Canvas中使用用路径描画圆弧
2015/01/01 HTML / CSS
床上用品全球在线购物:BeddingInn
2016/12/18 全球购物
Shopee马来西亚:随拍即卖,最佳行动电商拍卖平台
2017/06/05 全球购物
ghd官网:英国ghd直发器品牌
2018/05/04 全球购物
计算机网络毕业生自荐信
2013/10/01 职场文书
教师师德反思材料
2014/02/15 职场文书
出纳员岗位职责
2014/03/13 职场文书
《画风》教学反思
2014/04/16 职场文书
献爱心捐款倡议书
2014/05/14 职场文书
国际残疾人日广播稿范文
2014/10/09 职场文书
基层党支部整改方案
2014/10/25 职场文书
公司慰问信范文
2015/03/23 职场文书
2015年药店工作总结
2015/04/20 职场文书
2019年描写人生经典诗句大全
2019/07/08 职场文书
《月歌。》宣布制作10周年纪念剧场版《RABBITS KINGDOM THE MOVIE》
2022/04/02 日漫