TensorFlow MNIST手写数据集的实现方法


Posted in Python onFebruary 05, 2020

MNIST数据集介绍

MNIST数据集中包含了各种各样的手写数字图片,数据集的官网是:http://yann.lecun.com/exdb/mnist/index.html,我们可以从这里下载数据集。使用如下的代码对数据集进行加载:

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

运行上述代码会自动下载数据集并将文件解压在MNIST_data文件夹下面。代码中的one_hot=True,表示将样本的标签转化为one_hot编码。

MNIST数据集中的图片是28*28的,每张图被转化为一个行向量,长度是28*28=784,每一个值代表一个像素点。数据集中共有60000张手写数据图片,其中55000张训练数据,5000张测试数据。

在MNIST中,mnist.train.images是一个形状为[55000, 784]的张量,其中的第一个维度是用来索引图片,第二个维度图片中的像素。MNIST数据集包含有三部分,训练数据集,验证数据集,测试数据集(mnist.validation)。

标签是介于0-9之间的数字,用于描述图片中的数字,转化为one-hot向量即表示的数字对应的下标为1,其余的值为0。标签的训练数据是[55000,10]的数字矩阵。

下面定义了一个简单的网络对数据集进行训练,代码如下:

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
tf.reset_default_graph()
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])
w = tf.Variable(tf.random_normal([784, 10]))
b = tf.Variable(tf.zeros([10]))
pred = tf.matmul(x, w) + b
pred = tf.nn.softmax(pred)
cost = tf.reduce_mean(-tf.reduce_sum(y * tf.log(pred), reduction_indices=1))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
training_epochs = 25
batch_size = 100
display_step = 1
save_path = 'model/'
saver = tf.train.Saver()
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  for epoch in range(training_epochs):
    avg_cost = 0
    total_batch = int(mnist.train.num_examples/batch_size)
    for i in range(total_batch):
      batch_xs, batch_ys = mnist.train.next_batch(batch_size)
      _, c = sess.run([optimizer, cost], feed_dict={x:batch_xs, y:batch_ys})
      avg_cost += c / total_batch
    if (epoch + 1) % display_step == 0:
      print('epoch= ', epoch+1, ' cost= ', avg_cost)
  print('finished')
  correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
  accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  print('accuracy: ', accuracy.eval({x:mnist.test.images, y:mnist.test.labels}))
  save = saver.save(sess, save_path=save_path+'mnist.cpkt')
print(" starting 2nd session ...... ")
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  saver.restore(sess, save_path=save_path+'mnist.cpkt')
  correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
  accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  print('accuracy: ', accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))
  output = tf.argmax(pred, 1)
  batch_xs, batch_ys = mnist.test.next_batch(2)
  outputval= sess.run([output], feed_dict={x:batch_xs, y:batch_ys})
  print(outputval)
  im = batch_xs[0]
  im = im.reshape(-1, 28)
  plt.imshow(im, cmap='gray')
  plt.show()
  im = batch_xs[1]
  im = im.reshape(-1, 28)
  plt.imshow(im, cmap='gray')
  plt.show()

总结

以上所述是小编给大家介绍的TensorFlow MNIST手写数据集的实现方法,希望对大家有所帮助!

Python 相关文章推荐
Python操作Mysql实例代码教程在线版(查询手册)
Feb 18 Python
python模拟登录百度贴吧(百度贴吧登录)实例
Dec 18 Python
关于Django显示时间你应该知道的一些问题
Dec 25 Python
深入浅析Python2.x和3.x版本的主要区别
Nov 30 Python
Python解析json时提示“string indices must be integers”问题解决方法
Jul 31 Python
Python文本处理简单易懂方法解析
Dec 19 Python
Python实现ATM系统
Feb 17 Python
Python可以实现栈的结构吗
May 27 Python
安装pyinstaller遇到的各种问题(小结)
Nov 20 Python
cookies应对python反爬虫知识点详解
Nov 25 Python
使用Python获取字典键对应值的方法
Apr 26 Python
python数字图像处理实现图像的形变与缩放
Jun 28 Python
tensorflow之并行读入数据详解
Feb 05 #Python
tensorflow mnist 数据加载实现并画图效果
Feb 05 #Python
tensorflow 自定义损失函数示例代码
Feb 05 #Python
利用Tensorflow的队列多线程读取数据方式
Feb 05 #Python
Tensorflow 多线程与多进程数据加载实例
Feb 05 #Python
TensorFlow自定义损失函数来预测商品销售量
Feb 05 #Python
解决Tensorflow 内存泄露问题
Feb 05 #Python
You might like
PHP响应post请求上传文件的方法
2015/12/17 PHP
PHP简单实现无限级分类的方法
2016/05/13 PHP
老生常谈PHP中的数据结构:DS扩展
2017/07/17 PHP
PHP微信H5支付开发实例
2018/07/25 PHP
PHP下用Swoole实现Actor并发模型的方法
2019/06/12 PHP
JavaScript XML操作 封装类
2009/07/01 Javascript
js DataSet数据源处理代码
2010/03/29 Javascript
js ondocumentready onmouseover onclick onmouseout 样式
2010/07/22 Javascript
javascript生成随机大小写字母的方法
2014/02/20 Javascript
JavaScript数组方法大全(推荐)
2016/07/05 Javascript
jQuery.Validate表单验证插件的使用示例详解
2017/01/04 Javascript
vue使用watch 观察路由变化,重新获取内容
2017/03/08 Javascript
Webpack常见静态资源处理-模块加载器(Loaders)+ExtractTextPlugin插件
2017/06/29 Javascript
基于原生js实现判断元素是否有指定class名
2020/07/11 Javascript
[52:06]完美世界DOTA2联赛决赛日 Inki vs LBZS 第一场 11.08
2020/11/10 DOTA
Python 调用DLL操作抄表机
2009/01/12 Python
用python登录Dr.com思路以及代码分享
2014/06/25 Python
Python 实现文件的全备份和差异备份详解
2016/12/27 Python
开源软件包和环境管理系统Anaconda的安装使用
2017/09/04 Python
Python内置函数reversed()用法分析
2018/03/20 Python
解决python升级引起的pip执行错误的问题
2018/06/12 Python
在IPython中进行Python程序执行时间的测量方法
2018/11/01 Python
Django框架搭建的简易图书信息网站案例
2019/05/25 Python
django admin后管定制-显示字段的实例
2020/03/11 Python
Window10上Tensorflow的安装(CPU和GPU版本)
2020/12/15 Python
canvas学习和滤镜实现代码
2018/08/22 HTML / CSS
HTML5实现移动端点击翻牌功能
2020/10/23 HTML / CSS
英文自荐信
2013/12/19 职场文书
优秀学生干部推荐材料
2014/02/03 职场文书
应届大专生求职信
2014/06/26 职场文书
分居协议书范本
2014/11/03 职场文书
学校端午节活动总结
2015/02/11 职场文书
2015年行政工作总结范文
2015/04/09 职场文书
nginx处理http请求实现过程解析
2021/03/31 Servers
MySQL系列之二 多实例配置
2021/07/02 MySQL
HTML基础详解(上)
2021/10/16 HTML / CSS