基于Tensorflow的MNIST手写数字识别分类


Posted in Python onJune 17, 2020

本文实例为大家分享了基于Tensorflow的MNIST手写数字识别分类的具体实现代码,供大家参考,具体内容如下

代码如下:

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.contrib.tensorboard.plugins import projector
import time

IMAGE_PIXELS = 28
hidden_unit = 100
output_nums = 10
learning_rate = 0.001
train_steps = 50000
batch_size = 500
test_data_size = 10000
#日志目录(这里根据自己的目录修改)
logdir = 'D:/Develop_Software/Anaconda3/WorkDirectory/summary/mnist'
#导入mnist数据
mnist = input_data.read_data_sets('MNIST_data', one_hot = True)

 #全局训练步数
global_step = tf.Variable(0, name = 'global_step', trainable = False)
with tf.name_scope('input'):
 #输入数据
 with tf.name_scope('x'):
 x = tf.placeholder(
  dtype = tf.float32, shape = (None, IMAGE_PIXELS * IMAGE_PIXELS))
 #收集x图像的会总数据
 with tf.name_scope('x_summary'):
 shaped_image_batch = tf.reshape(
  tensor = x,
  shape = (-1, IMAGE_PIXELS, IMAGE_PIXELS, 1),
  name = 'shaped_image_batch')
 tf.summary.image(name = 'image_summary',
      tensor = shaped_image_batch,
      max_outputs = 10)
 with tf.name_scope('y_'):
 y_ = tf.placeholder(dtype = tf.float32, shape = (None, 10))

with tf.name_scope('hidden_layer'):
 with tf.name_scope('hidden_arg'):
 #隐层模型参数
 with tf.name_scope('hid_w'):
  
  hid_w = tf.Variable(
   tf.truncated_normal(shape = (IMAGE_PIXELS * IMAGE_PIXELS, hidden_unit)),
   name = 'hidden_w')
  #添加获取隐层权重统计值汇总数据的汇总操作
  tf.summary.histogram(name = 'weights', values = hid_w)
  with tf.name_scope('hid_b'):
  hid_b = tf.Variable(tf.zeros(shape = (1, hidden_unit), dtype = tf.float32),
       name = 'hidden_b')
 #隐层输出
 with tf.name_scope('relu'):
 hid_out = tf.nn.relu(tf.matmul(x, hid_w) + hid_b)
with tf.name_scope('softmax_layer'):
 with tf.name_scope('softmax_arg'):
 #softmax层参数
 with tf.name_scope('sm_w'):
  
  sm_w = tf.Variable(
   tf.truncated_normal(shape = (hidden_unit, output_nums)),
   name = 'softmax_w')
  #添加获取softmax层权重统计值汇总数据的汇总操作
  tf.summary.histogram(name = 'weights', values = sm_w)
  with tf.name_scope('sm_b'):
  sm_b = tf.Variable(tf.zeros(shape = (1, output_nums), dtype = tf.float32), 
       name = 'softmax_b')
 #softmax层的输出
 with tf.name_scope('softmax'):
 y = tf.nn.softmax(tf.matmul(hid_out, sm_w) + sm_b)
 #梯度裁剪,因为概率取值为[0, 1]为避免出现无意义的log(0),故将y值裁剪到[1e-10, 1]
 y_clip = tf.clip_by_value(y, 1.0e-10, 1 - 1.0e-5)
with tf.name_scope('cross_entropy'):
 #使用交叉熵代价函数
 cross_entropy = -tf.reduce_sum(y_ * tf.log(y_clip) + (1 - y_) * tf.log(1 - y_clip))
 #添加获取交叉熵的汇总操作
 tf.summary.scalar(name = 'cross_entropy', tensor = cross_entropy)
 
with tf.name_scope('train'):
 #若不使用同步训练机制,使用Adam优化器
 optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)
 #单步训练操作,
 train_op = optimizer.minimize(cross_entropy, global_step = global_step)
#加载测试数据
test_image = mnist.test.images
test_label = mnist.test.labels
test_feed = {x:test_image, y_:test_label}

with tf.name_scope('accuracy'):
 prediction = tf.equal(tf.argmax(input = y, axis = 1),
      tf.argmax(input = y_, axis = 1))
 accuracy = tf.reduce_mean(
  input_tensor = tf.cast(x = prediction, dtype = tf.float32))
#创建嵌入变量
embedding_var = tf.Variable(test_image, trainable = False, name = 'embedding')
saver = tf.train.Saver({'embedding':embedding_var})
#创建元数据文件,将MNIST图像测试集对应的标签写入文件
def CreateMedaDataFile():
 with open(logdir + '/metadata.tsv', 'w') as f:
 label = np.nonzero(test_label)[1]
 for i in range(test_data_size):
  f.write('%d\n' % label[i])
#创建投影配置参数
def CreateProjectorConfig():
 config = projector.ProjectorConfig()
 embeddings = config.embeddings.add()
 embeddings.tensor_name = 'embedding:0'
 embeddings.metadata_path = logdir + '/metadata.tsv'
 
 projector.visualize_embeddings(writer, config)
 #聚集汇总操作
merged = tf.summary.merge_all()
#创建会话的配置参数
sess_config = tf.ConfigProto(
 allow_soft_placement = True,
 log_device_placement = False)
#创建会话
with tf.Session(config = sess_config) as sess:
 #创建FileWriter实例
 writer = tf.summary.FileWriter(logdir = logdir, graph = sess.graph)
 #初始化全局变量
 sess.run(tf.global_variables_initializer())
 time_begin = time.time()
 print('Training begin time: %f' % time_begin)
 while True:
 #加载训练批数据
 batch_x, batch_y = mnist.train.next_batch(batch_size)
 train_feed = {x:batch_x, y_:batch_y}
 loss, _, summary= sess.run([cross_entropy, train_op, merged], feed_dict = train_feed)
 step = global_step.eval()
 #如果step为100的整数倍
 if step % 100 == 0:
  now = time.time()
  print('%f: global_step = %d, loss = %f' % (
   now, step, loss))
  #向事件文件中添加汇总数据
  writer.add_summary(summary = summary, global_step = step)
 #若大于等于训练总步数,退出训练
 if step >= train_steps:
  break
 time_end = time.time()
 print('Training end time: %f' % time_end)
 print('Training time: %f' % (time_end - time_begin))
 #测试模型精度
 test_accuracy = sess.run(accuracy, feed_dict = test_feed)
 print('accuracy: %f' % test_accuracy)
 
 saver.save(sess = sess, save_path = logdir + '/embedding_var.ckpt')
 CreateMedaDataFile()
 CreateProjectorConfig()
 #关闭FileWriter
 writer.close()

基于Tensorflow的MNIST手写数字识别分类

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解Python多线程Selenium跨浏览器测试
Apr 01 Python
详解tensorflow训练自己的数据集实现CNN图像分类
Feb 07 Python
Python使用numpy实现BP神经网络
Mar 10 Python
python实现ID3决策树算法
Aug 29 Python
Django使用unittest模块进行单元测试过程解析
Aug 02 Python
Python 利用高德地图api实现经纬度与地址的批量转换
Aug 14 Python
基于Python脚本实现邮件报警功能
May 20 Python
Numpy中np.random.rand()和np.random.randn() 用法和区别详解
Oct 23 Python
python正则表达式re.match()匹配多个字符方法的实现
Jan 27 Python
python实现代码审查自动回复消息
Feb 01 Python
Python环境搭建过程从安装到Hello World
Feb 05 Python
Python&Matlab实现樱花的绘制
Apr 07 Python
Kears 使用:通过回调函数保存最佳准确率下的模型操作
Jun 17 #Python
Python多线程threading创建及使用方法解析
Jun 17 #Python
Python偏函数Partial function使用方法实例详解
Jun 17 #Python
详解Python IO口多路复用
Jun 17 #Python
基于keras中的回调函数用法说明
Jun 17 #Python
Python学习之路安装pycharm的教程详解
Jun 17 #Python
Python闭包及装饰器运行原理解析
Jun 17 #Python
You might like
PHP安装GeoIP扩展根据IP获取地理位置及计算距离的方法
2016/07/01 PHP
php如何执行非缓冲查询API
2016/07/22 PHP
PHP面向对象自动加载机制原理与用法分析
2016/10/14 PHP
解决FireFox下[使用event很麻烦]的问题
2006/11/26 Javascript
javascript中使用replaceAll()函数实现字符替换的方法
2010/12/25 Javascript
图片无缝滚动代码(向左/向下/向上)
2013/04/10 Javascript
jQuery大于号(>)选择器的作用解释
2015/01/13 Javascript
jQuery实现仿百度帖吧头部固定导航效果
2015/08/07 Javascript
JS实现兼容性好,带缓冲的动感网页右键菜单效果
2015/09/18 Javascript
JS简单实现无缝滚动效果实例
2016/08/24 Javascript
jQuery延迟执行的实现方法
2016/12/21 Javascript
使用vue-cli+webpack搭建vue开发环境的方法
2017/12/22 Javascript
vue 开发一个按钮组件的示例代码
2018/03/27 Javascript
JavaScript日期工具类DateUtils定义与用法示例
2018/09/03 Javascript
vue2.x集成百度UEditor富文本编辑器的方法
2018/09/21 Javascript
js 实现ajax发送步骤过程详解
2019/07/25 Javascript
Element Card 卡片的具体使用
2020/07/26 Javascript
Python中不同进制的语法及转换方法分析
2016/07/27 Python
python 动态调用函数实例解析
2019/10/21 Python
详解用Python进行时间序列预测的7种方法
2020/03/13 Python
Python使用pdb调试代码的技巧
2020/05/03 Python
Python3.9 beta2版本发布了,看看这7个新的PEP都是什么
2020/06/10 Python
python能做哪方面的工作
2020/06/15 Python
Visual Studio Code搭建django项目的方法步骤
2020/09/17 Python
英国经济型酒店品牌:Travelodge
2019/12/17 全球购物
优秀小学生家长评语
2014/01/30 职场文书
3.12植树节活动总结2014
2014/03/13 职场文书
运动会方阵口号
2014/06/07 职场文书
办公室文员岗位职责范本
2014/06/12 职场文书
银行主办会计岗位职责
2014/08/13 职场文书
公务员爱岗敬业演讲稿
2014/08/26 职场文书
2014最新实习证明模板
2014/10/02 职场文书
2015年党员自评材料
2014/12/17 职场文书
戒赌保证书
2015/05/11 职场文书
2015年清剿火患专项行动工作总结
2015/07/27 职场文书
详解Redis基本命令与使用场景
2021/06/01 Redis