基于Tensorflow的MNIST手写数字识别分类


Posted in Python onJune 17, 2020

本文实例为大家分享了基于Tensorflow的MNIST手写数字识别分类的具体实现代码,供大家参考,具体内容如下

代码如下:

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.contrib.tensorboard.plugins import projector
import time

IMAGE_PIXELS = 28
hidden_unit = 100
output_nums = 10
learning_rate = 0.001
train_steps = 50000
batch_size = 500
test_data_size = 10000
#日志目录(这里根据自己的目录修改)
logdir = 'D:/Develop_Software/Anaconda3/WorkDirectory/summary/mnist'
#导入mnist数据
mnist = input_data.read_data_sets('MNIST_data', one_hot = True)

 #全局训练步数
global_step = tf.Variable(0, name = 'global_step', trainable = False)
with tf.name_scope('input'):
 #输入数据
 with tf.name_scope('x'):
 x = tf.placeholder(
  dtype = tf.float32, shape = (None, IMAGE_PIXELS * IMAGE_PIXELS))
 #收集x图像的会总数据
 with tf.name_scope('x_summary'):
 shaped_image_batch = tf.reshape(
  tensor = x,
  shape = (-1, IMAGE_PIXELS, IMAGE_PIXELS, 1),
  name = 'shaped_image_batch')
 tf.summary.image(name = 'image_summary',
      tensor = shaped_image_batch,
      max_outputs = 10)
 with tf.name_scope('y_'):
 y_ = tf.placeholder(dtype = tf.float32, shape = (None, 10))

with tf.name_scope('hidden_layer'):
 with tf.name_scope('hidden_arg'):
 #隐层模型参数
 with tf.name_scope('hid_w'):
  
  hid_w = tf.Variable(
   tf.truncated_normal(shape = (IMAGE_PIXELS * IMAGE_PIXELS, hidden_unit)),
   name = 'hidden_w')
  #添加获取隐层权重统计值汇总数据的汇总操作
  tf.summary.histogram(name = 'weights', values = hid_w)
  with tf.name_scope('hid_b'):
  hid_b = tf.Variable(tf.zeros(shape = (1, hidden_unit), dtype = tf.float32),
       name = 'hidden_b')
 #隐层输出
 with tf.name_scope('relu'):
 hid_out = tf.nn.relu(tf.matmul(x, hid_w) + hid_b)
with tf.name_scope('softmax_layer'):
 with tf.name_scope('softmax_arg'):
 #softmax层参数
 with tf.name_scope('sm_w'):
  
  sm_w = tf.Variable(
   tf.truncated_normal(shape = (hidden_unit, output_nums)),
   name = 'softmax_w')
  #添加获取softmax层权重统计值汇总数据的汇总操作
  tf.summary.histogram(name = 'weights', values = sm_w)
  with tf.name_scope('sm_b'):
  sm_b = tf.Variable(tf.zeros(shape = (1, output_nums), dtype = tf.float32), 
       name = 'softmax_b')
 #softmax层的输出
 with tf.name_scope('softmax'):
 y = tf.nn.softmax(tf.matmul(hid_out, sm_w) + sm_b)
 #梯度裁剪,因为概率取值为[0, 1]为避免出现无意义的log(0),故将y值裁剪到[1e-10, 1]
 y_clip = tf.clip_by_value(y, 1.0e-10, 1 - 1.0e-5)
with tf.name_scope('cross_entropy'):
 #使用交叉熵代价函数
 cross_entropy = -tf.reduce_sum(y_ * tf.log(y_clip) + (1 - y_) * tf.log(1 - y_clip))
 #添加获取交叉熵的汇总操作
 tf.summary.scalar(name = 'cross_entropy', tensor = cross_entropy)
 
with tf.name_scope('train'):
 #若不使用同步训练机制,使用Adam优化器
 optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)
 #单步训练操作,
 train_op = optimizer.minimize(cross_entropy, global_step = global_step)
#加载测试数据
test_image = mnist.test.images
test_label = mnist.test.labels
test_feed = {x:test_image, y_:test_label}

with tf.name_scope('accuracy'):
 prediction = tf.equal(tf.argmax(input = y, axis = 1),
      tf.argmax(input = y_, axis = 1))
 accuracy = tf.reduce_mean(
  input_tensor = tf.cast(x = prediction, dtype = tf.float32))
#创建嵌入变量
embedding_var = tf.Variable(test_image, trainable = False, name = 'embedding')
saver = tf.train.Saver({'embedding':embedding_var})
#创建元数据文件,将MNIST图像测试集对应的标签写入文件
def CreateMedaDataFile():
 with open(logdir + '/metadata.tsv', 'w') as f:
 label = np.nonzero(test_label)[1]
 for i in range(test_data_size):
  f.write('%d\n' % label[i])
#创建投影配置参数
def CreateProjectorConfig():
 config = projector.ProjectorConfig()
 embeddings = config.embeddings.add()
 embeddings.tensor_name = 'embedding:0'
 embeddings.metadata_path = logdir + '/metadata.tsv'
 
 projector.visualize_embeddings(writer, config)
 #聚集汇总操作
merged = tf.summary.merge_all()
#创建会话的配置参数
sess_config = tf.ConfigProto(
 allow_soft_placement = True,
 log_device_placement = False)
#创建会话
with tf.Session(config = sess_config) as sess:
 #创建FileWriter实例
 writer = tf.summary.FileWriter(logdir = logdir, graph = sess.graph)
 #初始化全局变量
 sess.run(tf.global_variables_initializer())
 time_begin = time.time()
 print('Training begin time: %f' % time_begin)
 while True:
 #加载训练批数据
 batch_x, batch_y = mnist.train.next_batch(batch_size)
 train_feed = {x:batch_x, y_:batch_y}
 loss, _, summary= sess.run([cross_entropy, train_op, merged], feed_dict = train_feed)
 step = global_step.eval()
 #如果step为100的整数倍
 if step % 100 == 0:
  now = time.time()
  print('%f: global_step = %d, loss = %f' % (
   now, step, loss))
  #向事件文件中添加汇总数据
  writer.add_summary(summary = summary, global_step = step)
 #若大于等于训练总步数,退出训练
 if step >= train_steps:
  break
 time_end = time.time()
 print('Training end time: %f' % time_end)
 print('Training time: %f' % (time_end - time_begin))
 #测试模型精度
 test_accuracy = sess.run(accuracy, feed_dict = test_feed)
 print('accuracy: %f' % test_accuracy)
 
 saver.save(sess = sess, save_path = logdir + '/embedding_var.ckpt')
 CreateMedaDataFile()
 CreateProjectorConfig()
 #关闭FileWriter
 writer.close()

基于Tensorflow的MNIST手写数字识别分类

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
pymongo给mongodb创建索引的简单实现方法
May 06 Python
教你用Python脚本快速为iOS10生成图标和截屏
Sep 22 Python
pandas数值计算与排序方法
Apr 12 Python
Python3.6实现连接mysql或mariadb的方法分析
May 18 Python
python读取excel指定列数据并写入到新的excel方法
Jul 10 Python
python pexpect ssh 远程登录服务器的方法
Feb 14 Python
Python递归求出列表(包括列表中的子列表)的最大值实例
Feb 27 Python
python不相等的两个字符串的 if 条件判断为True详解
Mar 12 Python
python使用列表的最佳方案
Aug 12 Python
Django创建一个后台的基本步骤记录
Oct 02 Python
python tkinter Entry控件的焦点移动操作
May 22 Python
浅谈tf.train.Saver()与tf.train.import_meta_graph的要点
May 26 Python
Kears 使用:通过回调函数保存最佳准确率下的模型操作
Jun 17 #Python
Python多线程threading创建及使用方法解析
Jun 17 #Python
Python偏函数Partial function使用方法实例详解
Jun 17 #Python
详解Python IO口多路复用
Jun 17 #Python
基于keras中的回调函数用法说明
Jun 17 #Python
Python学习之路安装pycharm的教程详解
Jun 17 #Python
Python闭包及装饰器运行原理解析
Jun 17 #Python
You might like
mysql_connect localhost和127.0.0.1的区别(网络层阐述)
2015/03/26 PHP
PHP对象的浅复制与深复制的实例详解
2017/10/26 PHP
改进版通过Json对象实现深复制的方法
2012/10/24 Javascript
jquery select 设置默认选中的示例代码
2014/02/07 Javascript
兼容所有浏览器的js复制插件Zero使用介绍
2014/03/19 Javascript
jQuery延迟加载图片插件Lazy Load使用指南
2015/03/25 Javascript
js实现获取当前时间是本月第几周的方法
2015/08/11 Javascript
JavaScript动态创建div等元素实例讲解
2016/01/06 Javascript
ArtEditor富文本编辑器增加表单提交功能
2016/04/18 Javascript
AngularJS入门教程引导程序
2016/08/18 Javascript
jQuery实现的自定义弹出层效果实例详解
2016/09/04 Javascript
探究Vue.js 2.0新增的虚拟DOM
2016/10/20 Javascript
js实现交通灯效果
2017/01/13 Javascript
JS数组操作之增删改查的简单实现
2017/08/21 Javascript
快速将Vue项目升级到webpack3的方法步骤
2017/09/14 Javascript
说说如何利用 Node.js 代理解决跨域问题
2019/04/22 Javascript
vuex入门最详细整理
2020/03/04 Javascript
JavaScript 链表定义与使用方法示例
2020/04/28 Javascript
vue 子组件修改data或调用操作
2020/08/07 Javascript
v-slot和slot、slot-scope之间相互替换实例
2020/09/04 Javascript
[02:08]什么藏在DOTA2 TI9“小紫本”里?斧王历险记告诉你!
2019/05/17 DOTA
在Python中调用ggplot的三种方法
2015/04/08 Python
python+Splinter实现12306抢票功能
2018/09/25 Python
详解Python是如何实现issubclass的
2019/07/24 Python
python单例模式原理与创建方法实例分析
2019/10/26 Python
使用Python爬虫库BeautifulSoup遍历文档树并对标签进行操作详解
2020/01/25 Python
完美解决Pycharm中matplotlib画图中文乱码问题
2021/01/11 Python
英国天然保健品网站:Simply Supplements
2017/03/22 全球购物
大学新学期计划书
2014/04/28 职场文书
民事诉讼代理委托书
2014/10/08 职场文书
运动会加油稿30字
2015/07/21 职场文书
2016见义勇为事迹材料汇总
2016/03/01 职场文书
html5移动端禁止长按图片保存的实现
2021/04/20 HTML / CSS
sql查询结果列拼接成逗号分隔的字符串方法
2021/05/25 SQL Server
Centos系统通过Docker安装并搭建MongoDB数据库
2022/04/12 MongoDB
Tomcat弱口令复现及利用
2022/05/06 Servers