基于Tensorflow的MNIST手写数字识别分类


Posted in Python onJune 17, 2020

本文实例为大家分享了基于Tensorflow的MNIST手写数字识别分类的具体实现代码,供大家参考,具体内容如下

代码如下:

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.contrib.tensorboard.plugins import projector
import time

IMAGE_PIXELS = 28
hidden_unit = 100
output_nums = 10
learning_rate = 0.001
train_steps = 50000
batch_size = 500
test_data_size = 10000
#日志目录(这里根据自己的目录修改)
logdir = 'D:/Develop_Software/Anaconda3/WorkDirectory/summary/mnist'
#导入mnist数据
mnist = input_data.read_data_sets('MNIST_data', one_hot = True)

 #全局训练步数
global_step = tf.Variable(0, name = 'global_step', trainable = False)
with tf.name_scope('input'):
 #输入数据
 with tf.name_scope('x'):
 x = tf.placeholder(
  dtype = tf.float32, shape = (None, IMAGE_PIXELS * IMAGE_PIXELS))
 #收集x图像的会总数据
 with tf.name_scope('x_summary'):
 shaped_image_batch = tf.reshape(
  tensor = x,
  shape = (-1, IMAGE_PIXELS, IMAGE_PIXELS, 1),
  name = 'shaped_image_batch')
 tf.summary.image(name = 'image_summary',
      tensor = shaped_image_batch,
      max_outputs = 10)
 with tf.name_scope('y_'):
 y_ = tf.placeholder(dtype = tf.float32, shape = (None, 10))

with tf.name_scope('hidden_layer'):
 with tf.name_scope('hidden_arg'):
 #隐层模型参数
 with tf.name_scope('hid_w'):
  
  hid_w = tf.Variable(
   tf.truncated_normal(shape = (IMAGE_PIXELS * IMAGE_PIXELS, hidden_unit)),
   name = 'hidden_w')
  #添加获取隐层权重统计值汇总数据的汇总操作
  tf.summary.histogram(name = 'weights', values = hid_w)
  with tf.name_scope('hid_b'):
  hid_b = tf.Variable(tf.zeros(shape = (1, hidden_unit), dtype = tf.float32),
       name = 'hidden_b')
 #隐层输出
 with tf.name_scope('relu'):
 hid_out = tf.nn.relu(tf.matmul(x, hid_w) + hid_b)
with tf.name_scope('softmax_layer'):
 with tf.name_scope('softmax_arg'):
 #softmax层参数
 with tf.name_scope('sm_w'):
  
  sm_w = tf.Variable(
   tf.truncated_normal(shape = (hidden_unit, output_nums)),
   name = 'softmax_w')
  #添加获取softmax层权重统计值汇总数据的汇总操作
  tf.summary.histogram(name = 'weights', values = sm_w)
  with tf.name_scope('sm_b'):
  sm_b = tf.Variable(tf.zeros(shape = (1, output_nums), dtype = tf.float32), 
       name = 'softmax_b')
 #softmax层的输出
 with tf.name_scope('softmax'):
 y = tf.nn.softmax(tf.matmul(hid_out, sm_w) + sm_b)
 #梯度裁剪,因为概率取值为[0, 1]为避免出现无意义的log(0),故将y值裁剪到[1e-10, 1]
 y_clip = tf.clip_by_value(y, 1.0e-10, 1 - 1.0e-5)
with tf.name_scope('cross_entropy'):
 #使用交叉熵代价函数
 cross_entropy = -tf.reduce_sum(y_ * tf.log(y_clip) + (1 - y_) * tf.log(1 - y_clip))
 #添加获取交叉熵的汇总操作
 tf.summary.scalar(name = 'cross_entropy', tensor = cross_entropy)
 
with tf.name_scope('train'):
 #若不使用同步训练机制,使用Adam优化器
 optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)
 #单步训练操作,
 train_op = optimizer.minimize(cross_entropy, global_step = global_step)
#加载测试数据
test_image = mnist.test.images
test_label = mnist.test.labels
test_feed = {x:test_image, y_:test_label}

with tf.name_scope('accuracy'):
 prediction = tf.equal(tf.argmax(input = y, axis = 1),
      tf.argmax(input = y_, axis = 1))
 accuracy = tf.reduce_mean(
  input_tensor = tf.cast(x = prediction, dtype = tf.float32))
#创建嵌入变量
embedding_var = tf.Variable(test_image, trainable = False, name = 'embedding')
saver = tf.train.Saver({'embedding':embedding_var})
#创建元数据文件,将MNIST图像测试集对应的标签写入文件
def CreateMedaDataFile():
 with open(logdir + '/metadata.tsv', 'w') as f:
 label = np.nonzero(test_label)[1]
 for i in range(test_data_size):
  f.write('%d\n' % label[i])
#创建投影配置参数
def CreateProjectorConfig():
 config = projector.ProjectorConfig()
 embeddings = config.embeddings.add()
 embeddings.tensor_name = 'embedding:0'
 embeddings.metadata_path = logdir + '/metadata.tsv'
 
 projector.visualize_embeddings(writer, config)
 #聚集汇总操作
merged = tf.summary.merge_all()
#创建会话的配置参数
sess_config = tf.ConfigProto(
 allow_soft_placement = True,
 log_device_placement = False)
#创建会话
with tf.Session(config = sess_config) as sess:
 #创建FileWriter实例
 writer = tf.summary.FileWriter(logdir = logdir, graph = sess.graph)
 #初始化全局变量
 sess.run(tf.global_variables_initializer())
 time_begin = time.time()
 print('Training begin time: %f' % time_begin)
 while True:
 #加载训练批数据
 batch_x, batch_y = mnist.train.next_batch(batch_size)
 train_feed = {x:batch_x, y_:batch_y}
 loss, _, summary= sess.run([cross_entropy, train_op, merged], feed_dict = train_feed)
 step = global_step.eval()
 #如果step为100的整数倍
 if step % 100 == 0:
  now = time.time()
  print('%f: global_step = %d, loss = %f' % (
   now, step, loss))
  #向事件文件中添加汇总数据
  writer.add_summary(summary = summary, global_step = step)
 #若大于等于训练总步数,退出训练
 if step >= train_steps:
  break
 time_end = time.time()
 print('Training end time: %f' % time_end)
 print('Training time: %f' % (time_end - time_begin))
 #测试模型精度
 test_accuracy = sess.run(accuracy, feed_dict = test_feed)
 print('accuracy: %f' % test_accuracy)
 
 saver.save(sess = sess, save_path = logdir + '/embedding_var.ckpt')
 CreateMedaDataFile()
 CreateProjectorConfig()
 #关闭FileWriter
 writer.close()

基于Tensorflow的MNIST手写数字识别分类

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用python实现baidu hi自动登录的代码
Feb 10 Python
Python实现读取目录所有文件的文件名并保存到txt文件代码
Nov 22 Python
python连接数据库的方法
Oct 19 Python
解决已经安装requests,却依然提示No module named requests问题
May 18 Python
解决Mac下使用python的坑
Aug 13 Python
8段用于数据清洗Python代码(小结)
Oct 31 Python
python进程池实现的多进程文件夹copy器完整示例
Nov 27 Python
Anaconda+vscode+pytorch环境搭建过程详解
May 25 Python
完美解决keras 读取多个hdf5文件进行训练的问题
Jul 01 Python
编译 pycaffe时报错:fatal error: numpy/arrayobject.h没有那个文件或目录
Nov 29 Python
VSCode中autopep8无法运行问题解决方案(提示Error: Command failed,usage)
Mar 02 Python
python基于tkinter制作下班倒计时工具
Apr 28 Python
Kears 使用:通过回调函数保存最佳准确率下的模型操作
Jun 17 #Python
Python多线程threading创建及使用方法解析
Jun 17 #Python
Python偏函数Partial function使用方法实例详解
Jun 17 #Python
详解Python IO口多路复用
Jun 17 #Python
基于keras中的回调函数用法说明
Jun 17 #Python
Python学习之路安装pycharm的教程详解
Jun 17 #Python
Python闭包及装饰器运行原理解析
Jun 17 #Python
You might like
php抽象类用法实例分析
2015/07/07 PHP
PHP实现清除wordpress里恶意代码
2015/10/21 PHP
thinkPHP自定义类实现方法详解
2016/11/30 PHP
yii2.0整合阿里云oss删除单个文件的方法
2017/09/19 PHP
Laravel项目中timeAgo字段语言转换的改善方法示例
2019/09/16 PHP
ThinkPHP 5.1 跨域配置方法
2019/10/11 PHP
对联广告js flash激活
2006/10/19 Javascript
js 数据类型转换总结笔记
2011/01/17 Javascript
读jQuery之一(对象的组成)
2011/06/11 Javascript
jQuery插件实现屏蔽单个元素使用户无法点击
2013/04/12 Javascript
AngularJS HTML编译器介绍
2014/12/06 Javascript
浅谈String.valueOf()方法的使用
2016/06/06 Javascript
javascript时间差插件分享
2016/07/18 Javascript
Javascript中级语法快速入手
2016/07/30 Javascript
jQuery的ajax和遍历数组json实例代码
2016/08/01 Javascript
简单三步实现报表页面集成天气
2016/12/15 Javascript
JS常见疑难点分析之match,charAt,charCodeAt,map,search用法分析
2016/12/25 Javascript
jQueryeasyui 中如何使用datetimebox 取两个日期间相隔的天数
2017/06/13 jQuery
Vuejs 2.0 子组件访问/调用父组件的方法(示例代码)
2018/02/08 Javascript
基于Vue2.X的路由和钩子函数详解
2018/02/09 Javascript
vue form check 表单验证的实现代码
2018/12/09 Javascript
python3制作捧腹网段子页爬虫
2017/02/12 Python
python利用requests库模拟post请求时json的使用教程
2018/12/07 Python
python argparse模块通过后台传递参数实例
2020/04/20 Python
快速了解Python开发环境Spyder
2020/06/29 Python
python中的时区问题
2021/01/14 Python
美国中小型企业领先的办公家具供应商:Office Designs
2016/11/26 全球购物
法国女性内衣购物网站:Glamuse
2019/05/13 全球购物
捷科时代的软件测试笔试题
2015/11/09 面试题
syb养殖创业计划书
2014/01/09 职场文书
总账会计岗位职责
2014/03/13 职场文书
科长竞争上岗演讲稿
2014/05/12 职场文书
运动会三级跳加油稿
2015/07/21 职场文书
晶体管来复再生式二管收音机
2021/04/22 无线电
Win11筛选键导致键盘失灵怎么解决? Win11关闭筛选键的技巧
2022/04/08 数码科技
Tomcat项目启动失败的原因和解决办法
2022/04/20 Servers