基于Tensorflow的MNIST手写数字识别分类


Posted in Python onJune 17, 2020

本文实例为大家分享了基于Tensorflow的MNIST手写数字识别分类的具体实现代码,供大家参考,具体内容如下

代码如下:

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.contrib.tensorboard.plugins import projector
import time

IMAGE_PIXELS = 28
hidden_unit = 100
output_nums = 10
learning_rate = 0.001
train_steps = 50000
batch_size = 500
test_data_size = 10000
#日志目录(这里根据自己的目录修改)
logdir = 'D:/Develop_Software/Anaconda3/WorkDirectory/summary/mnist'
#导入mnist数据
mnist = input_data.read_data_sets('MNIST_data', one_hot = True)

 #全局训练步数
global_step = tf.Variable(0, name = 'global_step', trainable = False)
with tf.name_scope('input'):
 #输入数据
 with tf.name_scope('x'):
 x = tf.placeholder(
  dtype = tf.float32, shape = (None, IMAGE_PIXELS * IMAGE_PIXELS))
 #收集x图像的会总数据
 with tf.name_scope('x_summary'):
 shaped_image_batch = tf.reshape(
  tensor = x,
  shape = (-1, IMAGE_PIXELS, IMAGE_PIXELS, 1),
  name = 'shaped_image_batch')
 tf.summary.image(name = 'image_summary',
      tensor = shaped_image_batch,
      max_outputs = 10)
 with tf.name_scope('y_'):
 y_ = tf.placeholder(dtype = tf.float32, shape = (None, 10))

with tf.name_scope('hidden_layer'):
 with tf.name_scope('hidden_arg'):
 #隐层模型参数
 with tf.name_scope('hid_w'):
  
  hid_w = tf.Variable(
   tf.truncated_normal(shape = (IMAGE_PIXELS * IMAGE_PIXELS, hidden_unit)),
   name = 'hidden_w')
  #添加获取隐层权重统计值汇总数据的汇总操作
  tf.summary.histogram(name = 'weights', values = hid_w)
  with tf.name_scope('hid_b'):
  hid_b = tf.Variable(tf.zeros(shape = (1, hidden_unit), dtype = tf.float32),
       name = 'hidden_b')
 #隐层输出
 with tf.name_scope('relu'):
 hid_out = tf.nn.relu(tf.matmul(x, hid_w) + hid_b)
with tf.name_scope('softmax_layer'):
 with tf.name_scope('softmax_arg'):
 #softmax层参数
 with tf.name_scope('sm_w'):
  
  sm_w = tf.Variable(
   tf.truncated_normal(shape = (hidden_unit, output_nums)),
   name = 'softmax_w')
  #添加获取softmax层权重统计值汇总数据的汇总操作
  tf.summary.histogram(name = 'weights', values = sm_w)
  with tf.name_scope('sm_b'):
  sm_b = tf.Variable(tf.zeros(shape = (1, output_nums), dtype = tf.float32), 
       name = 'softmax_b')
 #softmax层的输出
 with tf.name_scope('softmax'):
 y = tf.nn.softmax(tf.matmul(hid_out, sm_w) + sm_b)
 #梯度裁剪,因为概率取值为[0, 1]为避免出现无意义的log(0),故将y值裁剪到[1e-10, 1]
 y_clip = tf.clip_by_value(y, 1.0e-10, 1 - 1.0e-5)
with tf.name_scope('cross_entropy'):
 #使用交叉熵代价函数
 cross_entropy = -tf.reduce_sum(y_ * tf.log(y_clip) + (1 - y_) * tf.log(1 - y_clip))
 #添加获取交叉熵的汇总操作
 tf.summary.scalar(name = 'cross_entropy', tensor = cross_entropy)
 
with tf.name_scope('train'):
 #若不使用同步训练机制,使用Adam优化器
 optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)
 #单步训练操作,
 train_op = optimizer.minimize(cross_entropy, global_step = global_step)
#加载测试数据
test_image = mnist.test.images
test_label = mnist.test.labels
test_feed = {x:test_image, y_:test_label}

with tf.name_scope('accuracy'):
 prediction = tf.equal(tf.argmax(input = y, axis = 1),
      tf.argmax(input = y_, axis = 1))
 accuracy = tf.reduce_mean(
  input_tensor = tf.cast(x = prediction, dtype = tf.float32))
#创建嵌入变量
embedding_var = tf.Variable(test_image, trainable = False, name = 'embedding')
saver = tf.train.Saver({'embedding':embedding_var})
#创建元数据文件,将MNIST图像测试集对应的标签写入文件
def CreateMedaDataFile():
 with open(logdir + '/metadata.tsv', 'w') as f:
 label = np.nonzero(test_label)[1]
 for i in range(test_data_size):
  f.write('%d\n' % label[i])
#创建投影配置参数
def CreateProjectorConfig():
 config = projector.ProjectorConfig()
 embeddings = config.embeddings.add()
 embeddings.tensor_name = 'embedding:0'
 embeddings.metadata_path = logdir + '/metadata.tsv'
 
 projector.visualize_embeddings(writer, config)
 #聚集汇总操作
merged = tf.summary.merge_all()
#创建会话的配置参数
sess_config = tf.ConfigProto(
 allow_soft_placement = True,
 log_device_placement = False)
#创建会话
with tf.Session(config = sess_config) as sess:
 #创建FileWriter实例
 writer = tf.summary.FileWriter(logdir = logdir, graph = sess.graph)
 #初始化全局变量
 sess.run(tf.global_variables_initializer())
 time_begin = time.time()
 print('Training begin time: %f' % time_begin)
 while True:
 #加载训练批数据
 batch_x, batch_y = mnist.train.next_batch(batch_size)
 train_feed = {x:batch_x, y_:batch_y}
 loss, _, summary= sess.run([cross_entropy, train_op, merged], feed_dict = train_feed)
 step = global_step.eval()
 #如果step为100的整数倍
 if step % 100 == 0:
  now = time.time()
  print('%f: global_step = %d, loss = %f' % (
   now, step, loss))
  #向事件文件中添加汇总数据
  writer.add_summary(summary = summary, global_step = step)
 #若大于等于训练总步数,退出训练
 if step >= train_steps:
  break
 time_end = time.time()
 print('Training end time: %f' % time_end)
 print('Training time: %f' % (time_end - time_begin))
 #测试模型精度
 test_accuracy = sess.run(accuracy, feed_dict = test_feed)
 print('accuracy: %f' % test_accuracy)
 
 saver.save(sess = sess, save_path = logdir + '/embedding_var.ckpt')
 CreateMedaDataFile()
 CreateProjectorConfig()
 #关闭FileWriter
 writer.close()

基于Tensorflow的MNIST手写数字识别分类

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用scrapy解析js示例
Jan 23 Python
Python中使用glob和rmtree删除目录子目录及所有文件的例子
Nov 21 Python
Python中的两个内置模块介绍
Apr 05 Python
python获取目录下所有文件的方法
Jun 01 Python
Python星号*与**用法分析
Feb 02 Python
Python多继承顺序实例分析
May 26 Python
python3 kmp 字符串匹配的方法
Jul 07 Python
解决Python一行输出不显示的问题
Dec 03 Python
Django对接支付宝实现支付宝充值金币功能示例
Dec 17 Python
Python解释器及PyCharm工具安装过程
Feb 26 Python
python异常处理、自定义异常、断言原理与用法分析
Mar 23 Python
解决jupyter notebook图片显示模糊和保存清晰图片的操作
Apr 24 Python
Kears 使用:通过回调函数保存最佳准确率下的模型操作
Jun 17 #Python
Python多线程threading创建及使用方法解析
Jun 17 #Python
Python偏函数Partial function使用方法实例详解
Jun 17 #Python
详解Python IO口多路复用
Jun 17 #Python
基于keras中的回调函数用法说明
Jun 17 #Python
Python学习之路安装pycharm的教程详解
Jun 17 #Python
Python闭包及装饰器运行原理解析
Jun 17 #Python
You might like
通过ODBC连接的SQL SERVER实例
2006/10/09 PHP
PHP 小心urldecode引发的SQL注入漏洞
2011/10/27 PHP
PHP连接MongoDB示例代码
2012/09/06 PHP
getJSON跨域SyntaxError问题分析
2014/08/07 PHP
PHP开发框架Laravel数据库操作方法总结
2014/09/03 PHP
php目录拷贝实现方法
2015/07/10 PHP
使用phpstorm和xdebug实现远程调试的方法
2015/12/29 PHP
javascript getElementsByClassName实现代码
2010/10/11 Javascript
JavaScript高级程序设计 扩展--关于动态原型
2010/11/09 Javascript
jQuery验证Checkbox是否选中的代码 推荐
2011/09/04 Javascript
javascript函数以及基础写法100多条实用整理
2013/01/13 Javascript
jquery实现带二级菜单的导航示例
2014/04/28 Javascript
chrome下img加载对height()的影响示例探讨
2014/05/26 Javascript
JavaScript中对象介绍
2014/12/31 Javascript
JSONP跨域请求
2017/03/02 Javascript
详解angular 中的自定义指令之详解API
2017/06/20 Javascript
Vue内容分发slot(全面解析)
2017/08/19 Javascript
JS脚本实现网页自动秒杀点击
2018/01/11 Javascript
javascript性能优化之分时函数的介绍
2018/03/28 Javascript
vee-validate vue 2.0自定义表单验证的实例
2018/08/28 Javascript
通过JavaScript下载文件到本地的方法(单文件)
2019/03/17 Javascript
浅析Vue 中的 render 函数
2020/02/28 Javascript
uniapp开发小程序实现滑动页面控制元素的显示和隐藏效果
2020/12/10 Javascript
python使用新浪微博api上传图片到微博示例
2014/01/10 Python
在Python编程过程中用单元测试法调试代码的介绍
2015/04/02 Python
在Python中使用M2Crypto模块实现AES加密的教程
2015/04/08 Python
Python的time模块中的常用方法整理
2015/06/18 Python
python+matplotlib实现礼盒柱状图实例代码
2018/01/16 Python
python实现对文件中图片生成带标签的txt文件方法
2018/04/27 Python
python 字典的打印实现
2019/09/26 Python
Python GUI编程学习笔记之tkinter界面布局显示详解
2020/03/30 Python
python-jwt用户认证食用教学的实现方法
2021/01/19 Python
Canvas在超级玛丽游戏中的应用详解
2021/02/06 HTML / CSS
医学生自我鉴定范文
2014/03/26 职场文书
家庭贫困证明
2014/09/23 职场文书
小学优秀学生评语
2014/12/29 职场文书