基于Tensorflow的MNIST手写数字识别分类


Posted in Python onJune 17, 2020

本文实例为大家分享了基于Tensorflow的MNIST手写数字识别分类的具体实现代码,供大家参考,具体内容如下

代码如下:

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.contrib.tensorboard.plugins import projector
import time

IMAGE_PIXELS = 28
hidden_unit = 100
output_nums = 10
learning_rate = 0.001
train_steps = 50000
batch_size = 500
test_data_size = 10000
#日志目录(这里根据自己的目录修改)
logdir = 'D:/Develop_Software/Anaconda3/WorkDirectory/summary/mnist'
#导入mnist数据
mnist = input_data.read_data_sets('MNIST_data', one_hot = True)

 #全局训练步数
global_step = tf.Variable(0, name = 'global_step', trainable = False)
with tf.name_scope('input'):
 #输入数据
 with tf.name_scope('x'):
 x = tf.placeholder(
  dtype = tf.float32, shape = (None, IMAGE_PIXELS * IMAGE_PIXELS))
 #收集x图像的会总数据
 with tf.name_scope('x_summary'):
 shaped_image_batch = tf.reshape(
  tensor = x,
  shape = (-1, IMAGE_PIXELS, IMAGE_PIXELS, 1),
  name = 'shaped_image_batch')
 tf.summary.image(name = 'image_summary',
      tensor = shaped_image_batch,
      max_outputs = 10)
 with tf.name_scope('y_'):
 y_ = tf.placeholder(dtype = tf.float32, shape = (None, 10))

with tf.name_scope('hidden_layer'):
 with tf.name_scope('hidden_arg'):
 #隐层模型参数
 with tf.name_scope('hid_w'):
  
  hid_w = tf.Variable(
   tf.truncated_normal(shape = (IMAGE_PIXELS * IMAGE_PIXELS, hidden_unit)),
   name = 'hidden_w')
  #添加获取隐层权重统计值汇总数据的汇总操作
  tf.summary.histogram(name = 'weights', values = hid_w)
  with tf.name_scope('hid_b'):
  hid_b = tf.Variable(tf.zeros(shape = (1, hidden_unit), dtype = tf.float32),
       name = 'hidden_b')
 #隐层输出
 with tf.name_scope('relu'):
 hid_out = tf.nn.relu(tf.matmul(x, hid_w) + hid_b)
with tf.name_scope('softmax_layer'):
 with tf.name_scope('softmax_arg'):
 #softmax层参数
 with tf.name_scope('sm_w'):
  
  sm_w = tf.Variable(
   tf.truncated_normal(shape = (hidden_unit, output_nums)),
   name = 'softmax_w')
  #添加获取softmax层权重统计值汇总数据的汇总操作
  tf.summary.histogram(name = 'weights', values = sm_w)
  with tf.name_scope('sm_b'):
  sm_b = tf.Variable(tf.zeros(shape = (1, output_nums), dtype = tf.float32), 
       name = 'softmax_b')
 #softmax层的输出
 with tf.name_scope('softmax'):
 y = tf.nn.softmax(tf.matmul(hid_out, sm_w) + sm_b)
 #梯度裁剪,因为概率取值为[0, 1]为避免出现无意义的log(0),故将y值裁剪到[1e-10, 1]
 y_clip = tf.clip_by_value(y, 1.0e-10, 1 - 1.0e-5)
with tf.name_scope('cross_entropy'):
 #使用交叉熵代价函数
 cross_entropy = -tf.reduce_sum(y_ * tf.log(y_clip) + (1 - y_) * tf.log(1 - y_clip))
 #添加获取交叉熵的汇总操作
 tf.summary.scalar(name = 'cross_entropy', tensor = cross_entropy)
 
with tf.name_scope('train'):
 #若不使用同步训练机制,使用Adam优化器
 optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)
 #单步训练操作,
 train_op = optimizer.minimize(cross_entropy, global_step = global_step)
#加载测试数据
test_image = mnist.test.images
test_label = mnist.test.labels
test_feed = {x:test_image, y_:test_label}

with tf.name_scope('accuracy'):
 prediction = tf.equal(tf.argmax(input = y, axis = 1),
      tf.argmax(input = y_, axis = 1))
 accuracy = tf.reduce_mean(
  input_tensor = tf.cast(x = prediction, dtype = tf.float32))
#创建嵌入变量
embedding_var = tf.Variable(test_image, trainable = False, name = 'embedding')
saver = tf.train.Saver({'embedding':embedding_var})
#创建元数据文件,将MNIST图像测试集对应的标签写入文件
def CreateMedaDataFile():
 with open(logdir + '/metadata.tsv', 'w') as f:
 label = np.nonzero(test_label)[1]
 for i in range(test_data_size):
  f.write('%d\n' % label[i])
#创建投影配置参数
def CreateProjectorConfig():
 config = projector.ProjectorConfig()
 embeddings = config.embeddings.add()
 embeddings.tensor_name = 'embedding:0'
 embeddings.metadata_path = logdir + '/metadata.tsv'
 
 projector.visualize_embeddings(writer, config)
 #聚集汇总操作
merged = tf.summary.merge_all()
#创建会话的配置参数
sess_config = tf.ConfigProto(
 allow_soft_placement = True,
 log_device_placement = False)
#创建会话
with tf.Session(config = sess_config) as sess:
 #创建FileWriter实例
 writer = tf.summary.FileWriter(logdir = logdir, graph = sess.graph)
 #初始化全局变量
 sess.run(tf.global_variables_initializer())
 time_begin = time.time()
 print('Training begin time: %f' % time_begin)
 while True:
 #加载训练批数据
 batch_x, batch_y = mnist.train.next_batch(batch_size)
 train_feed = {x:batch_x, y_:batch_y}
 loss, _, summary= sess.run([cross_entropy, train_op, merged], feed_dict = train_feed)
 step = global_step.eval()
 #如果step为100的整数倍
 if step % 100 == 0:
  now = time.time()
  print('%f: global_step = %d, loss = %f' % (
   now, step, loss))
  #向事件文件中添加汇总数据
  writer.add_summary(summary = summary, global_step = step)
 #若大于等于训练总步数,退出训练
 if step >= train_steps:
  break
 time_end = time.time()
 print('Training end time: %f' % time_end)
 print('Training time: %f' % (time_end - time_begin))
 #测试模型精度
 test_accuracy = sess.run(accuracy, feed_dict = test_feed)
 print('accuracy: %f' % test_accuracy)
 
 saver.save(sess = sess, save_path = logdir + '/embedding_var.ckpt')
 CreateMedaDataFile()
 CreateProjectorConfig()
 #关闭FileWriter
 writer.close()

基于Tensorflow的MNIST手写数字识别分类

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python读取浮点数和读取文本文件示例
May 06 Python
Python检测QQ在线状态的方法
May 09 Python
python处理按钮消息的实例详解
Jul 11 Python
详解用Python处理HTML转义字符的5种方式
Dec 27 Python
Python3.6安装及引入Requests库的实现方法
Jan 24 Python
python入门教程 python入门神图一张
Mar 05 Python
Python使用gRPC传输协议教程
Oct 16 Python
Python基础教程之异常详解
Jan 10 Python
Python logging模块异步线程写日志实现过程解析
Jun 30 Python
基于python实现删除指定文件类型
Jul 21 Python
python 决策树算法的实现
Oct 09 Python
如何用tempfile库创建python进程中的临时文件
Jan 28 Python
Kears 使用:通过回调函数保存最佳准确率下的模型操作
Jun 17 #Python
Python多线程threading创建及使用方法解析
Jun 17 #Python
Python偏函数Partial function使用方法实例详解
Jun 17 #Python
详解Python IO口多路复用
Jun 17 #Python
基于keras中的回调函数用法说明
Jun 17 #Python
Python学习之路安装pycharm的教程详解
Jun 17 #Python
Python闭包及装饰器运行原理解析
Jun 17 #Python
You might like
PHP中实现中文字符进制转换原理分析
2011/12/06 PHP
Laravel中基于Artisan View扩展包创建及删除应用视图文件的方法
2016/10/08 PHP
document对象execCommand的command参数介绍
2006/08/01 Javascript
JavaScript 学习笔记(六)
2009/12/31 Javascript
jQuery-ui中自动完成实现方法
2010/06/10 Javascript
javascript dom代码应用 简单的相册[firefox only]
2010/06/12 Javascript
Javascript之旅 对象的原型链之由来
2010/08/25 Javascript
js调用后台servlet方法实例
2013/06/09 Javascript
JS 实现导航栏悬停效果(续)
2013/09/24 Javascript
express的中间件basicAuth详解
2014/12/04 Javascript
jquery中的工具使用方法$.isFunction, $.isArray(), $.isWindow()
2015/08/09 Javascript
基于Bootstrap实现图片轮播效果
2016/05/22 Javascript
底部悬浮通栏可以关闭广告位的实现方法
2016/06/01 Javascript
JS实现物体带缓冲的间歇运动效果示例
2016/12/22 Javascript
JavaScript获取tr td 的三种方式全面总结(推荐)
2017/08/15 Javascript
Vue中Table组件Select的勾选和取消勾选事件详解
2019/03/19 Javascript
如何使用50行javaScript代码实现简单版的call,apply,bind
2019/08/14 Javascript
微信小程序实现一个简单swiper代码实例
2019/12/30 Javascript
vue 限制input只能输入正数的操作
2020/08/05 Javascript
Openlayers绘制地图标注
2020/09/28 Javascript
node脚手架搭建服务器实现token验证的方法
2021/01/20 Javascript
Python的Flask框架应用调用Redis队列数据的方法
2016/06/06 Python
Python实现爬虫从网络上下载文档的实例代码
2018/06/13 Python
django 发送邮件和缓存的实现代码
2018/07/18 Python
python爬虫获取小区经纬度以及结构化地址
2018/12/30 Python
django ManyToManyField多对多关系的实例详解
2019/08/09 Python
python模拟预测一下新型冠状病毒肺炎的数据
2020/02/01 Python
Django视图、传参和forms验证操作
2020/07/15 Python
python实现文件+参数发送request的实例代码
2021/01/05 Python
Skyscanner香港:机票比价, 平机票和廉价航空机票预订
2020/02/07 全球购物
The North Face意大利官网:服装、背包和鞋子
2020/06/17 全球购物
extern是什么意思
2016/03/10 面试题
编辑硕士自荐信范文
2013/11/27 职场文书
高三励志标语
2014/06/05 职场文书
重阳节座谈会主持词
2015/07/03 职场文书
Java中常用解析工具jackson及fastjson的使用
2021/06/28 Java/Android