基于Tensorflow的MNIST手写数字识别分类


Posted in Python onJune 17, 2020

本文实例为大家分享了基于Tensorflow的MNIST手写数字识别分类的具体实现代码,供大家参考,具体内容如下

代码如下:

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.contrib.tensorboard.plugins import projector
import time

IMAGE_PIXELS = 28
hidden_unit = 100
output_nums = 10
learning_rate = 0.001
train_steps = 50000
batch_size = 500
test_data_size = 10000
#日志目录(这里根据自己的目录修改)
logdir = 'D:/Develop_Software/Anaconda3/WorkDirectory/summary/mnist'
#导入mnist数据
mnist = input_data.read_data_sets('MNIST_data', one_hot = True)

 #全局训练步数
global_step = tf.Variable(0, name = 'global_step', trainable = False)
with tf.name_scope('input'):
 #输入数据
 with tf.name_scope('x'):
 x = tf.placeholder(
  dtype = tf.float32, shape = (None, IMAGE_PIXELS * IMAGE_PIXELS))
 #收集x图像的会总数据
 with tf.name_scope('x_summary'):
 shaped_image_batch = tf.reshape(
  tensor = x,
  shape = (-1, IMAGE_PIXELS, IMAGE_PIXELS, 1),
  name = 'shaped_image_batch')
 tf.summary.image(name = 'image_summary',
      tensor = shaped_image_batch,
      max_outputs = 10)
 with tf.name_scope('y_'):
 y_ = tf.placeholder(dtype = tf.float32, shape = (None, 10))

with tf.name_scope('hidden_layer'):
 with tf.name_scope('hidden_arg'):
 #隐层模型参数
 with tf.name_scope('hid_w'):
  
  hid_w = tf.Variable(
   tf.truncated_normal(shape = (IMAGE_PIXELS * IMAGE_PIXELS, hidden_unit)),
   name = 'hidden_w')
  #添加获取隐层权重统计值汇总数据的汇总操作
  tf.summary.histogram(name = 'weights', values = hid_w)
  with tf.name_scope('hid_b'):
  hid_b = tf.Variable(tf.zeros(shape = (1, hidden_unit), dtype = tf.float32),
       name = 'hidden_b')
 #隐层输出
 with tf.name_scope('relu'):
 hid_out = tf.nn.relu(tf.matmul(x, hid_w) + hid_b)
with tf.name_scope('softmax_layer'):
 with tf.name_scope('softmax_arg'):
 #softmax层参数
 with tf.name_scope('sm_w'):
  
  sm_w = tf.Variable(
   tf.truncated_normal(shape = (hidden_unit, output_nums)),
   name = 'softmax_w')
  #添加获取softmax层权重统计值汇总数据的汇总操作
  tf.summary.histogram(name = 'weights', values = sm_w)
  with tf.name_scope('sm_b'):
  sm_b = tf.Variable(tf.zeros(shape = (1, output_nums), dtype = tf.float32), 
       name = 'softmax_b')
 #softmax层的输出
 with tf.name_scope('softmax'):
 y = tf.nn.softmax(tf.matmul(hid_out, sm_w) + sm_b)
 #梯度裁剪,因为概率取值为[0, 1]为避免出现无意义的log(0),故将y值裁剪到[1e-10, 1]
 y_clip = tf.clip_by_value(y, 1.0e-10, 1 - 1.0e-5)
with tf.name_scope('cross_entropy'):
 #使用交叉熵代价函数
 cross_entropy = -tf.reduce_sum(y_ * tf.log(y_clip) + (1 - y_) * tf.log(1 - y_clip))
 #添加获取交叉熵的汇总操作
 tf.summary.scalar(name = 'cross_entropy', tensor = cross_entropy)
 
with tf.name_scope('train'):
 #若不使用同步训练机制,使用Adam优化器
 optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)
 #单步训练操作,
 train_op = optimizer.minimize(cross_entropy, global_step = global_step)
#加载测试数据
test_image = mnist.test.images
test_label = mnist.test.labels
test_feed = {x:test_image, y_:test_label}

with tf.name_scope('accuracy'):
 prediction = tf.equal(tf.argmax(input = y, axis = 1),
      tf.argmax(input = y_, axis = 1))
 accuracy = tf.reduce_mean(
  input_tensor = tf.cast(x = prediction, dtype = tf.float32))
#创建嵌入变量
embedding_var = tf.Variable(test_image, trainable = False, name = 'embedding')
saver = tf.train.Saver({'embedding':embedding_var})
#创建元数据文件,将MNIST图像测试集对应的标签写入文件
def CreateMedaDataFile():
 with open(logdir + '/metadata.tsv', 'w') as f:
 label = np.nonzero(test_label)[1]
 for i in range(test_data_size):
  f.write('%d\n' % label[i])
#创建投影配置参数
def CreateProjectorConfig():
 config = projector.ProjectorConfig()
 embeddings = config.embeddings.add()
 embeddings.tensor_name = 'embedding:0'
 embeddings.metadata_path = logdir + '/metadata.tsv'
 
 projector.visualize_embeddings(writer, config)
 #聚集汇总操作
merged = tf.summary.merge_all()
#创建会话的配置参数
sess_config = tf.ConfigProto(
 allow_soft_placement = True,
 log_device_placement = False)
#创建会话
with tf.Session(config = sess_config) as sess:
 #创建FileWriter实例
 writer = tf.summary.FileWriter(logdir = logdir, graph = sess.graph)
 #初始化全局变量
 sess.run(tf.global_variables_initializer())
 time_begin = time.time()
 print('Training begin time: %f' % time_begin)
 while True:
 #加载训练批数据
 batch_x, batch_y = mnist.train.next_batch(batch_size)
 train_feed = {x:batch_x, y_:batch_y}
 loss, _, summary= sess.run([cross_entropy, train_op, merged], feed_dict = train_feed)
 step = global_step.eval()
 #如果step为100的整数倍
 if step % 100 == 0:
  now = time.time()
  print('%f: global_step = %d, loss = %f' % (
   now, step, loss))
  #向事件文件中添加汇总数据
  writer.add_summary(summary = summary, global_step = step)
 #若大于等于训练总步数,退出训练
 if step >= train_steps:
  break
 time_end = time.time()
 print('Training end time: %f' % time_end)
 print('Training time: %f' % (time_end - time_begin))
 #测试模型精度
 test_accuracy = sess.run(accuracy, feed_dict = test_feed)
 print('accuracy: %f' % test_accuracy)
 
 saver.save(sess = sess, save_path = logdir + '/embedding_var.ckpt')
 CreateMedaDataFile()
 CreateProjectorConfig()
 #关闭FileWriter
 writer.close()

基于Tensorflow的MNIST手写数字识别分类

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python使用面向对象方式创建线程实现12306售票系统
Dec 24 Python
django开发之settings.py中变量的全局引用详解
Mar 29 Python
pygame游戏之旅 按钮上添加文字的方法
Nov 21 Python
Python使用itchat模块实现简单的微信控制电脑功能示例
Aug 26 Python
python nmap实现端口扫描器教程
May 28 Python
python基于三阶贝塞尔曲线的数据平滑算法
Dec 27 Python
Tkinter中复选菜单是否被选中的判断与设置方式
Mar 04 Python
基于python实现获取网页图片过程解析
May 11 Python
Python爬虫抓取指定网页图片代码实例
Jul 24 Python
Django自定义YamlField实现过程解析
Nov 11 Python
快速创建python 虚拟环境
Nov 28 Python
python中lower函数实现方法及用法讲解
Dec 23 Python
Kears 使用:通过回调函数保存最佳准确率下的模型操作
Jun 17 #Python
Python多线程threading创建及使用方法解析
Jun 17 #Python
Python偏函数Partial function使用方法实例详解
Jun 17 #Python
详解Python IO口多路复用
Jun 17 #Python
基于keras中的回调函数用法说明
Jun 17 #Python
Python学习之路安装pycharm的教程详解
Jun 17 #Python
Python闭包及装饰器运行原理解析
Jun 17 #Python
You might like
Symfony2开发之控制器用法实例分析
2016/02/05 PHP
详解Yii2高级版引入bootstrap.js的一个办法
2017/03/21 PHP
php微信公众号开发之微信企业付款给个人
2018/10/04 PHP
解决Laravel 使用insert插入数据,字段created_at为0000的问题
2019/10/11 PHP
利用Jquery实现可多选的下拉框
2014/02/21 Javascript
JavaScript使用Replace进行字符串替换的方法
2015/04/14 Javascript
jQuery+HTML5实现手机摇一摇换衣特效
2015/06/05 Javascript
不依赖Flash和任何JS库实现文本复制与剪切附源码下载
2015/10/09 Javascript
JS自定义函数对web前端上传的文件进行类型大小判断
2016/10/19 Javascript
js实时获取窗口大小变化的实例代码
2016/11/18 Javascript
Angular2学习笔记——详解路由器模型(Router)
2016/12/02 Javascript
jQuery实现select模糊查询(反射机制)
2017/01/14 Javascript
JavaScript中for循环的几种写法与效率总结
2017/02/03 Javascript
js获取地址栏中传递的参数(两种方法)
2017/02/08 Javascript
浅谈AngularJs 双向绑定原理(数据绑定机制)
2017/12/07 Javascript
浅谈Vue.use的使用
2018/08/29 Javascript
vue 实现websocket发送消息并实时接收消息
2019/12/09 Javascript
vue实现计算器功能
2020/02/22 Javascript
详解elementUI中input框无法输入的问题
2020/04/27 Javascript
基于JS实现快速读取TXT文件
2020/08/25 Javascript
VUE前端从后台请求过来的数据进行转换数据结构操作
2020/11/11 Javascript
利用Python中的mock库对Python代码进行模拟测试
2015/04/16 Python
在Django框架中编写Contact表单的教程
2015/07/17 Python
Python函数中*args和**kwargs来传递变长参数的用法
2016/01/26 Python
pycharm安装和首次使用教程
2018/08/27 Python
浅谈Python里面None True False之间的区别
2020/07/09 Python
Notino罗马尼亚网站:购买香水和化妆品
2019/07/20 全球购物
工程部主管岗位职责
2013/11/17 职场文书
艺校音乐专业自我鉴定范文
2014/03/01 职场文书
给学校的建议书范文
2014/05/15 职场文书
学校通报表扬范文
2015/05/04 职场文书
2015年班组建设工作总结
2015/05/13 职场文书
2019年英语版感谢信(8篇)
2019/09/29 职场文书
Vue实现下拉加载更多
2021/05/09 Vue.js
Python中requests库的用法详解
2022/06/05 Python
win server2012 r2服务器共享文件夹如何设置
2022/06/21 Servers