Tensorflow训练MNIST手写数字识别模型


Posted in Python onFebruary 13, 2020

本文实例为大家分享了Tensorflow训练MNIST手写数字识别模型的具体代码,供大家参考,具体内容如下

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
 
INPUT_NODE = 784  # 输入层节点=图片像素=28x28=784
OUTPUT_NODE = 10  # 输出层节点数=图片类别数目
 
LAYER1_NODE = 500  # 隐藏层节点数,只有一个隐藏层
BATCH_SIZE = 100  # 一个训练包中的数据个数,数字越小
          # 越接近随机梯度下降,越大越接近梯度下降
 
LEARNING_RATE_BASE = 0.8   # 基础学习率
LEARNING_RATE_DECAY = 0.99  # 学习率衰减率
 
REGULARIZATION_RATE = 0.0001  # 正则化项系数
TRAINING_STEPS = 30000     # 训练轮数
MOVING_AVG_DECAY = 0.99    # 滑动平均衰减率
 
# 定义一个辅助函数,给定神经网络的输入和所有参数,计算神经网络的前向传播结果
def inference(input_tensor, avg_class, weights1, biases1,
       weights2, biases2):
 
 # 当没有提供滑动平均类时,直接使用参数当前取值
 if avg_class == None:
  # 计算隐藏层前向传播结果
  layer1 = tf.nn.relu(tf.matmul(input_tensor, weights1) + biases1)
  # 计算输出层前向传播结果
  return tf.matmul(layer1, weights2) + biases2
 else:
  # 首先计算变量的滑动平均值,然后计算前向传播结果
  layer1 = tf.nn.relu(
    tf.matmul(input_tensor, avg_class.average(weights1)) +
    avg_class.average(biases1))
  
  return tf.matmul(
    layer1, avg_class.average(weights2)) + avg_class.average(biases2)
 
# 训练模型的过程
def train(mnist):
 x = tf.placeholder(tf.float32, [None, INPUT_NODE], name='x-input')
 y_ = tf.placeholder(tf.float32, [None, OUTPUT_NODE], name='y-input')
 
 # 生成隐藏层参数
 weights1 = tf.Variable(
   tf.truncated_normal([INPUT_NODE, LAYER1_NODE], stddev=0.1))
 biases1 = tf.Variable(tf.constant(0.1, shape=[LAYER1_NODE]))
 
 # 生成输出层参数
 weights2 = tf.Variable(
   tf.truncated_normal([LAYER1_NODE, OUTPUT_NODE], stddev=0.1))
 biases2 = tf.Variable(tf.constant(0.1, shape=[OUTPUT_NODE]))
 
 # 计算前向传播结果,不使用参数滑动平均值 avg_class=None
 y = inference(x, None, weights1, biases1, weights2, biases2)
 
 # 定义训练轮数变量,指定为不可训练
 global_step = tf.Variable(0, trainable=False)
 
 # 给定滑动平均衰减率和训练轮数的变量,初始化滑动平均类
 variable_avgs = tf.train.ExponentialMovingAverage(
   MOVING_AVG_DECAY, global_step)
 
 # 在所有代表神经网络参数的可训练变量上使用滑动平均
 variables_avgs_op = variable_avgs.apply(tf.trainable_variables())
 
 # 计算使用滑动平均值后的前向传播结果
 avg_y = inference(x, variable_avgs, weights1, biases1, weights2, biases2)
 
 # 计算交叉熵作为损失函数
 cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
   logits=y, labels=tf.argmax(y_, 1))
 cross_entropy_mean = tf.reduce_mean(cross_entropy)
 
 # 计算L2正则化损失函数
 regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
 regularization = regularizer(weights1) + regularizer(weights2)
 
 loss = cross_entropy_mean + regularization
 
 # 设置指数衰减的学习率
 learning_rate = tf.train.exponential_decay(
   LEARNING_RATE_BASE,
   global_step,              # 当前迭代轮数
   mnist.train.num_examples / BATCH_SIZE, # 过完所有训练数据的迭代次数
   LEARNING_RATE_DECAY)
 
 
 # 优化损失函数
 train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(
   loss, global_step=global_step)
 
 # 反向传播同时更新神经网络参数及其滑动平均值
 with tf.control_dependencies([train_step, variables_avgs_op]):
  train_op = tf.no_op(name='train')
 
 # 检验使用了滑动平均模型的神经网络前向传播结果是否正确
 correct_prediction = tf.equal(tf.argmax(avg_y, 1), tf.argmax(y_, 1))
 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
 
 
 # 初始化会话并开始训练
 with tf.Session() as sess:
  tf.global_variables_initializer().run()
  
  # 准备验证数据,用于判断停止条件和训练效果
  validate_feed = {x: mnist.validation.images,
          y_: mnist.validation.labels}
  
  # 准备测试数据,用于模型优劣的最后评价标准
  test_feed = {x: mnist.test.images, y_: mnist.test.labels}
  
  # 迭代训练神经网络
  for i in range(TRAINING_STEPS):
   if i%1000 == 0:
    validate_acc = sess.run(accuracy, feed_dict=validate_feed)
    print("After %d training step(s), validation accuracy using average " 
       "model is %g " % (i, validate_acc))
    
   xs, ys = mnist.train.next_batch(BATCH_SIZE)
   sess.run(train_op, feed_dict={x: xs, y_: ys})
  
  # 训练结束后在测试集上检测模型的最终正确率
  test_acc = sess.run(accuracy, feed_dict=test_feed)
  print("After %d training steps, test accuracy using average model "
     "is %g " % (TRAINING_STEPS, test_acc))
  
  
# 主程序入口
def main(argv=None):
 mnist = input_data.read_data_sets("/tmp/data", one_hot=True)
 train(mnist)
 
# Tensorflow主程序入口
if __name__ == '__main__':
 tf.app.run()

输出结果如下:

Extracting /tmp/data/train-images-idx3-ubyte.gz
Extracting /tmp/data/train-labels-idx1-ubyte.gz
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz
After 0 training step(s), validation accuracy using average model is 0.0462 
After 1000 training step(s), validation accuracy using average model is 0.9784 
After 2000 training step(s), validation accuracy using average model is 0.9806 
After 3000 training step(s), validation accuracy using average model is 0.9798 
After 4000 training step(s), validation accuracy using average model is 0.9814 
After 5000 training step(s), validation accuracy using average model is 0.9826 
After 6000 training step(s), validation accuracy using average model is 0.9828 
After 7000 training step(s), validation accuracy using average model is 0.9832 
After 8000 training step(s), validation accuracy using average model is 0.9838 
After 9000 training step(s), validation accuracy using average model is 0.983 
After 10000 training step(s), validation accuracy using average model is 0.9836 
After 11000 training step(s), validation accuracy using average model is 0.9822 
After 12000 training step(s), validation accuracy using average model is 0.983 
After 13000 training step(s), validation accuracy using average model is 0.983 
After 14000 training step(s), validation accuracy using average model is 0.9844 
After 15000 training step(s), validation accuracy using average model is 0.9832 
After 16000 training step(s), validation accuracy using average model is 0.9844 
After 17000 training step(s), validation accuracy using average model is 0.9842 
After 18000 training step(s), validation accuracy using average model is 0.9842 
After 19000 training step(s), validation accuracy using average model is 0.9838 
After 20000 training step(s), validation accuracy using average model is 0.9834 
After 21000 training step(s), validation accuracy using average model is 0.9828 
After 22000 training step(s), validation accuracy using average model is 0.9834 
After 23000 training step(s), validation accuracy using average model is 0.9844 
After 24000 training step(s), validation accuracy using average model is 0.9838 
After 25000 training step(s), validation accuracy using average model is 0.9834 
After 26000 training step(s), validation accuracy using average model is 0.984 
After 27000 training step(s), validation accuracy using average model is 0.984 
After 28000 training step(s), validation accuracy using average model is 0.9836 
After 29000 training step(s), validation accuracy using average model is 0.9842 
After 30000 training steps, test accuracy using average model is 0.9839

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python根据区号生成手机号码的方法
Jul 08 Python
Python与Java间Socket通信实例代码
Mar 06 Python
Python实现简单过滤文本段的方法
May 24 Python
pandas数据框,统计某列数据对应的个数方法
Apr 11 Python
Python爬虫之正则表达式的使用教程详解
Oct 25 Python
python引入不同文件夹下的自定义模块方法
Oct 27 Python
Python零基础入门学习之输入与输出
Apr 03 Python
PyCharm无法识别PyQt5的2种解决方法,ModuleNotFoundError: No module named 'pyqt5'
Feb 17 Python
Python 基于FIR实现Hilbert滤波器求信号包络详解
Feb 26 Python
python 统计代码耗时的几种方法分享
Apr 02 Python
python文件名批量重命名脚本实例代码
Apr 22 Python
一些让Python代码简洁的实用技巧总结
Aug 23 Python
Python3 读取Word文件方式
Feb 13 #Python
解决Python import docx出错DLL load failed的问题
Feb 13 #Python
python求最大公约数和最小公倍数的简单方法
Feb 13 #Python
python圣诞树编写实例详解
Feb 13 #Python
python如何实现复制目录到指定目录
Feb 13 #Python
Python制作简易版小工具之计算天数的实现思路
Feb 13 #Python
解决python-docx打包之后找不到default.docx的问题
Feb 13 #Python
You might like
php数组函数序列之asort() - 对数组的元素值进行升序排序,保持索引关系
2011/11/02 PHP
php目录操作实例代码
2014/02/21 PHP
手把手编写PHP框架 深入了解MVC运行流程
2016/09/19 PHP
jQuery方法简洁实现隔行换色及toggleClass的使用
2013/03/15 Javascript
jQuery ui插件的使用方法代码实例
2013/05/08 Javascript
JavaScript位移运算符(无符号) >>> 三个大于号 的使用方法详解
2016/03/31 Javascript
AngularJS中的指令全面解析(必看)
2016/05/20 Javascript
Vuejs第十一篇组件之slot内容分发实例详解
2016/09/09 Javascript
yarn与npm的命令行小结
2016/10/20 Javascript
jQuery插件echarts实现的单折线图效果示例【附demo源码下载】
2017/03/04 Javascript
JavaScript评论点赞功能的实现方法
2017/03/13 Javascript
微信小程序 支付功能(前端)的实现
2017/05/24 Javascript
jQuery实现的文字逐行向上间歇滚动效果示例
2017/09/06 jQuery
Angularjs渲染的 using 指令的星级评分系统示例
2017/11/09 Javascript
vue.js实现点击后动态添加class及删除同级class的实现代码
2018/04/04 Javascript
vue-cli2.9.3 详细教程
2018/04/23 Javascript
基于Angularjs-router动态改变Title值的问题
2018/08/30 Javascript
vue数据操作之点击事件实现num加减功能示例
2019/01/19 Javascript
koa2 从入门到精通(小结)
2019/07/23 Javascript
基于vue3.0.1beta搭建仿京东的电商H5项目
2020/05/06 Javascript
Python中作用域的深入讲解
2018/12/10 Python
pandas实现to_sql将DataFrame保存到数据库中
2019/07/03 Python
Python tkinter实现图片标注功能(完整代码)
2019/12/08 Python
基于打开pycharm有带图片md文件卡死问题的解决
2020/04/24 Python
Python如何实现机器人聊天
2020/09/10 Python
python连接mysql数据库并读取数据的实现
2020/09/25 Python
全球性的在线时尚男装零售商:boohooMAN
2016/12/17 全球购物
美国新兴城市生活方式零售商:VILLA
2017/12/06 全球购物
德国香水、化妆品和护理产品网上商店:Parfumdreams
2018/09/26 全球购物
建筑工程技术应届生求职信
2013/11/17 职场文书
高中数学教学反思
2014/01/30 职场文书
班级活动策划书
2014/02/06 职场文书
2014年综合治理工作总结
2014/11/20 职场文书
2014年销售工作总结与计划
2014/12/01 职场文书
公司放假通知范文
2015/04/14 职场文书
采购员工作总结范文
2015/08/12 职场文书