基于Tensorflow的MNIST手写数字识别分类


Posted in Python onJune 17, 2020

本文实例为大家分享了基于Tensorflow的MNIST手写数字识别分类的具体实现代码,供大家参考,具体内容如下

代码如下:

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.contrib.tensorboard.plugins import projector
import time

IMAGE_PIXELS = 28
hidden_unit = 100
output_nums = 10
learning_rate = 0.001
train_steps = 50000
batch_size = 500
test_data_size = 10000
#日志目录(这里根据自己的目录修改)
logdir = 'D:/Develop_Software/Anaconda3/WorkDirectory/summary/mnist'
#导入mnist数据
mnist = input_data.read_data_sets('MNIST_data', one_hot = True)

 #全局训练步数
global_step = tf.Variable(0, name = 'global_step', trainable = False)
with tf.name_scope('input'):
 #输入数据
 with tf.name_scope('x'):
 x = tf.placeholder(
  dtype = tf.float32, shape = (None, IMAGE_PIXELS * IMAGE_PIXELS))
 #收集x图像的会总数据
 with tf.name_scope('x_summary'):
 shaped_image_batch = tf.reshape(
  tensor = x,
  shape = (-1, IMAGE_PIXELS, IMAGE_PIXELS, 1),
  name = 'shaped_image_batch')
 tf.summary.image(name = 'image_summary',
      tensor = shaped_image_batch,
      max_outputs = 10)
 with tf.name_scope('y_'):
 y_ = tf.placeholder(dtype = tf.float32, shape = (None, 10))

with tf.name_scope('hidden_layer'):
 with tf.name_scope('hidden_arg'):
 #隐层模型参数
 with tf.name_scope('hid_w'):
  
  hid_w = tf.Variable(
   tf.truncated_normal(shape = (IMAGE_PIXELS * IMAGE_PIXELS, hidden_unit)),
   name = 'hidden_w')
  #添加获取隐层权重统计值汇总数据的汇总操作
  tf.summary.histogram(name = 'weights', values = hid_w)
  with tf.name_scope('hid_b'):
  hid_b = tf.Variable(tf.zeros(shape = (1, hidden_unit), dtype = tf.float32),
       name = 'hidden_b')
 #隐层输出
 with tf.name_scope('relu'):
 hid_out = tf.nn.relu(tf.matmul(x, hid_w) + hid_b)
with tf.name_scope('softmax_layer'):
 with tf.name_scope('softmax_arg'):
 #softmax层参数
 with tf.name_scope('sm_w'):
  
  sm_w = tf.Variable(
   tf.truncated_normal(shape = (hidden_unit, output_nums)),
   name = 'softmax_w')
  #添加获取softmax层权重统计值汇总数据的汇总操作
  tf.summary.histogram(name = 'weights', values = sm_w)
  with tf.name_scope('sm_b'):
  sm_b = tf.Variable(tf.zeros(shape = (1, output_nums), dtype = tf.float32), 
       name = 'softmax_b')
 #softmax层的输出
 with tf.name_scope('softmax'):
 y = tf.nn.softmax(tf.matmul(hid_out, sm_w) + sm_b)
 #梯度裁剪,因为概率取值为[0, 1]为避免出现无意义的log(0),故将y值裁剪到[1e-10, 1]
 y_clip = tf.clip_by_value(y, 1.0e-10, 1 - 1.0e-5)
with tf.name_scope('cross_entropy'):
 #使用交叉熵代价函数
 cross_entropy = -tf.reduce_sum(y_ * tf.log(y_clip) + (1 - y_) * tf.log(1 - y_clip))
 #添加获取交叉熵的汇总操作
 tf.summary.scalar(name = 'cross_entropy', tensor = cross_entropy)
 
with tf.name_scope('train'):
 #若不使用同步训练机制,使用Adam优化器
 optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)
 #单步训练操作,
 train_op = optimizer.minimize(cross_entropy, global_step = global_step)
#加载测试数据
test_image = mnist.test.images
test_label = mnist.test.labels
test_feed = {x:test_image, y_:test_label}

with tf.name_scope('accuracy'):
 prediction = tf.equal(tf.argmax(input = y, axis = 1),
      tf.argmax(input = y_, axis = 1))
 accuracy = tf.reduce_mean(
  input_tensor = tf.cast(x = prediction, dtype = tf.float32))
#创建嵌入变量
embedding_var = tf.Variable(test_image, trainable = False, name = 'embedding')
saver = tf.train.Saver({'embedding':embedding_var})
#创建元数据文件,将MNIST图像测试集对应的标签写入文件
def CreateMedaDataFile():
 with open(logdir + '/metadata.tsv', 'w') as f:
 label = np.nonzero(test_label)[1]
 for i in range(test_data_size):
  f.write('%d\n' % label[i])
#创建投影配置参数
def CreateProjectorConfig():
 config = projector.ProjectorConfig()
 embeddings = config.embeddings.add()
 embeddings.tensor_name = 'embedding:0'
 embeddings.metadata_path = logdir + '/metadata.tsv'
 
 projector.visualize_embeddings(writer, config)
 #聚集汇总操作
merged = tf.summary.merge_all()
#创建会话的配置参数
sess_config = tf.ConfigProto(
 allow_soft_placement = True,
 log_device_placement = False)
#创建会话
with tf.Session(config = sess_config) as sess:
 #创建FileWriter实例
 writer = tf.summary.FileWriter(logdir = logdir, graph = sess.graph)
 #初始化全局变量
 sess.run(tf.global_variables_initializer())
 time_begin = time.time()
 print('Training begin time: %f' % time_begin)
 while True:
 #加载训练批数据
 batch_x, batch_y = mnist.train.next_batch(batch_size)
 train_feed = {x:batch_x, y_:batch_y}
 loss, _, summary= sess.run([cross_entropy, train_op, merged], feed_dict = train_feed)
 step = global_step.eval()
 #如果step为100的整数倍
 if step % 100 == 0:
  now = time.time()
  print('%f: global_step = %d, loss = %f' % (
   now, step, loss))
  #向事件文件中添加汇总数据
  writer.add_summary(summary = summary, global_step = step)
 #若大于等于训练总步数,退出训练
 if step >= train_steps:
  break
 time_end = time.time()
 print('Training end time: %f' % time_end)
 print('Training time: %f' % (time_end - time_begin))
 #测试模型精度
 test_accuracy = sess.run(accuracy, feed_dict = test_feed)
 print('accuracy: %f' % test_accuracy)
 
 saver.save(sess = sess, save_path = logdir + '/embedding_var.ckpt')
 CreateMedaDataFile()
 CreateProjectorConfig()
 #关闭FileWriter
 writer.close()

基于Tensorflow的MNIST手写数字识别分类

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python快速查找算法应用实例
Sep 26 Python
Python实现网络端口转发和重定向的方法
Sep 19 Python
Python常见加密模块用法分析【MD5,sha,crypt模块】
May 24 Python
利用Python进行异常值分析实例代码
Dec 07 Python
浅谈DataFrame和SparkSql取值误区
Jun 09 Python
Python判断以什么结尾以什么开头的实例
Oct 27 Python
Python制作exe文件简单流程
Jan 24 Python
Python使用Pickle模块进行数据保存和读取的讲解
Apr 09 Python
Python restful框架接口开发实现
Apr 13 Python
Python创建自己的加密货币的示例
Mar 01 Python
python 三边测量定位的实现代码
Apr 22 Python
python通配符之glob模块的使用详解
Apr 24 Python
Kears 使用:通过回调函数保存最佳准确率下的模型操作
Jun 17 #Python
Python多线程threading创建及使用方法解析
Jun 17 #Python
Python偏函数Partial function使用方法实例详解
Jun 17 #Python
详解Python IO口多路复用
Jun 17 #Python
基于keras中的回调函数用法说明
Jun 17 #Python
Python学习之路安装pycharm的教程详解
Jun 17 #Python
Python闭包及装饰器运行原理解析
Jun 17 #Python
You might like
PHP笔记之:基于面向对象设计的详解
2013/05/14 PHP
win7 64位系统 配置php最新版开发环境(php+Apache+mysql)
2014/08/15 PHP
Yii数据库缓存实例分析
2016/03/29 PHP
基于jQuery实现的水平和垂直居中的div窗口
2011/08/08 Javascript
JavaScript高级程序设计 读书笔记之十一 内置对象Global
2012/03/07 Javascript
js控制frameSet示例
2013/09/10 Javascript
jquery 自定义容器下雨效果可将下雨图标改为其他
2014/04/23 Javascript
JS判断图片是否加载完成方法汇总(最新版)
2016/05/13 Javascript
bootstrap日历插件datetimepicker使用方法
2016/12/14 Javascript
JS简单实现移动端日历功能示例
2016/12/28 Javascript
微信小程序微信支付接入开发实例详解
2017/04/12 Javascript
angular学习之ngRoute路由机制
2017/04/12 Javascript
微信小程序分享功能之按钮button 边框隐藏和点击隐藏
2018/06/14 Javascript
详解JavaScript 中 if / if...else...替换方式
2018/07/15 Javascript
vue组件数据传递、父子组件数据获取,slot,router路由功能示例
2019/03/19 Javascript
使用Vue.set()方法实现响应式修改数组数据步骤
2019/11/09 Javascript
JavaScript常用工具函数大全
2020/05/06 Javascript
封装 axios+promise通用请求函数操作
2020/08/11 Javascript
vue 动态生成拓扑图的示例
2021/01/03 Vue.js
Python Nose框架编写测试用例方法
2017/10/26 Python
Django使用httpresponse返回用户头像实例代码
2018/01/26 Python
python 接口测试response返回数据对比的方法
2018/02/11 Python
pandas对指定列进行填充的方法
2018/04/11 Python
python实现控制台打印的方法
2019/01/12 Python
Python图像处理PIL各模块详细介绍(推荐)
2019/07/17 Python
详解Python IO编程
2020/07/24 Python
CSS3使用border-radius属性制作圆角
2014/12/22 HTML / CSS
Yves Rocher伊夫·黎雪美国官网:法国始创植物美肌1959
2019/01/09 全球购物
澳大利亚波西米亚风连衣裙在线商店:Fortunate One
2019/04/01 全球购物
英国领先的在线高尔夫商店:Scottsdale Golf
2019/08/26 全球购物
销售主管的自我评价分享
2014/01/03 职场文书
小学教师师德反思
2014/02/03 职场文书
党员转正大会主持词
2015/07/02 职场文书
MySQL实现配置主从复制项目实践
2022/03/31 MySQL
Python中requests库的用法详解
2022/06/05 Python
JS实现页面炫酷的时钟特效示例
2022/08/14 Javascript