TensorFlow MNIST手写数据集的实现方法


Posted in Python onFebruary 05, 2020

MNIST数据集介绍

MNIST数据集中包含了各种各样的手写数字图片,数据集的官网是:http://yann.lecun.com/exdb/mnist/index.html,我们可以从这里下载数据集。使用如下的代码对数据集进行加载:

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

运行上述代码会自动下载数据集并将文件解压在MNIST_data文件夹下面。代码中的one_hot=True,表示将样本的标签转化为one_hot编码。

MNIST数据集中的图片是28*28的,每张图被转化为一个行向量,长度是28*28=784,每一个值代表一个像素点。数据集中共有60000张手写数据图片,其中55000张训练数据,5000张测试数据。

在MNIST中,mnist.train.images是一个形状为[55000, 784]的张量,其中的第一个维度是用来索引图片,第二个维度图片中的像素。MNIST数据集包含有三部分,训练数据集,验证数据集,测试数据集(mnist.validation)。

标签是介于0-9之间的数字,用于描述图片中的数字,转化为one-hot向量即表示的数字对应的下标为1,其余的值为0。标签的训练数据是[55000,10]的数字矩阵。

下面定义了一个简单的网络对数据集进行训练,代码如下:

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
tf.reset_default_graph()
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])
w = tf.Variable(tf.random_normal([784, 10]))
b = tf.Variable(tf.zeros([10]))
pred = tf.matmul(x, w) + b
pred = tf.nn.softmax(pred)
cost = tf.reduce_mean(-tf.reduce_sum(y * tf.log(pred), reduction_indices=1))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
training_epochs = 25
batch_size = 100
display_step = 1
save_path = 'model/'
saver = tf.train.Saver()
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  for epoch in range(training_epochs):
    avg_cost = 0
    total_batch = int(mnist.train.num_examples/batch_size)
    for i in range(total_batch):
      batch_xs, batch_ys = mnist.train.next_batch(batch_size)
      _, c = sess.run([optimizer, cost], feed_dict={x:batch_xs, y:batch_ys})
      avg_cost += c / total_batch
    if (epoch + 1) % display_step == 0:
      print('epoch= ', epoch+1, ' cost= ', avg_cost)
  print('finished')
  correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
  accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  print('accuracy: ', accuracy.eval({x:mnist.test.images, y:mnist.test.labels}))
  save = saver.save(sess, save_path=save_path+'mnist.cpkt')
print(" starting 2nd session ...... ")
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  saver.restore(sess, save_path=save_path+'mnist.cpkt')
  correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
  accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  print('accuracy: ', accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))
  output = tf.argmax(pred, 1)
  batch_xs, batch_ys = mnist.test.next_batch(2)
  outputval= sess.run([output], feed_dict={x:batch_xs, y:batch_ys})
  print(outputval)
  im = batch_xs[0]
  im = im.reshape(-1, 28)
  plt.imshow(im, cmap='gray')
  plt.show()
  im = batch_xs[1]
  im = im.reshape(-1, 28)
  plt.imshow(im, cmap='gray')
  plt.show()

总结

以上所述是小编给大家介绍的TensorFlow MNIST手写数据集的实现方法,希望对大家有所帮助!

Python 相关文章推荐
Python使用cx_Oracle模块将oracle中数据导出到csv文件的方法
May 16 Python
python 实时遍历日志文件
Apr 12 Python
在CentOS上配置Nginx+Gunicorn+Python+Flask环境的教程
Jun 07 Python
定制FileField中的上传文件名称实例
Aug 23 Python
python自制包并用pip免提交到pypi仅安装到本机【推荐】
Jun 03 Python
浅析Python与Mongodb数据库之间的操作方法
Jul 01 Python
Python中的单下划线和双下划线使用场景详解
Sep 09 Python
python3中利用filter函数输出小于某个数的所有回文数实例
Nov 24 Python
python用pip install时安装失败的一系列问题及解决方法
Feb 24 Python
Python tkinter布局与按钮间距设置方式
Mar 04 Python
在django admin详情表单显示中添加自定义控件的实现
Mar 11 Python
Python爬虫:从m3u8文件里提取小视频的正确操作
May 14 Python
tensorflow之并行读入数据详解
Feb 05 #Python
tensorflow mnist 数据加载实现并画图效果
Feb 05 #Python
tensorflow 自定义损失函数示例代码
Feb 05 #Python
利用Tensorflow的队列多线程读取数据方式
Feb 05 #Python
Tensorflow 多线程与多进程数据加载实例
Feb 05 #Python
TensorFlow自定义损失函数来预测商品销售量
Feb 05 #Python
解决Tensorflow 内存泄露问题
Feb 05 #Python
You might like
php GD绘制24小时柱状图
2008/06/28 PHP
php仿ZOL分页类代码
2008/10/02 PHP
php实现建立多层级目录的方法
2014/07/19 PHP
Yii框架中用response保存cookie,用request读取cookie的原理解析
2019/09/04 PHP
jQuery EasyUI API 中文文档 - MenuButton菜单按钮使用介绍
2011/10/06 Javascript
jQuery+CSS 半开折叠效果原理及代码(自写)
2013/03/04 Javascript
js 事件截取enter按键页面提交事件示例代码
2014/03/04 Javascript
使用jQuery不判断浏览器高度解决iframe自适应高度问题
2014/12/16 Javascript
js对象继承之原型链继承实例
2015/01/10 Javascript
Vue非父子组件通信详解
2017/06/12 Javascript
自定义vue全局组件use使用、vuex的使用详解
2017/06/14 Javascript
jQuery EasyUI 选项卡面板tabs的使用实例讲解
2017/12/25 jQuery
Three.js 再探 - 写一个微信跳一跳极简版游戏
2018/01/04 Javascript
vue轮播图插件vue-concise-slider的使用
2018/03/13 Javascript
使用VUE+iView+.Net Core上传图片的方法示例
2019/01/04 Javascript
vue的滚动条插件实现代码
2019/09/07 Javascript
jquery实现吸顶导航效果
2020/01/08 jQuery
jQuery开发仿QQ版音乐播放器
2020/07/10 jQuery
Python实现子类调用父类的方法
2014/11/10 Python
python实现将英文单词表示的数字转换成阿拉伯数字的方法
2015/07/02 Python
安装ElasticSearch搜索工具并配置Python驱动的方法
2015/12/22 Python
pycharm远程调试openstack的图文教程
2017/11/21 Python
Python+selenium 获取浏览器窗口坐标、句柄的方法
2018/10/14 Python
在python中利用最小二乘拟合二次抛物线函数的方法
2018/12/29 Python
Python当中的array数组对象实例详解
2019/06/12 Python
对Python函数设计规范详解
2019/07/19 Python
Python识别html主要文本框过程解析
2020/02/18 Python
澳大利亚快时尚鞋类市场:Billini
2018/05/20 全球购物
统计岗位职责
2014/02/21 职场文书
公司节能减排倡议书
2014/05/14 职场文书
学生自我鉴定格式及范文
2014/09/16 职场文书
群众路线自我剖析范文
2014/11/04 职场文书
运动会新闻稿
2015/07/17 职场文书
MongoDB数据库常用的10条操作命令
2021/06/18 MongoDB
MySQL系列之十三 MySQL的复制
2021/07/02 MySQL
Python创建SQL数据库流程逐步讲解
2022/09/23 Python