TensorFlow MNIST手写数据集的实现方法


Posted in Python onFebruary 05, 2020

MNIST数据集介绍

MNIST数据集中包含了各种各样的手写数字图片,数据集的官网是:http://yann.lecun.com/exdb/mnist/index.html,我们可以从这里下载数据集。使用如下的代码对数据集进行加载:

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

运行上述代码会自动下载数据集并将文件解压在MNIST_data文件夹下面。代码中的one_hot=True,表示将样本的标签转化为one_hot编码。

MNIST数据集中的图片是28*28的,每张图被转化为一个行向量,长度是28*28=784,每一个值代表一个像素点。数据集中共有60000张手写数据图片,其中55000张训练数据,5000张测试数据。

在MNIST中,mnist.train.images是一个形状为[55000, 784]的张量,其中的第一个维度是用来索引图片,第二个维度图片中的像素。MNIST数据集包含有三部分,训练数据集,验证数据集,测试数据集(mnist.validation)。

标签是介于0-9之间的数字,用于描述图片中的数字,转化为one-hot向量即表示的数字对应的下标为1,其余的值为0。标签的训练数据是[55000,10]的数字矩阵。

下面定义了一个简单的网络对数据集进行训练,代码如下:

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
tf.reset_default_graph()
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])
w = tf.Variable(tf.random_normal([784, 10]))
b = tf.Variable(tf.zeros([10]))
pred = tf.matmul(x, w) + b
pred = tf.nn.softmax(pred)
cost = tf.reduce_mean(-tf.reduce_sum(y * tf.log(pred), reduction_indices=1))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
training_epochs = 25
batch_size = 100
display_step = 1
save_path = 'model/'
saver = tf.train.Saver()
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  for epoch in range(training_epochs):
    avg_cost = 0
    total_batch = int(mnist.train.num_examples/batch_size)
    for i in range(total_batch):
      batch_xs, batch_ys = mnist.train.next_batch(batch_size)
      _, c = sess.run([optimizer, cost], feed_dict={x:batch_xs, y:batch_ys})
      avg_cost += c / total_batch
    if (epoch + 1) % display_step == 0:
      print('epoch= ', epoch+1, ' cost= ', avg_cost)
  print('finished')
  correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
  accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  print('accuracy: ', accuracy.eval({x:mnist.test.images, y:mnist.test.labels}))
  save = saver.save(sess, save_path=save_path+'mnist.cpkt')
print(" starting 2nd session ...... ")
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  saver.restore(sess, save_path=save_path+'mnist.cpkt')
  correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
  accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  print('accuracy: ', accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))
  output = tf.argmax(pred, 1)
  batch_xs, batch_ys = mnist.test.next_batch(2)
  outputval= sess.run([output], feed_dict={x:batch_xs, y:batch_ys})
  print(outputval)
  im = batch_xs[0]
  im = im.reshape(-1, 28)
  plt.imshow(im, cmap='gray')
  plt.show()
  im = batch_xs[1]
  im = im.reshape(-1, 28)
  plt.imshow(im, cmap='gray')
  plt.show()

总结

以上所述是小编给大家介绍的TensorFlow MNIST手写数据集的实现方法,希望对大家有所帮助!

Python 相关文章推荐
python实现代理服务功能实例
Nov 15 Python
python格式化字符串实例总结
Sep 28 Python
Python基于高斯消元法计算线性方程组示例
Jan 17 Python
python实现根据文件关键字进行切分为多个文件的示例
Dec 10 Python
利用 Flask 动态展示 Pyecharts 图表数据方法小结
Sep 04 Python
Python unittest 自动识别并执行测试用例方式
Mar 09 Python
python使用梯度下降算法实现一个多线性回归
Mar 24 Python
关于python tushare Tkinter构建的简单股票可视化查询系统(Beta v0.13)
Oct 19 Python
python时间time模块处理大全
Oct 25 Python
Django跨域请求原理及实现代码
Nov 14 Python
python3 os进行嵌套操作的实例讲解
Nov 19 Python
详解Open Folder as PyCharm Project怎么添加的方法
Dec 29 Python
tensorflow之并行读入数据详解
Feb 05 #Python
tensorflow mnist 数据加载实现并画图效果
Feb 05 #Python
tensorflow 自定义损失函数示例代码
Feb 05 #Python
利用Tensorflow的队列多线程读取数据方式
Feb 05 #Python
Tensorflow 多线程与多进程数据加载实例
Feb 05 #Python
TensorFlow自定义损失函数来预测商品销售量
Feb 05 #Python
解决Tensorflow 内存泄露问题
Feb 05 #Python
You might like
PHP中如何防止外部恶意提交调用ajax接口
2016/04/11 PHP
PHP支付系统设计与典型案例分享
2016/08/02 PHP
javascript 浏览器判断 绑定事件 arguments 转换数组 数组遍历
2009/07/06 Javascript
IE6图片加载的一个BUG解决方法
2010/07/13 Javascript
jqgrid 简单学习笔记
2011/05/03 Javascript
读jQuery之六 缓存数据功能介绍
2011/06/21 Javascript
jquery 表单验证之通过 class验证表单不为空
2015/11/02 Javascript
[原创]jQuery常用的4种加载方式分析
2016/07/25 Javascript
javascript之with的使用(阿里云、淘宝使用代码分析)
2016/10/11 Javascript
Javascript使用SWFUpload进行多文件上传
2016/11/16 Javascript
基于模板引擎Jade的应用(详解)
2017/12/12 Javascript
详解微信小程序审核不通过的解决方法
2018/01/17 Javascript
基于cropper.js封装vue实现在线图片裁剪组件功能
2018/03/01 Javascript
bootstrap select2插件用ajax来获取和显示数据的实例
2018/08/09 Javascript
VeeValidate 的使用场景以及配置详解
2019/01/11 Javascript
vue实现页面滚动到底部刷新
2019/08/16 Javascript
vue 实现cli3.0中使用proxy进行代理转发
2019/10/30 Javascript
js实现类选择器和name属性选择器的示例步骤
2021/02/07 Javascript
[10:39]DOTA2上海特级锦标赛音乐会纪录片
2016/03/21 DOTA
[34:10]Secret vs VG 2019国际邀请赛淘汰赛 败者组 BO3 第二场 8.24
2019/09/10 DOTA
Python中用Ctrl+C终止多线程程序的问题解决
2013/03/30 Python
儿童python练习实例
2018/05/27 Python
python实现在cmd窗口显示彩色文字
2019/06/24 Python
flask 实现token机制的示例代码
2019/11/07 Python
Python中的 ansible 动态Inventory 脚本
2020/01/19 Python
详解python常用命令行选项与环境变量
2020/02/20 Python
Smashbox官网:美国知名彩妆品牌
2017/01/05 全球购物
POP文化和音乐灵感的时尚:Hot Topic
2019/06/19 全球购物
请编程遍历页面上所有 TextBox 控件并给它赋值为 string.Empty
2015/12/03 面试题
优秀高中生事迹材料
2014/02/11 职场文书
寒假家长评语大全
2014/04/16 职场文书
一般基层干部群众路线教育实践活动个人对照检查材料
2014/11/04 职场文书
物业客服专员岗位职责
2015/04/07 职场文书
惹女朋友生气检讨书
2015/05/06 职场文书
公证书
2019/04/17 职场文书
pytorch交叉熵损失函数的weight参数的使用
2021/05/24 Python