基于Tensorflow的MNIST手写数字识别分类


Posted in Python onJune 17, 2020

本文实例为大家分享了基于Tensorflow的MNIST手写数字识别分类的具体实现代码,供大家参考,具体内容如下

代码如下:

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.contrib.tensorboard.plugins import projector
import time

IMAGE_PIXELS = 28
hidden_unit = 100
output_nums = 10
learning_rate = 0.001
train_steps = 50000
batch_size = 500
test_data_size = 10000
#日志目录(这里根据自己的目录修改)
logdir = 'D:/Develop_Software/Anaconda3/WorkDirectory/summary/mnist'
#导入mnist数据
mnist = input_data.read_data_sets('MNIST_data', one_hot = True)

 #全局训练步数
global_step = tf.Variable(0, name = 'global_step', trainable = False)
with tf.name_scope('input'):
 #输入数据
 with tf.name_scope('x'):
 x = tf.placeholder(
  dtype = tf.float32, shape = (None, IMAGE_PIXELS * IMAGE_PIXELS))
 #收集x图像的会总数据
 with tf.name_scope('x_summary'):
 shaped_image_batch = tf.reshape(
  tensor = x,
  shape = (-1, IMAGE_PIXELS, IMAGE_PIXELS, 1),
  name = 'shaped_image_batch')
 tf.summary.image(name = 'image_summary',
      tensor = shaped_image_batch,
      max_outputs = 10)
 with tf.name_scope('y_'):
 y_ = tf.placeholder(dtype = tf.float32, shape = (None, 10))

with tf.name_scope('hidden_layer'):
 with tf.name_scope('hidden_arg'):
 #隐层模型参数
 with tf.name_scope('hid_w'):
  
  hid_w = tf.Variable(
   tf.truncated_normal(shape = (IMAGE_PIXELS * IMAGE_PIXELS, hidden_unit)),
   name = 'hidden_w')
  #添加获取隐层权重统计值汇总数据的汇总操作
  tf.summary.histogram(name = 'weights', values = hid_w)
  with tf.name_scope('hid_b'):
  hid_b = tf.Variable(tf.zeros(shape = (1, hidden_unit), dtype = tf.float32),
       name = 'hidden_b')
 #隐层输出
 with tf.name_scope('relu'):
 hid_out = tf.nn.relu(tf.matmul(x, hid_w) + hid_b)
with tf.name_scope('softmax_layer'):
 with tf.name_scope('softmax_arg'):
 #softmax层参数
 with tf.name_scope('sm_w'):
  
  sm_w = tf.Variable(
   tf.truncated_normal(shape = (hidden_unit, output_nums)),
   name = 'softmax_w')
  #添加获取softmax层权重统计值汇总数据的汇总操作
  tf.summary.histogram(name = 'weights', values = sm_w)
  with tf.name_scope('sm_b'):
  sm_b = tf.Variable(tf.zeros(shape = (1, output_nums), dtype = tf.float32), 
       name = 'softmax_b')
 #softmax层的输出
 with tf.name_scope('softmax'):
 y = tf.nn.softmax(tf.matmul(hid_out, sm_w) + sm_b)
 #梯度裁剪,因为概率取值为[0, 1]为避免出现无意义的log(0),故将y值裁剪到[1e-10, 1]
 y_clip = tf.clip_by_value(y, 1.0e-10, 1 - 1.0e-5)
with tf.name_scope('cross_entropy'):
 #使用交叉熵代价函数
 cross_entropy = -tf.reduce_sum(y_ * tf.log(y_clip) + (1 - y_) * tf.log(1 - y_clip))
 #添加获取交叉熵的汇总操作
 tf.summary.scalar(name = 'cross_entropy', tensor = cross_entropy)
 
with tf.name_scope('train'):
 #若不使用同步训练机制,使用Adam优化器
 optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)
 #单步训练操作,
 train_op = optimizer.minimize(cross_entropy, global_step = global_step)
#加载测试数据
test_image = mnist.test.images
test_label = mnist.test.labels
test_feed = {x:test_image, y_:test_label}

with tf.name_scope('accuracy'):
 prediction = tf.equal(tf.argmax(input = y, axis = 1),
      tf.argmax(input = y_, axis = 1))
 accuracy = tf.reduce_mean(
  input_tensor = tf.cast(x = prediction, dtype = tf.float32))
#创建嵌入变量
embedding_var = tf.Variable(test_image, trainable = False, name = 'embedding')
saver = tf.train.Saver({'embedding':embedding_var})
#创建元数据文件,将MNIST图像测试集对应的标签写入文件
def CreateMedaDataFile():
 with open(logdir + '/metadata.tsv', 'w') as f:
 label = np.nonzero(test_label)[1]
 for i in range(test_data_size):
  f.write('%d\n' % label[i])
#创建投影配置参数
def CreateProjectorConfig():
 config = projector.ProjectorConfig()
 embeddings = config.embeddings.add()
 embeddings.tensor_name = 'embedding:0'
 embeddings.metadata_path = logdir + '/metadata.tsv'
 
 projector.visualize_embeddings(writer, config)
 #聚集汇总操作
merged = tf.summary.merge_all()
#创建会话的配置参数
sess_config = tf.ConfigProto(
 allow_soft_placement = True,
 log_device_placement = False)
#创建会话
with tf.Session(config = sess_config) as sess:
 #创建FileWriter实例
 writer = tf.summary.FileWriter(logdir = logdir, graph = sess.graph)
 #初始化全局变量
 sess.run(tf.global_variables_initializer())
 time_begin = time.time()
 print('Training begin time: %f' % time_begin)
 while True:
 #加载训练批数据
 batch_x, batch_y = mnist.train.next_batch(batch_size)
 train_feed = {x:batch_x, y_:batch_y}
 loss, _, summary= sess.run([cross_entropy, train_op, merged], feed_dict = train_feed)
 step = global_step.eval()
 #如果step为100的整数倍
 if step % 100 == 0:
  now = time.time()
  print('%f: global_step = %d, loss = %f' % (
   now, step, loss))
  #向事件文件中添加汇总数据
  writer.add_summary(summary = summary, global_step = step)
 #若大于等于训练总步数,退出训练
 if step >= train_steps:
  break
 time_end = time.time()
 print('Training end time: %f' % time_end)
 print('Training time: %f' % (time_end - time_begin))
 #测试模型精度
 test_accuracy = sess.run(accuracy, feed_dict = test_feed)
 print('accuracy: %f' % test_accuracy)
 
 saver.save(sess = sess, save_path = logdir + '/embedding_var.ckpt')
 CreateMedaDataFile()
 CreateProjectorConfig()
 #关闭FileWriter
 writer.close()

基于Tensorflow的MNIST手写数字识别分类

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 获取et和excel的版本号
Apr 09 Python
用Python输出一个杨辉三角的例子
Jun 13 Python
python 简单的绘图工具turtle使用详解
Jun 21 Python
Python探索之SocketServer详解
Oct 28 Python
Python读取视频的两种方法(imageio和cv2)
Apr 15 Python
Python实现的多进程和多线程功能示例
May 29 Python
Python使用pymysql从MySQL数据库中读出数据的方法
Jul 25 Python
Python机器学习之scikit-learn库中KNN算法的封装与使用方法
Dec 14 Python
pyinstaller 3.6版本通过pip安装失败的解决办法(推荐)
Jan 18 Python
python3+opencv 使用灰度直方图来判断图片的亮暗操作
Jun 02 Python
python和php哪个更适合写爬虫
Jun 22 Python
python3爬虫中多线程的优势总结
Nov 24 Python
Kears 使用:通过回调函数保存最佳准确率下的模型操作
Jun 17 #Python
Python多线程threading创建及使用方法解析
Jun 17 #Python
Python偏函数Partial function使用方法实例详解
Jun 17 #Python
详解Python IO口多路复用
Jun 17 #Python
基于keras中的回调函数用法说明
Jun 17 #Python
Python学习之路安装pycharm的教程详解
Jun 17 #Python
Python闭包及装饰器运行原理解析
Jun 17 #Python
You might like
Mysql的Root密码忘记,查看或修改的解决方法(图文介绍)
2013/06/14 PHP
php实现数组中出现次数超过一半的数字的统计方法
2018/10/14 PHP
javascript同步Import,同步调用外部js的方法
2008/07/08 Javascript
使用JS 清空File控件的路径值
2013/07/08 Javascript
JS 对输入框进行限制(常用的都有)
2013/07/30 Javascript
js 走马灯简单实例
2013/11/21 Javascript
JavaScript控制各种浏览器全屏模式的方法、属性和事件介绍
2014/04/03 Javascript
jquery ajax局部加载方法详解(实现代码)
2016/05/12 Javascript
node.js cookie-parser 中间件介绍
2016/06/06 Javascript
JavaScript如何实现跨域请求
2016/08/05 Javascript
js实现可输入可选择的select下拉框
2016/12/21 Javascript
详解js产生对象的3种基本方式(工厂模式,构造函数模式,原型模式)
2017/01/09 Javascript
利用JS实现scroll自定义滚动效果详解
2017/10/17 Javascript
vuejs选中当前样式active的实例
2018/08/22 Javascript
微信小程序ibeacon三点定位详解
2018/10/31 Javascript
微信小程序实现搜索历史功能
2020/03/26 Javascript
浅谈vue中组件绑定事件时是否加.native
2019/11/09 Javascript
python实现数值积分的Simpson方法实例分析
2015/06/05 Python
Django Highcharts制作图表
2016/08/27 Python
让Python更加充分的使用Sqlite3
2017/12/11 Python
python ftp 按目录结构上传下载的实现代码
2018/09/12 Python
使用python实现快速搭建简易的FTP服务器
2018/09/12 Python
Django如何防止定时任务并发浅析
2019/05/14 Python
Python实现微信机器人的方法
2019/09/06 Python
python数据类型可变不可变知识点总结
2020/03/06 Python
python sleep和wait对比总结
2021/02/03 Python
海淘零差价,宝贝全球购: 宝贝格子
2016/08/24 全球购物
美国演唱会和体育门票购买网站:Ticketnetwork
2018/10/19 全球购物
应用电子技术专业个人求职信
2013/09/21 职场文书
委托书怎么写
2014/07/31 职场文书
超市创业计划书
2014/09/15 职场文书
2014最新实习证明模板
2014/10/02 职场文书
如何用PHP实现多线程编程
2021/05/26 PHP
Python 阶乘详解
2021/10/05 Python
Win11任务栏无法正常显示 资源管理器不停重启的解决方法
2022/07/07 数码科技
SQLServer常见数学函数梳理总结
2022/08/05 MySQL