基于Tensorflow的MNIST手写数字识别分类


Posted in Python onJune 17, 2020

本文实例为大家分享了基于Tensorflow的MNIST手写数字识别分类的具体实现代码,供大家参考,具体内容如下

代码如下:

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.contrib.tensorboard.plugins import projector
import time

IMAGE_PIXELS = 28
hidden_unit = 100
output_nums = 10
learning_rate = 0.001
train_steps = 50000
batch_size = 500
test_data_size = 10000
#日志目录(这里根据自己的目录修改)
logdir = 'D:/Develop_Software/Anaconda3/WorkDirectory/summary/mnist'
#导入mnist数据
mnist = input_data.read_data_sets('MNIST_data', one_hot = True)

 #全局训练步数
global_step = tf.Variable(0, name = 'global_step', trainable = False)
with tf.name_scope('input'):
 #输入数据
 with tf.name_scope('x'):
 x = tf.placeholder(
  dtype = tf.float32, shape = (None, IMAGE_PIXELS * IMAGE_PIXELS))
 #收集x图像的会总数据
 with tf.name_scope('x_summary'):
 shaped_image_batch = tf.reshape(
  tensor = x,
  shape = (-1, IMAGE_PIXELS, IMAGE_PIXELS, 1),
  name = 'shaped_image_batch')
 tf.summary.image(name = 'image_summary',
      tensor = shaped_image_batch,
      max_outputs = 10)
 with tf.name_scope('y_'):
 y_ = tf.placeholder(dtype = tf.float32, shape = (None, 10))

with tf.name_scope('hidden_layer'):
 with tf.name_scope('hidden_arg'):
 #隐层模型参数
 with tf.name_scope('hid_w'):
  
  hid_w = tf.Variable(
   tf.truncated_normal(shape = (IMAGE_PIXELS * IMAGE_PIXELS, hidden_unit)),
   name = 'hidden_w')
  #添加获取隐层权重统计值汇总数据的汇总操作
  tf.summary.histogram(name = 'weights', values = hid_w)
  with tf.name_scope('hid_b'):
  hid_b = tf.Variable(tf.zeros(shape = (1, hidden_unit), dtype = tf.float32),
       name = 'hidden_b')
 #隐层输出
 with tf.name_scope('relu'):
 hid_out = tf.nn.relu(tf.matmul(x, hid_w) + hid_b)
with tf.name_scope('softmax_layer'):
 with tf.name_scope('softmax_arg'):
 #softmax层参数
 with tf.name_scope('sm_w'):
  
  sm_w = tf.Variable(
   tf.truncated_normal(shape = (hidden_unit, output_nums)),
   name = 'softmax_w')
  #添加获取softmax层权重统计值汇总数据的汇总操作
  tf.summary.histogram(name = 'weights', values = sm_w)
  with tf.name_scope('sm_b'):
  sm_b = tf.Variable(tf.zeros(shape = (1, output_nums), dtype = tf.float32), 
       name = 'softmax_b')
 #softmax层的输出
 with tf.name_scope('softmax'):
 y = tf.nn.softmax(tf.matmul(hid_out, sm_w) + sm_b)
 #梯度裁剪,因为概率取值为[0, 1]为避免出现无意义的log(0),故将y值裁剪到[1e-10, 1]
 y_clip = tf.clip_by_value(y, 1.0e-10, 1 - 1.0e-5)
with tf.name_scope('cross_entropy'):
 #使用交叉熵代价函数
 cross_entropy = -tf.reduce_sum(y_ * tf.log(y_clip) + (1 - y_) * tf.log(1 - y_clip))
 #添加获取交叉熵的汇总操作
 tf.summary.scalar(name = 'cross_entropy', tensor = cross_entropy)
 
with tf.name_scope('train'):
 #若不使用同步训练机制,使用Adam优化器
 optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)
 #单步训练操作,
 train_op = optimizer.minimize(cross_entropy, global_step = global_step)
#加载测试数据
test_image = mnist.test.images
test_label = mnist.test.labels
test_feed = {x:test_image, y_:test_label}

with tf.name_scope('accuracy'):
 prediction = tf.equal(tf.argmax(input = y, axis = 1),
      tf.argmax(input = y_, axis = 1))
 accuracy = tf.reduce_mean(
  input_tensor = tf.cast(x = prediction, dtype = tf.float32))
#创建嵌入变量
embedding_var = tf.Variable(test_image, trainable = False, name = 'embedding')
saver = tf.train.Saver({'embedding':embedding_var})
#创建元数据文件,将MNIST图像测试集对应的标签写入文件
def CreateMedaDataFile():
 with open(logdir + '/metadata.tsv', 'w') as f:
 label = np.nonzero(test_label)[1]
 for i in range(test_data_size):
  f.write('%d\n' % label[i])
#创建投影配置参数
def CreateProjectorConfig():
 config = projector.ProjectorConfig()
 embeddings = config.embeddings.add()
 embeddings.tensor_name = 'embedding:0'
 embeddings.metadata_path = logdir + '/metadata.tsv'
 
 projector.visualize_embeddings(writer, config)
 #聚集汇总操作
merged = tf.summary.merge_all()
#创建会话的配置参数
sess_config = tf.ConfigProto(
 allow_soft_placement = True,
 log_device_placement = False)
#创建会话
with tf.Session(config = sess_config) as sess:
 #创建FileWriter实例
 writer = tf.summary.FileWriter(logdir = logdir, graph = sess.graph)
 #初始化全局变量
 sess.run(tf.global_variables_initializer())
 time_begin = time.time()
 print('Training begin time: %f' % time_begin)
 while True:
 #加载训练批数据
 batch_x, batch_y = mnist.train.next_batch(batch_size)
 train_feed = {x:batch_x, y_:batch_y}
 loss, _, summary= sess.run([cross_entropy, train_op, merged], feed_dict = train_feed)
 step = global_step.eval()
 #如果step为100的整数倍
 if step % 100 == 0:
  now = time.time()
  print('%f: global_step = %d, loss = %f' % (
   now, step, loss))
  #向事件文件中添加汇总数据
  writer.add_summary(summary = summary, global_step = step)
 #若大于等于训练总步数,退出训练
 if step >= train_steps:
  break
 time_end = time.time()
 print('Training end time: %f' % time_end)
 print('Training time: %f' % (time_end - time_begin))
 #测试模型精度
 test_accuracy = sess.run(accuracy, feed_dict = test_feed)
 print('accuracy: %f' % test_accuracy)
 
 saver.save(sess = sess, save_path = logdir + '/embedding_var.ckpt')
 CreateMedaDataFile()
 CreateProjectorConfig()
 #关闭FileWriter
 writer.close()

基于Tensorflow的MNIST手写数字识别分类

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
举例讲解Python中is和id的用法
Apr 03 Python
Python基于回溯法子集树模板解决野人与传教士问题示例
Sep 11 Python
python中reduce()函数的使用方法示例
Sep 29 Python
Python实现读取字符串按列分配后按行输出示例
Apr 17 Python
pandas数据处理基础之筛选指定行或者指定列的数据
May 03 Python
Python使用matplotlib和pandas实现的画图操作【经典示例】
Jun 13 Python
python3 使用openpyxl将mysql数据写入xlsx的操作
May 15 Python
如何快速理解python的垃圾回收机制
Sep 01 Python
Pytest allure 命令行参数的使用
Apr 18 Python
详解分布式系统中如何用python实现Paxos
May 18 Python
pytorch 实现在测试的时候启用dropout
May 27 Python
使用Python获取字典键对应值的方法
Apr 26 Python
Kears 使用:通过回调函数保存最佳准确率下的模型操作
Jun 17 #Python
Python多线程threading创建及使用方法解析
Jun 17 #Python
Python偏函数Partial function使用方法实例详解
Jun 17 #Python
详解Python IO口多路复用
Jun 17 #Python
基于keras中的回调函数用法说明
Jun 17 #Python
Python学习之路安装pycharm的教程详解
Jun 17 #Python
Python闭包及装饰器运行原理解析
Jun 17 #Python
You might like
一个很方便的 XML 类!!原创的噢
2006/10/09 PHP
实用函数2
2007/11/08 PHP
php simplexmlElement操作xml的命名空间实现代码
2011/01/04 PHP
PHP定时执行计划任务的多种方法小结
2011/12/19 PHP
PHP随机字符串生成代码(包括大小写字母)
2013/06/24 PHP
PHP写的资源下载防盗链类分享
2014/05/12 PHP
PHP概率计算函数汇总
2015/09/13 PHP
Thinkphp微信公众号支付接口
2016/08/04 PHP
PHP单例模式简单用法示例
2017/06/23 PHP
javascript 字符串连接的性能问题(多浏览器)
2008/11/18 Javascript
javascript下4个跨浏览器必备的函数
2010/03/07 Javascript
JS模板实现方法
2013/04/03 Javascript
jquery实现文字由下到上循环滚动的实例代码
2013/08/09 Javascript
Ajax局部更新导致JS事件重复触发问题的解决方法
2014/10/14 Javascript
angularJS 中$scope方法使用指南
2015/02/09 Javascript
对layer弹出框中icon数字参数的说明介绍
2019/09/04 Javascript
javascript操作元素的常见方法小结
2019/11/13 Javascript
如何在vue 中使用柱状图 并自修改配置
2021/01/21 Vue.js
在Python的列表中利用remove()方法删除元素的教程
2015/05/21 Python
python 实现A*算法的示例代码
2018/08/13 Python
YUV转为jpg图像的实现
2019/12/09 Python
Python3+selenium配置常见报错解决方案
2020/08/28 Python
Python内置函数property()如何使用
2020/09/01 Python
Matlab使用Plot函数实现数据动态显示方法总结
2021/02/25 Python
基于HTML5 的人脸识别活体认证的实现方法
2016/06/22 HTML / CSS
Calphalon美国官网:美国顶级锅具品牌
2020/02/05 全球购物
高三生物教学反思
2014/01/25 职场文书
单位工程竣工验收方案
2014/03/16 职场文书
幼儿园大班毕业教师寄语
2014/04/03 职场文书
绿色家庭事迹材料
2014/05/01 职场文书
考生诚信考试承诺书
2014/05/23 职场文书
竞选学委演讲稿
2014/09/13 职场文书
现实表现材料范文
2014/12/23 职场文书
2015年团支书工作总结
2015/04/03 职场文书
python 遍历磁盘目录的三种方法
2021/04/02 Python
MySQL的存储过程和相关函数
2022/04/26 MySQL