基于Tensorflow的MNIST手写数字识别分类


Posted in Python onJune 17, 2020

本文实例为大家分享了基于Tensorflow的MNIST手写数字识别分类的具体实现代码,供大家参考,具体内容如下

代码如下:

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.contrib.tensorboard.plugins import projector
import time

IMAGE_PIXELS = 28
hidden_unit = 100
output_nums = 10
learning_rate = 0.001
train_steps = 50000
batch_size = 500
test_data_size = 10000
#日志目录(这里根据自己的目录修改)
logdir = 'D:/Develop_Software/Anaconda3/WorkDirectory/summary/mnist'
#导入mnist数据
mnist = input_data.read_data_sets('MNIST_data', one_hot = True)

 #全局训练步数
global_step = tf.Variable(0, name = 'global_step', trainable = False)
with tf.name_scope('input'):
 #输入数据
 with tf.name_scope('x'):
 x = tf.placeholder(
  dtype = tf.float32, shape = (None, IMAGE_PIXELS * IMAGE_PIXELS))
 #收集x图像的会总数据
 with tf.name_scope('x_summary'):
 shaped_image_batch = tf.reshape(
  tensor = x,
  shape = (-1, IMAGE_PIXELS, IMAGE_PIXELS, 1),
  name = 'shaped_image_batch')
 tf.summary.image(name = 'image_summary',
      tensor = shaped_image_batch,
      max_outputs = 10)
 with tf.name_scope('y_'):
 y_ = tf.placeholder(dtype = tf.float32, shape = (None, 10))

with tf.name_scope('hidden_layer'):
 with tf.name_scope('hidden_arg'):
 #隐层模型参数
 with tf.name_scope('hid_w'):
  
  hid_w = tf.Variable(
   tf.truncated_normal(shape = (IMAGE_PIXELS * IMAGE_PIXELS, hidden_unit)),
   name = 'hidden_w')
  #添加获取隐层权重统计值汇总数据的汇总操作
  tf.summary.histogram(name = 'weights', values = hid_w)
  with tf.name_scope('hid_b'):
  hid_b = tf.Variable(tf.zeros(shape = (1, hidden_unit), dtype = tf.float32),
       name = 'hidden_b')
 #隐层输出
 with tf.name_scope('relu'):
 hid_out = tf.nn.relu(tf.matmul(x, hid_w) + hid_b)
with tf.name_scope('softmax_layer'):
 with tf.name_scope('softmax_arg'):
 #softmax层参数
 with tf.name_scope('sm_w'):
  
  sm_w = tf.Variable(
   tf.truncated_normal(shape = (hidden_unit, output_nums)),
   name = 'softmax_w')
  #添加获取softmax层权重统计值汇总数据的汇总操作
  tf.summary.histogram(name = 'weights', values = sm_w)
  with tf.name_scope('sm_b'):
  sm_b = tf.Variable(tf.zeros(shape = (1, output_nums), dtype = tf.float32), 
       name = 'softmax_b')
 #softmax层的输出
 with tf.name_scope('softmax'):
 y = tf.nn.softmax(tf.matmul(hid_out, sm_w) + sm_b)
 #梯度裁剪,因为概率取值为[0, 1]为避免出现无意义的log(0),故将y值裁剪到[1e-10, 1]
 y_clip = tf.clip_by_value(y, 1.0e-10, 1 - 1.0e-5)
with tf.name_scope('cross_entropy'):
 #使用交叉熵代价函数
 cross_entropy = -tf.reduce_sum(y_ * tf.log(y_clip) + (1 - y_) * tf.log(1 - y_clip))
 #添加获取交叉熵的汇总操作
 tf.summary.scalar(name = 'cross_entropy', tensor = cross_entropy)
 
with tf.name_scope('train'):
 #若不使用同步训练机制,使用Adam优化器
 optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)
 #单步训练操作,
 train_op = optimizer.minimize(cross_entropy, global_step = global_step)
#加载测试数据
test_image = mnist.test.images
test_label = mnist.test.labels
test_feed = {x:test_image, y_:test_label}

with tf.name_scope('accuracy'):
 prediction = tf.equal(tf.argmax(input = y, axis = 1),
      tf.argmax(input = y_, axis = 1))
 accuracy = tf.reduce_mean(
  input_tensor = tf.cast(x = prediction, dtype = tf.float32))
#创建嵌入变量
embedding_var = tf.Variable(test_image, trainable = False, name = 'embedding')
saver = tf.train.Saver({'embedding':embedding_var})
#创建元数据文件,将MNIST图像测试集对应的标签写入文件
def CreateMedaDataFile():
 with open(logdir + '/metadata.tsv', 'w') as f:
 label = np.nonzero(test_label)[1]
 for i in range(test_data_size):
  f.write('%d\n' % label[i])
#创建投影配置参数
def CreateProjectorConfig():
 config = projector.ProjectorConfig()
 embeddings = config.embeddings.add()
 embeddings.tensor_name = 'embedding:0'
 embeddings.metadata_path = logdir + '/metadata.tsv'
 
 projector.visualize_embeddings(writer, config)
 #聚集汇总操作
merged = tf.summary.merge_all()
#创建会话的配置参数
sess_config = tf.ConfigProto(
 allow_soft_placement = True,
 log_device_placement = False)
#创建会话
with tf.Session(config = sess_config) as sess:
 #创建FileWriter实例
 writer = tf.summary.FileWriter(logdir = logdir, graph = sess.graph)
 #初始化全局变量
 sess.run(tf.global_variables_initializer())
 time_begin = time.time()
 print('Training begin time: %f' % time_begin)
 while True:
 #加载训练批数据
 batch_x, batch_y = mnist.train.next_batch(batch_size)
 train_feed = {x:batch_x, y_:batch_y}
 loss, _, summary= sess.run([cross_entropy, train_op, merged], feed_dict = train_feed)
 step = global_step.eval()
 #如果step为100的整数倍
 if step % 100 == 0:
  now = time.time()
  print('%f: global_step = %d, loss = %f' % (
   now, step, loss))
  #向事件文件中添加汇总数据
  writer.add_summary(summary = summary, global_step = step)
 #若大于等于训练总步数,退出训练
 if step >= train_steps:
  break
 time_end = time.time()
 print('Training end time: %f' % time_end)
 print('Training time: %f' % (time_end - time_begin))
 #测试模型精度
 test_accuracy = sess.run(accuracy, feed_dict = test_feed)
 print('accuracy: %f' % test_accuracy)
 
 saver.save(sess = sess, save_path = logdir + '/embedding_var.ckpt')
 CreateMedaDataFile()
 CreateProjectorConfig()
 #关闭FileWriter
 writer.close()

基于Tensorflow的MNIST手写数字识别分类

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解Python2.x中对Unicode编码的使用
Apr 03 Python
详解Django缓存处理中Vary头部的使用
Jul 24 Python
Python数据分析之双色球统计单个红和蓝球哪个比例高的方法
Feb 03 Python
python3+PyQt5实现柱状图
Apr 24 Python
使用Python做定时任务及时了解互联网动态
May 15 Python
Flask框架 CSRF 保护实现方法详解
Oct 30 Python
python实现将视频按帧读取到自定义目录
Dec 10 Python
如何使用python3获取当前路径及os.path.dirname的使用
Dec 13 Python
使用 prometheus python 库编写自定义指标的方法(完整代码)
Jun 29 Python
python如何利用paramiko执行服务器命令
Nov 07 Python
Python爬取科目四考试题库的方法实现
Mar 30 Python
Django分页器的用法你都了解吗
May 26 Python
Kears 使用:通过回调函数保存最佳准确率下的模型操作
Jun 17 #Python
Python多线程threading创建及使用方法解析
Jun 17 #Python
Python偏函数Partial function使用方法实例详解
Jun 17 #Python
详解Python IO口多路复用
Jun 17 #Python
基于keras中的回调函数用法说明
Jun 17 #Python
Python学习之路安装pycharm的教程详解
Jun 17 #Python
Python闭包及装饰器运行原理解析
Jun 17 #Python
You might like
linux下为php添加curl扩展的方法
2011/07/29 PHP
基于simple_html_dom的使用小结
2013/07/01 PHP
js利用Array.splice实现Array的insert/remove
2009/01/13 Javascript
利用WebBrowser彻底解决Web打印问题(包括后台打印)
2009/06/22 Javascript
js 多种变量定义(对象直接量,数组直接量和函数直接量)
2010/05/24 Javascript
jquery图片放大镜功能的实例代码
2013/03/26 Javascript
Javascript学习笔记之数组的遍历和 length 属性
2014/11/23 Javascript
jQuery实现的超简单点赞效果实例分析
2015/12/31 Javascript
基于Bootstrap使用jQuery实现简单可编辑表格
2016/05/04 Javascript
Google Maps基础及实例解析
2016/08/06 Javascript
js控制文本框只能输入中文、英文、数字与指定特殊符号的实现代码
2016/09/09 Javascript
BootStrap学习系列之布局组件(下拉,按钮组[toolbar],上拉)
2017/01/03 Javascript
详解JS中的快速排序与冒泡
2017/01/10 Javascript
在ABP框架中使用BootstrapTable组件的方法
2017/07/31 Javascript
vue-cli中打包图片路径错误的解决方法
2017/10/26 Javascript
Vue 组件传值几种常用方法【总结】
2018/05/28 Javascript
Vue实现todolist删除功能
2018/06/26 Javascript
解决vue axios的封装 请求状态的错误提示问题
2018/09/25 Javascript
layui对工具条进行选择性的显示方法
2019/09/19 Javascript
[01:32:50]DOTA2-DPC中国联赛 正赛 DLG vs XG BO3 第一场 1月25日
2021/03/11 DOTA
11个并不被常用但对开发非常有帮助的Python库
2015/03/31 Python
Python基于动态规划算法计算单词距离
2015/07/25 Python
Python3.6 Schedule模块定时任务(实例讲解)
2017/11/09 Python
Python多进程编程multiprocessing代码实例
2020/03/12 Python
VC++笔试题
2014/10/13 面试题
教师自荐书
2013/10/08 职场文书
2014年3.15团委活动总结
2014/03/16 职场文书
微笑面对生活演讲稿
2014/09/23 职场文书
英语通知范文
2015/04/22 职场文书
信息技术国培研修日志
2015/11/13 职场文书
电力企业职工培训心得体会
2016/01/11 职场文书
同学会演讲稿
2019/04/02 职场文书
党风廉政建设心得体会
2019/05/21 职场文书
python爬取新闻门户网站的示例
2021/04/25 Python
golang内置函数len的小技巧
2021/07/25 Golang
python之基数排序的实现
2021/07/26 Python