TensorFlow MNIST手写数据集的实现方法


Posted in Python onFebruary 05, 2020

MNIST数据集介绍

MNIST数据集中包含了各种各样的手写数字图片,数据集的官网是:http://yann.lecun.com/exdb/mnist/index.html,我们可以从这里下载数据集。使用如下的代码对数据集进行加载:

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

运行上述代码会自动下载数据集并将文件解压在MNIST_data文件夹下面。代码中的one_hot=True,表示将样本的标签转化为one_hot编码。

MNIST数据集中的图片是28*28的,每张图被转化为一个行向量,长度是28*28=784,每一个值代表一个像素点。数据集中共有60000张手写数据图片,其中55000张训练数据,5000张测试数据。

在MNIST中,mnist.train.images是一个形状为[55000, 784]的张量,其中的第一个维度是用来索引图片,第二个维度图片中的像素。MNIST数据集包含有三部分,训练数据集,验证数据集,测试数据集(mnist.validation)。

标签是介于0-9之间的数字,用于描述图片中的数字,转化为one-hot向量即表示的数字对应的下标为1,其余的值为0。标签的训练数据是[55000,10]的数字矩阵。

下面定义了一个简单的网络对数据集进行训练,代码如下:

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
tf.reset_default_graph()
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])
w = tf.Variable(tf.random_normal([784, 10]))
b = tf.Variable(tf.zeros([10]))
pred = tf.matmul(x, w) + b
pred = tf.nn.softmax(pred)
cost = tf.reduce_mean(-tf.reduce_sum(y * tf.log(pred), reduction_indices=1))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
training_epochs = 25
batch_size = 100
display_step = 1
save_path = 'model/'
saver = tf.train.Saver()
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  for epoch in range(training_epochs):
    avg_cost = 0
    total_batch = int(mnist.train.num_examples/batch_size)
    for i in range(total_batch):
      batch_xs, batch_ys = mnist.train.next_batch(batch_size)
      _, c = sess.run([optimizer, cost], feed_dict={x:batch_xs, y:batch_ys})
      avg_cost += c / total_batch
    if (epoch + 1) % display_step == 0:
      print('epoch= ', epoch+1, ' cost= ', avg_cost)
  print('finished')
  correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
  accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  print('accuracy: ', accuracy.eval({x:mnist.test.images, y:mnist.test.labels}))
  save = saver.save(sess, save_path=save_path+'mnist.cpkt')
print(" starting 2nd session ...... ")
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  saver.restore(sess, save_path=save_path+'mnist.cpkt')
  correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
  accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  print('accuracy: ', accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))
  output = tf.argmax(pred, 1)
  batch_xs, batch_ys = mnist.test.next_batch(2)
  outputval= sess.run([output], feed_dict={x:batch_xs, y:batch_ys})
  print(outputval)
  im = batch_xs[0]
  im = im.reshape(-1, 28)
  plt.imshow(im, cmap='gray')
  plt.show()
  im = batch_xs[1]
  im = im.reshape(-1, 28)
  plt.imshow(im, cmap='gray')
  plt.show()

总结

以上所述是小编给大家介绍的TensorFlow MNIST手写数据集的实现方法,希望对大家有所帮助!

Python 相关文章推荐
linux环境下安装pyramid和新建项目的步骤
Nov 27 Python
跟老齐学Python之大话题小函数(1)
Oct 10 Python
使用Python的Tornado框架实现一个简单的WebQQ机器人
Apr 24 Python
Python使用requests发送POST请求实例代码
Jan 25 Python
python 日志增量抓取实现方法
Apr 28 Python
python 美化输出信息的实例
Oct 15 Python
python实现名片管理系统
Nov 29 Python
将python文件打包exe独立运行程序方法详解
Feb 12 Python
Python多线程获取返回值代码实例
Feb 17 Python
利用python绘制中国地图(含省界、河流等)
Sep 21 Python
python 提高开发效率的5个小技巧
Oct 19 Python
python可视化 matplotlib画图使用colorbar工具自定义颜色
Dec 07 Python
tensorflow之并行读入数据详解
Feb 05 #Python
tensorflow mnist 数据加载实现并画图效果
Feb 05 #Python
tensorflow 自定义损失函数示例代码
Feb 05 #Python
利用Tensorflow的队列多线程读取数据方式
Feb 05 #Python
Tensorflow 多线程与多进程数据加载实例
Feb 05 #Python
TensorFlow自定义损失函数来预测商品销售量
Feb 05 #Python
解决Tensorflow 内存泄露问题
Feb 05 #Python
You might like
php数组函数序列之array_combine() - 数组合并函数使用说明
2011/10/29 PHP
php数组转换js数组操作及json_encode的用法详解
2013/10/26 PHP
PHP上传图片进行等比缩放可增加水印功能
2014/01/13 PHP
基于linnux+phantomjs实现生成图片格式的网页快照
2015/04/15 PHP
PHP实现简单实用的验证码类
2015/07/29 PHP
Zend Framework实现将session存储在memcache中的方法
2016/03/22 PHP
PHP文件操作简单介绍及函数汇总
2020/12/11 PHP
Track Image Loading效果代码分析
2007/08/13 Javascript
JavaScript 验证浏览器是否支持javascript的方法小结
2009/05/17 Javascript
学习ExtJS table布局
2009/10/08 Javascript
JS异常处理的一个想法(sofish)
2013/03/14 Javascript
使用javascript做的一个随机点名程序
2014/02/13 Javascript
jquery使用ajax实现微信自动回复插件
2014/04/28 Javascript
Javascript中的delete操作符详细介绍
2014/06/06 Javascript
jquery通过closest选择器修改上级元素的方法
2015/03/17 Javascript
全面解析Bootstrap排版使用方法(标题)
2015/11/30 Javascript
BootStrap 超链接变按钮的实现方法
2016/09/25 Javascript
jquery日历插件e-calendar升级版
2016/11/10 Javascript
javascript表达式和运算符详解
2017/02/07 Javascript
JavaScript 函数的定义-调用、注意事项
2017/04/16 Javascript
JS中的算法与数据结构之列表(List)实例详解
2019/08/16 Javascript
微信小程序进入广告实现代码实例
2019/09/19 Javascript
javascript中的with语句学习笔记及用法
2020/02/17 Javascript
[53:38]OG vs LGD 2018国际邀请赛淘汰赛BO3 第三场 8.26
2018/08/30 DOTA
详解Python最长公共子串和最长公共子序列的实现
2018/07/07 Python
Pytorch 卷积中的 Input Shape用法
2020/06/29 Python
Pythonic版二分查找实现过程原理解析
2020/08/11 Python
详解python使用金山词霸的翻译功能(调试工具断点的使用)
2021/01/07 Python
Python3自带工具2to3.py 转换 Python2.x 代码到Python3的操作
2021/03/03 Python
俄罗斯披萨、寿司和面食送货到家服务:2 Берега
2019/12/15 全球购物
小学后勤管理制度
2014/01/14 职场文书
2014年两会学习心得范例
2014/03/17 职场文书
大学竞选班干部演讲稿
2014/08/21 职场文书
第28个世界无烟日活动总结
2015/02/10 职场文书
2015年纪委工作总结
2015/05/13 职场文书
人生遥控器观后感
2015/06/11 职场文书