TensorFlow MNIST手写数据集的实现方法


Posted in Python onFebruary 05, 2020

MNIST数据集介绍

MNIST数据集中包含了各种各样的手写数字图片,数据集的官网是:http://yann.lecun.com/exdb/mnist/index.html,我们可以从这里下载数据集。使用如下的代码对数据集进行加载:

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

运行上述代码会自动下载数据集并将文件解压在MNIST_data文件夹下面。代码中的one_hot=True,表示将样本的标签转化为one_hot编码。

MNIST数据集中的图片是28*28的,每张图被转化为一个行向量,长度是28*28=784,每一个值代表一个像素点。数据集中共有60000张手写数据图片,其中55000张训练数据,5000张测试数据。

在MNIST中,mnist.train.images是一个形状为[55000, 784]的张量,其中的第一个维度是用来索引图片,第二个维度图片中的像素。MNIST数据集包含有三部分,训练数据集,验证数据集,测试数据集(mnist.validation)。

标签是介于0-9之间的数字,用于描述图片中的数字,转化为one-hot向量即表示的数字对应的下标为1,其余的值为0。标签的训练数据是[55000,10]的数字矩阵。

下面定义了一个简单的网络对数据集进行训练,代码如下:

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
tf.reset_default_graph()
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])
w = tf.Variable(tf.random_normal([784, 10]))
b = tf.Variable(tf.zeros([10]))
pred = tf.matmul(x, w) + b
pred = tf.nn.softmax(pred)
cost = tf.reduce_mean(-tf.reduce_sum(y * tf.log(pred), reduction_indices=1))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
training_epochs = 25
batch_size = 100
display_step = 1
save_path = 'model/'
saver = tf.train.Saver()
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  for epoch in range(training_epochs):
    avg_cost = 0
    total_batch = int(mnist.train.num_examples/batch_size)
    for i in range(total_batch):
      batch_xs, batch_ys = mnist.train.next_batch(batch_size)
      _, c = sess.run([optimizer, cost], feed_dict={x:batch_xs, y:batch_ys})
      avg_cost += c / total_batch
    if (epoch + 1) % display_step == 0:
      print('epoch= ', epoch+1, ' cost= ', avg_cost)
  print('finished')
  correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
  accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  print('accuracy: ', accuracy.eval({x:mnist.test.images, y:mnist.test.labels}))
  save = saver.save(sess, save_path=save_path+'mnist.cpkt')
print(" starting 2nd session ...... ")
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  saver.restore(sess, save_path=save_path+'mnist.cpkt')
  correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
  accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  print('accuracy: ', accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))
  output = tf.argmax(pred, 1)
  batch_xs, batch_ys = mnist.test.next_batch(2)
  outputval= sess.run([output], feed_dict={x:batch_xs, y:batch_ys})
  print(outputval)
  im = batch_xs[0]
  im = im.reshape(-1, 28)
  plt.imshow(im, cmap='gray')
  plt.show()
  im = batch_xs[1]
  im = im.reshape(-1, 28)
  plt.imshow(im, cmap='gray')
  plt.show()

总结

以上所述是小编给大家介绍的TensorFlow MNIST手写数据集的实现方法,希望对大家有所帮助!

Python 相关文章推荐
python实现ftp客户端示例分享
Feb 17 Python
Python ORM框架SQLAlchemy学习笔记之映射类使用实例和Session会话介绍
Jun 10 Python
python中使用urllib2获取http请求状态码的代码例子
Jul 07 Python
Python中lambda的用法及其与def的区别解析
Jul 28 Python
Python数据类型之Dict字典实例详解
May 07 Python
使用pandas读取文件的实现
Jul 31 Python
python Qt5实现窗体跟踪鼠标移动
Dec 13 Python
python——全排列数的生成方式
Feb 26 Python
tensorflow指定CPU与GPU运算的方法实现
Apr 21 Python
关于Keras Dense层整理
May 21 Python
如何使用PyCharm引入需要使用的包的方法
Sep 22 Python
Python开发简易五子棋小游戏
May 02 Python
tensorflow之并行读入数据详解
Feb 05 #Python
tensorflow mnist 数据加载实现并画图效果
Feb 05 #Python
tensorflow 自定义损失函数示例代码
Feb 05 #Python
利用Tensorflow的队列多线程读取数据方式
Feb 05 #Python
Tensorflow 多线程与多进程数据加载实例
Feb 05 #Python
TensorFlow自定义损失函数来预测商品销售量
Feb 05 #Python
解决Tensorflow 内存泄露问题
Feb 05 #Python
You might like
PHP在XP下IIS和Apache2服务器上的安装
2006/09/05 PHP
ajax在joomla中的原生态应用代码
2012/07/19 PHP
2014年10个最佳的PHP图像操作库
2014/07/14 PHP
php过滤html标记属性类用法实例
2014/09/23 PHP
php开发中的页面跳转方法总结
2015/04/26 PHP
php+MySQL实现登录时验证登录名和密码是否正确
2016/05/10 PHP
php简单的上传类分享
2016/05/15 PHP
php魔法函数与魔法常量使用介绍
2017/07/23 PHP
用javascript编写的第一人称射击游戏
2007/02/25 Javascript
google地图的路线实现代码
2009/08/20 Javascript
JS自定义功能函数实现动态添加网址参数修改网址参数值
2013/08/02 Javascript
javascript的创建多行字符串的7种方法
2014/04/29 Javascript
深入剖析JavaScript:Object类型
2016/05/10 Javascript
jQuery使用$获取对象后检查该对象是否存在的实现方法
2016/09/04 Javascript
vue使用laydate时间插件的方法
2018/11/14 Javascript
element-ui树形控件后台返回的数据+生成组织树的工具类
2020/03/05 Javascript
Python Queue模块详细介绍及实例
2016/12/27 Python
Python获取指定文件夹下的文件名的方法
2018/02/06 Python
Python常见内置高效率函数用法示例
2018/07/31 Python
Linux下Python安装完成后使用pip命令的详细教程
2018/11/22 Python
python爬取酷狗音乐排行榜
2019/02/20 Python
Django框架自定义模型管理器与元选项用法分析
2019/07/22 Python
python3 map函数和filter函数详解
2019/08/26 Python
Python Gitlab Api 使用方法
2019/08/28 Python
Python要求O(n)复杂度求无序列表中第K的大元素实例
2020/04/02 Python
cookies应对python反爬虫知识点详解
2020/11/25 Python
Wilson体育用品官网:美国著名运动器材品牌
2019/05/12 全球购物
园林施工员岗位职责
2013/12/11 职场文书
应届本科生推荐信范文
2013/12/25 职场文书
乡镇干部十八大感言
2014/02/17 职场文书
触电现场处置方案
2014/05/14 职场文书
企业宣传标语
2014/06/09 职场文书
竞选大队干部演讲稿
2014/09/11 职场文书
民间借贷纠纷起诉书
2015/08/03 职场文书
《红领巾真好》教学反思
2016/02/16 职场文书
Python使用OpenCV实现虚拟缩放效果
2022/02/28 Python