基于Tensorflow的MNIST手写数字识别分类


Posted in Python onJune 17, 2020

本文实例为大家分享了基于Tensorflow的MNIST手写数字识别分类的具体实现代码,供大家参考,具体内容如下

代码如下:

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.contrib.tensorboard.plugins import projector
import time

IMAGE_PIXELS = 28
hidden_unit = 100
output_nums = 10
learning_rate = 0.001
train_steps = 50000
batch_size = 500
test_data_size = 10000
#日志目录(这里根据自己的目录修改)
logdir = 'D:/Develop_Software/Anaconda3/WorkDirectory/summary/mnist'
#导入mnist数据
mnist = input_data.read_data_sets('MNIST_data', one_hot = True)

 #全局训练步数
global_step = tf.Variable(0, name = 'global_step', trainable = False)
with tf.name_scope('input'):
 #输入数据
 with tf.name_scope('x'):
 x = tf.placeholder(
  dtype = tf.float32, shape = (None, IMAGE_PIXELS * IMAGE_PIXELS))
 #收集x图像的会总数据
 with tf.name_scope('x_summary'):
 shaped_image_batch = tf.reshape(
  tensor = x,
  shape = (-1, IMAGE_PIXELS, IMAGE_PIXELS, 1),
  name = 'shaped_image_batch')
 tf.summary.image(name = 'image_summary',
      tensor = shaped_image_batch,
      max_outputs = 10)
 with tf.name_scope('y_'):
 y_ = tf.placeholder(dtype = tf.float32, shape = (None, 10))

with tf.name_scope('hidden_layer'):
 with tf.name_scope('hidden_arg'):
 #隐层模型参数
 with tf.name_scope('hid_w'):
  
  hid_w = tf.Variable(
   tf.truncated_normal(shape = (IMAGE_PIXELS * IMAGE_PIXELS, hidden_unit)),
   name = 'hidden_w')
  #添加获取隐层权重统计值汇总数据的汇总操作
  tf.summary.histogram(name = 'weights', values = hid_w)
  with tf.name_scope('hid_b'):
  hid_b = tf.Variable(tf.zeros(shape = (1, hidden_unit), dtype = tf.float32),
       name = 'hidden_b')
 #隐层输出
 with tf.name_scope('relu'):
 hid_out = tf.nn.relu(tf.matmul(x, hid_w) + hid_b)
with tf.name_scope('softmax_layer'):
 with tf.name_scope('softmax_arg'):
 #softmax层参数
 with tf.name_scope('sm_w'):
  
  sm_w = tf.Variable(
   tf.truncated_normal(shape = (hidden_unit, output_nums)),
   name = 'softmax_w')
  #添加获取softmax层权重统计值汇总数据的汇总操作
  tf.summary.histogram(name = 'weights', values = sm_w)
  with tf.name_scope('sm_b'):
  sm_b = tf.Variable(tf.zeros(shape = (1, output_nums), dtype = tf.float32), 
       name = 'softmax_b')
 #softmax层的输出
 with tf.name_scope('softmax'):
 y = tf.nn.softmax(tf.matmul(hid_out, sm_w) + sm_b)
 #梯度裁剪,因为概率取值为[0, 1]为避免出现无意义的log(0),故将y值裁剪到[1e-10, 1]
 y_clip = tf.clip_by_value(y, 1.0e-10, 1 - 1.0e-5)
with tf.name_scope('cross_entropy'):
 #使用交叉熵代价函数
 cross_entropy = -tf.reduce_sum(y_ * tf.log(y_clip) + (1 - y_) * tf.log(1 - y_clip))
 #添加获取交叉熵的汇总操作
 tf.summary.scalar(name = 'cross_entropy', tensor = cross_entropy)
 
with tf.name_scope('train'):
 #若不使用同步训练机制,使用Adam优化器
 optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)
 #单步训练操作,
 train_op = optimizer.minimize(cross_entropy, global_step = global_step)
#加载测试数据
test_image = mnist.test.images
test_label = mnist.test.labels
test_feed = {x:test_image, y_:test_label}

with tf.name_scope('accuracy'):
 prediction = tf.equal(tf.argmax(input = y, axis = 1),
      tf.argmax(input = y_, axis = 1))
 accuracy = tf.reduce_mean(
  input_tensor = tf.cast(x = prediction, dtype = tf.float32))
#创建嵌入变量
embedding_var = tf.Variable(test_image, trainable = False, name = 'embedding')
saver = tf.train.Saver({'embedding':embedding_var})
#创建元数据文件,将MNIST图像测试集对应的标签写入文件
def CreateMedaDataFile():
 with open(logdir + '/metadata.tsv', 'w') as f:
 label = np.nonzero(test_label)[1]
 for i in range(test_data_size):
  f.write('%d\n' % label[i])
#创建投影配置参数
def CreateProjectorConfig():
 config = projector.ProjectorConfig()
 embeddings = config.embeddings.add()
 embeddings.tensor_name = 'embedding:0'
 embeddings.metadata_path = logdir + '/metadata.tsv'
 
 projector.visualize_embeddings(writer, config)
 #聚集汇总操作
merged = tf.summary.merge_all()
#创建会话的配置参数
sess_config = tf.ConfigProto(
 allow_soft_placement = True,
 log_device_placement = False)
#创建会话
with tf.Session(config = sess_config) as sess:
 #创建FileWriter实例
 writer = tf.summary.FileWriter(logdir = logdir, graph = sess.graph)
 #初始化全局变量
 sess.run(tf.global_variables_initializer())
 time_begin = time.time()
 print('Training begin time: %f' % time_begin)
 while True:
 #加载训练批数据
 batch_x, batch_y = mnist.train.next_batch(batch_size)
 train_feed = {x:batch_x, y_:batch_y}
 loss, _, summary= sess.run([cross_entropy, train_op, merged], feed_dict = train_feed)
 step = global_step.eval()
 #如果step为100的整数倍
 if step % 100 == 0:
  now = time.time()
  print('%f: global_step = %d, loss = %f' % (
   now, step, loss))
  #向事件文件中添加汇总数据
  writer.add_summary(summary = summary, global_step = step)
 #若大于等于训练总步数,退出训练
 if step >= train_steps:
  break
 time_end = time.time()
 print('Training end time: %f' % time_end)
 print('Training time: %f' % (time_end - time_begin))
 #测试模型精度
 test_accuracy = sess.run(accuracy, feed_dict = test_feed)
 print('accuracy: %f' % test_accuracy)
 
 saver.save(sess = sess, save_path = logdir + '/embedding_var.ckpt')
 CreateMedaDataFile()
 CreateProjectorConfig()
 #关闭FileWriter
 writer.close()

基于Tensorflow的MNIST手写数字识别分类

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现的数据结构与算法之双端队列详解
Apr 22 Python
python连接字符串的方法小结
Jul 13 Python
python实现按行切分文本文件的方法
Apr 18 Python
Python实现八大排序算法
Aug 13 Python
python 获取字符串MD5值方法
May 29 Python
python使用 zip 同时迭代多个序列示例
Jul 06 Python
python3模拟实现xshell远程执行liunx命令的方法
Jul 12 Python
Python Threading 线程/互斥锁/死锁/GIL锁
Jul 21 Python
django2.2安装错误最全的解决方案(小结)
Sep 24 Python
python学生管理系统的实现
Apr 05 Python
如何将tensorflow训练好的模型移植到Android (MNIST手写数字识别)
Apr 22 Python
python中sys模块是做什么用的
Aug 16 Python
Kears 使用:通过回调函数保存最佳准确率下的模型操作
Jun 17 #Python
Python多线程threading创建及使用方法解析
Jun 17 #Python
Python偏函数Partial function使用方法实例详解
Jun 17 #Python
详解Python IO口多路复用
Jun 17 #Python
基于keras中的回调函数用法说明
Jun 17 #Python
Python学习之路安装pycharm的教程详解
Jun 17 #Python
Python闭包及装饰器运行原理解析
Jun 17 #Python
You might like
PHP Global定义全局变量使用说明
2013/08/15 PHP
简单理解PHP的面向对象编程方式
2016/05/17 PHP
CI框架整合smarty步骤详解
2016/05/19 PHP
php 可变函数使用小结
2018/06/12 PHP
js验证表单大全
2006/11/25 Javascript
nodejs实现黑名单中间件设计
2014/06/17 NodeJs
JavaScript创建一个object对象并操作对象属性的用法
2015/03/23 Javascript
JS实现简单的图书馆享元模式实例
2015/06/30 Javascript
阿里巴巴技术文章分享 Javascript继承机制的实现
2016/01/14 Javascript
快速掌握Node.js事件驱动模型
2016/03/21 Javascript
Vue2实现组件props双向绑定
2016/12/02 Javascript
js实现3d悬浮效果
2017/02/16 Javascript
JS拉起或下载app的实现代码
2017/02/22 Javascript
深入理解nodejs中Express的中间件
2017/05/19 NodeJs
CSS3结合jQuery实现动画效果及回调函数的实例
2017/12/27 jQuery
electron-vue利用webpack打包实现多页面的入口文件问题
2019/05/12 Javascript
编写Python爬虫抓取豆瓣电影TOP100及用户头像的方法
2016/01/20 Python
python中pandas.DataFrame对行与列求和及添加新行与列示例
2017/03/12 Python
python入门:这篇文章带你直接学会python
2018/09/14 Python
浅谈Pycharm最有必要改的几个默认设置项
2020/02/14 Python
HTML5实现音频和视频嵌入的方法
2018/08/22 HTML / CSS
西班牙电子产品购物网站:Electronicamente
2018/07/26 全球购物
学年末自我鉴定
2014/01/21 职场文书
新书吧创业计划书
2014/01/31 职场文书
三八妇女节超市活动方案
2014/08/18 职场文书
机械制造专业大学生自我鉴定
2014/09/19 职场文书
毕业证委托书范文
2014/09/26 职场文书
2014年业务员工作总结范文
2014/11/17 职场文书
2014离婚协议书范文(3篇)
2014/11/29 职场文书
党员自我评价2015
2015/03/03 职场文书
解除劳动合同通知书范本
2015/04/16 职场文书
2015年小学重阳节活动总结
2015/07/29 职场文书
2015年秋季运动会广播稿
2015/08/19 职场文书
2015年幼儿园班主任个人工作总结
2015/10/22 职场文书
解决Vue+SpringBoot+Shiro跨域问题
2021/06/09 Vue.js
Android Rxjava3 使用场景详解
2022/04/07 Java/Android