基于Tensorflow的MNIST手写数字识别分类


Posted in Python onJune 17, 2020

本文实例为大家分享了基于Tensorflow的MNIST手写数字识别分类的具体实现代码,供大家参考,具体内容如下

代码如下:

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.contrib.tensorboard.plugins import projector
import time

IMAGE_PIXELS = 28
hidden_unit = 100
output_nums = 10
learning_rate = 0.001
train_steps = 50000
batch_size = 500
test_data_size = 10000
#日志目录(这里根据自己的目录修改)
logdir = 'D:/Develop_Software/Anaconda3/WorkDirectory/summary/mnist'
#导入mnist数据
mnist = input_data.read_data_sets('MNIST_data', one_hot = True)

 #全局训练步数
global_step = tf.Variable(0, name = 'global_step', trainable = False)
with tf.name_scope('input'):
 #输入数据
 with tf.name_scope('x'):
 x = tf.placeholder(
  dtype = tf.float32, shape = (None, IMAGE_PIXELS * IMAGE_PIXELS))
 #收集x图像的会总数据
 with tf.name_scope('x_summary'):
 shaped_image_batch = tf.reshape(
  tensor = x,
  shape = (-1, IMAGE_PIXELS, IMAGE_PIXELS, 1),
  name = 'shaped_image_batch')
 tf.summary.image(name = 'image_summary',
      tensor = shaped_image_batch,
      max_outputs = 10)
 with tf.name_scope('y_'):
 y_ = tf.placeholder(dtype = tf.float32, shape = (None, 10))

with tf.name_scope('hidden_layer'):
 with tf.name_scope('hidden_arg'):
 #隐层模型参数
 with tf.name_scope('hid_w'):
  
  hid_w = tf.Variable(
   tf.truncated_normal(shape = (IMAGE_PIXELS * IMAGE_PIXELS, hidden_unit)),
   name = 'hidden_w')
  #添加获取隐层权重统计值汇总数据的汇总操作
  tf.summary.histogram(name = 'weights', values = hid_w)
  with tf.name_scope('hid_b'):
  hid_b = tf.Variable(tf.zeros(shape = (1, hidden_unit), dtype = tf.float32),
       name = 'hidden_b')
 #隐层输出
 with tf.name_scope('relu'):
 hid_out = tf.nn.relu(tf.matmul(x, hid_w) + hid_b)
with tf.name_scope('softmax_layer'):
 with tf.name_scope('softmax_arg'):
 #softmax层参数
 with tf.name_scope('sm_w'):
  
  sm_w = tf.Variable(
   tf.truncated_normal(shape = (hidden_unit, output_nums)),
   name = 'softmax_w')
  #添加获取softmax层权重统计值汇总数据的汇总操作
  tf.summary.histogram(name = 'weights', values = sm_w)
  with tf.name_scope('sm_b'):
  sm_b = tf.Variable(tf.zeros(shape = (1, output_nums), dtype = tf.float32), 
       name = 'softmax_b')
 #softmax层的输出
 with tf.name_scope('softmax'):
 y = tf.nn.softmax(tf.matmul(hid_out, sm_w) + sm_b)
 #梯度裁剪,因为概率取值为[0, 1]为避免出现无意义的log(0),故将y值裁剪到[1e-10, 1]
 y_clip = tf.clip_by_value(y, 1.0e-10, 1 - 1.0e-5)
with tf.name_scope('cross_entropy'):
 #使用交叉熵代价函数
 cross_entropy = -tf.reduce_sum(y_ * tf.log(y_clip) + (1 - y_) * tf.log(1 - y_clip))
 #添加获取交叉熵的汇总操作
 tf.summary.scalar(name = 'cross_entropy', tensor = cross_entropy)
 
with tf.name_scope('train'):
 #若不使用同步训练机制,使用Adam优化器
 optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)
 #单步训练操作,
 train_op = optimizer.minimize(cross_entropy, global_step = global_step)
#加载测试数据
test_image = mnist.test.images
test_label = mnist.test.labels
test_feed = {x:test_image, y_:test_label}

with tf.name_scope('accuracy'):
 prediction = tf.equal(tf.argmax(input = y, axis = 1),
      tf.argmax(input = y_, axis = 1))
 accuracy = tf.reduce_mean(
  input_tensor = tf.cast(x = prediction, dtype = tf.float32))
#创建嵌入变量
embedding_var = tf.Variable(test_image, trainable = False, name = 'embedding')
saver = tf.train.Saver({'embedding':embedding_var})
#创建元数据文件,将MNIST图像测试集对应的标签写入文件
def CreateMedaDataFile():
 with open(logdir + '/metadata.tsv', 'w') as f:
 label = np.nonzero(test_label)[1]
 for i in range(test_data_size):
  f.write('%d\n' % label[i])
#创建投影配置参数
def CreateProjectorConfig():
 config = projector.ProjectorConfig()
 embeddings = config.embeddings.add()
 embeddings.tensor_name = 'embedding:0'
 embeddings.metadata_path = logdir + '/metadata.tsv'
 
 projector.visualize_embeddings(writer, config)
 #聚集汇总操作
merged = tf.summary.merge_all()
#创建会话的配置参数
sess_config = tf.ConfigProto(
 allow_soft_placement = True,
 log_device_placement = False)
#创建会话
with tf.Session(config = sess_config) as sess:
 #创建FileWriter实例
 writer = tf.summary.FileWriter(logdir = logdir, graph = sess.graph)
 #初始化全局变量
 sess.run(tf.global_variables_initializer())
 time_begin = time.time()
 print('Training begin time: %f' % time_begin)
 while True:
 #加载训练批数据
 batch_x, batch_y = mnist.train.next_batch(batch_size)
 train_feed = {x:batch_x, y_:batch_y}
 loss, _, summary= sess.run([cross_entropy, train_op, merged], feed_dict = train_feed)
 step = global_step.eval()
 #如果step为100的整数倍
 if step % 100 == 0:
  now = time.time()
  print('%f: global_step = %d, loss = %f' % (
   now, step, loss))
  #向事件文件中添加汇总数据
  writer.add_summary(summary = summary, global_step = step)
 #若大于等于训练总步数,退出训练
 if step >= train_steps:
  break
 time_end = time.time()
 print('Training end time: %f' % time_end)
 print('Training time: %f' % (time_end - time_begin))
 #测试模型精度
 test_accuracy = sess.run(accuracy, feed_dict = test_feed)
 print('accuracy: %f' % test_accuracy)
 
 saver.save(sess = sess, save_path = logdir + '/embedding_var.ckpt')
 CreateMedaDataFile()
 CreateProjectorConfig()
 #关闭FileWriter
 writer.close()

基于Tensorflow的MNIST手写数字识别分类

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python基础之函数用法实例详解
Sep 10 Python
Python基于分水岭算法解决走迷宫游戏示例
Sep 26 Python
Python输出带颜色的字符串实例
Oct 10 Python
Python定时器实例代码
Nov 01 Python
Python中分支语句与循环语句实例详解
Sep 13 Python
Python global全局变量函数详解
Sep 18 Python
selenium + python 获取table数据的示例讲解
Oct 13 Python
django的settings中设置中文支持的实现
Apr 28 Python
python实现多线程端口扫描
Aug 31 Python
基于Python爬取爱奇艺资源过程解析
Mar 02 Python
tensorflow 20:搭网络,导出模型,运行模型的实例
May 26 Python
selenium设置浏览器为headless无头模式(Chrome和Firefox)
Jan 08 Python
Kears 使用:通过回调函数保存最佳准确率下的模型操作
Jun 17 #Python
Python多线程threading创建及使用方法解析
Jun 17 #Python
Python偏函数Partial function使用方法实例详解
Jun 17 #Python
详解Python IO口多路复用
Jun 17 #Python
基于keras中的回调函数用法说明
Jun 17 #Python
Python学习之路安装pycharm的教程详解
Jun 17 #Python
Python闭包及装饰器运行原理解析
Jun 17 #Python
You might like
Drupal7连接多个数据库及常见问题解决
2014/03/02 PHP
ThinkPHP的截取字符串函数无法显示省略号的解决方法
2014/06/25 PHP
event.srcElement 用法笔记e.target
2009/12/18 Javascript
基于jquery的返回顶部效果(兼容IE6)
2011/01/17 Javascript
jQuery写的日历(包括日历的样式及功能)
2013/04/23 Javascript
图标线性回归斜着移动到指定的位置
2013/08/16 Javascript
js中this用法实例详解
2015/05/05 Javascript
jQuery+css实现的换页标签栏效果
2016/01/27 Javascript
javascript常见数字进制转换实例分析
2016/04/21 Javascript
JS生成不重复的随机数组的简单实例
2016/07/10 Javascript
浅谈javascript alert和confirm的美化
2016/12/15 Javascript
JavaScript实现获取远程的html到当前页面中
2017/03/26 Javascript
jquery仿京东商品放大浏览页面
2017/06/06 jQuery
解决vue中对象属性改变视图不更新的问题
2018/02/23 Javascript
JS获取浏览器地址栏的多个参数值的任意值实例代码
2018/07/24 Javascript
js实现下拉框二级联动
2018/12/04 Javascript
layui 解决富文本框form表单提交为空的问题
2019/10/26 Javascript
vue打开新窗口并实现传参的图文实例
2021/03/04 Vue.js
利用python程序帮大家清理windows垃圾
2017/01/15 Python
基于python的多进程共享变量正确打开方式
2018/04/28 Python
Python寻找两个有序数组的中位数实例详解
2018/12/05 Python
浅析Python 中几种字符串格式化方法及其比较
2019/07/02 Python
wxpython自定义下拉列表框过程图解
2020/02/14 Python
python使用paramiko实现ssh的功能详解
2020/03/06 Python
Python装饰器的应用场景代码总结
2020/04/10 Python
Django中Q查询及Q()对象 F查询及F()对象用法
2020/07/09 Python
基于python实现图片转字符画代码实例
2020/09/04 Python
浅析Python中字符串的intern机制
2020/10/03 Python
python 基于opencv实现图像增强
2020/12/23 Python
基于HTML5 audio元素播放声音jQuery小插件
2011/05/11 HTML / CSS
html5 Canvas画图教程(1)—画图的基本常识
2013/01/09 HTML / CSS
Urban Decay官方网站:美国化妆品品牌
2020/06/04 全球购物
幼儿园小班评语大全
2014/04/17 职场文书
驾驶员培训方案
2014/05/01 职场文书
校园文化标语
2014/06/18 职场文书
Python实战之实现简易的学生选课系统
2021/05/25 Python