Tensorflow训练MNIST手写数字识别模型


Posted in Python onFebruary 13, 2020

本文实例为大家分享了Tensorflow训练MNIST手写数字识别模型的具体代码,供大家参考,具体内容如下

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
 
INPUT_NODE = 784  # 输入层节点=图片像素=28x28=784
OUTPUT_NODE = 10  # 输出层节点数=图片类别数目
 
LAYER1_NODE = 500  # 隐藏层节点数,只有一个隐藏层
BATCH_SIZE = 100  # 一个训练包中的数据个数,数字越小
          # 越接近随机梯度下降,越大越接近梯度下降
 
LEARNING_RATE_BASE = 0.8   # 基础学习率
LEARNING_RATE_DECAY = 0.99  # 学习率衰减率
 
REGULARIZATION_RATE = 0.0001  # 正则化项系数
TRAINING_STEPS = 30000     # 训练轮数
MOVING_AVG_DECAY = 0.99    # 滑动平均衰减率
 
# 定义一个辅助函数,给定神经网络的输入和所有参数,计算神经网络的前向传播结果
def inference(input_tensor, avg_class, weights1, biases1,
       weights2, biases2):
 
 # 当没有提供滑动平均类时,直接使用参数当前取值
 if avg_class == None:
  # 计算隐藏层前向传播结果
  layer1 = tf.nn.relu(tf.matmul(input_tensor, weights1) + biases1)
  # 计算输出层前向传播结果
  return tf.matmul(layer1, weights2) + biases2
 else:
  # 首先计算变量的滑动平均值,然后计算前向传播结果
  layer1 = tf.nn.relu(
    tf.matmul(input_tensor, avg_class.average(weights1)) +
    avg_class.average(biases1))
  
  return tf.matmul(
    layer1, avg_class.average(weights2)) + avg_class.average(biases2)
 
# 训练模型的过程
def train(mnist):
 x = tf.placeholder(tf.float32, [None, INPUT_NODE], name='x-input')
 y_ = tf.placeholder(tf.float32, [None, OUTPUT_NODE], name='y-input')
 
 # 生成隐藏层参数
 weights1 = tf.Variable(
   tf.truncated_normal([INPUT_NODE, LAYER1_NODE], stddev=0.1))
 biases1 = tf.Variable(tf.constant(0.1, shape=[LAYER1_NODE]))
 
 # 生成输出层参数
 weights2 = tf.Variable(
   tf.truncated_normal([LAYER1_NODE, OUTPUT_NODE], stddev=0.1))
 biases2 = tf.Variable(tf.constant(0.1, shape=[OUTPUT_NODE]))
 
 # 计算前向传播结果,不使用参数滑动平均值 avg_class=None
 y = inference(x, None, weights1, biases1, weights2, biases2)
 
 # 定义训练轮数变量,指定为不可训练
 global_step = tf.Variable(0, trainable=False)
 
 # 给定滑动平均衰减率和训练轮数的变量,初始化滑动平均类
 variable_avgs = tf.train.ExponentialMovingAverage(
   MOVING_AVG_DECAY, global_step)
 
 # 在所有代表神经网络参数的可训练变量上使用滑动平均
 variables_avgs_op = variable_avgs.apply(tf.trainable_variables())
 
 # 计算使用滑动平均值后的前向传播结果
 avg_y = inference(x, variable_avgs, weights1, biases1, weights2, biases2)
 
 # 计算交叉熵作为损失函数
 cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
   logits=y, labels=tf.argmax(y_, 1))
 cross_entropy_mean = tf.reduce_mean(cross_entropy)
 
 # 计算L2正则化损失函数
 regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
 regularization = regularizer(weights1) + regularizer(weights2)
 
 loss = cross_entropy_mean + regularization
 
 # 设置指数衰减的学习率
 learning_rate = tf.train.exponential_decay(
   LEARNING_RATE_BASE,
   global_step,              # 当前迭代轮数
   mnist.train.num_examples / BATCH_SIZE, # 过完所有训练数据的迭代次数
   LEARNING_RATE_DECAY)
 
 
 # 优化损失函数
 train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(
   loss, global_step=global_step)
 
 # 反向传播同时更新神经网络参数及其滑动平均值
 with tf.control_dependencies([train_step, variables_avgs_op]):
  train_op = tf.no_op(name='train')
 
 # 检验使用了滑动平均模型的神经网络前向传播结果是否正确
 correct_prediction = tf.equal(tf.argmax(avg_y, 1), tf.argmax(y_, 1))
 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
 
 
 # 初始化会话并开始训练
 with tf.Session() as sess:
  tf.global_variables_initializer().run()
  
  # 准备验证数据,用于判断停止条件和训练效果
  validate_feed = {x: mnist.validation.images,
          y_: mnist.validation.labels}
  
  # 准备测试数据,用于模型优劣的最后评价标准
  test_feed = {x: mnist.test.images, y_: mnist.test.labels}
  
  # 迭代训练神经网络
  for i in range(TRAINING_STEPS):
   if i%1000 == 0:
    validate_acc = sess.run(accuracy, feed_dict=validate_feed)
    print("After %d training step(s), validation accuracy using average " 
       "model is %g " % (i, validate_acc))
    
   xs, ys = mnist.train.next_batch(BATCH_SIZE)
   sess.run(train_op, feed_dict={x: xs, y_: ys})
  
  # 训练结束后在测试集上检测模型的最终正确率
  test_acc = sess.run(accuracy, feed_dict=test_feed)
  print("After %d training steps, test accuracy using average model "
     "is %g " % (TRAINING_STEPS, test_acc))
  
  
# 主程序入口
def main(argv=None):
 mnist = input_data.read_data_sets("/tmp/data", one_hot=True)
 train(mnist)
 
# Tensorflow主程序入口
if __name__ == '__main__':
 tf.app.run()

输出结果如下:

Extracting /tmp/data/train-images-idx3-ubyte.gz
Extracting /tmp/data/train-labels-idx1-ubyte.gz
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz
After 0 training step(s), validation accuracy using average model is 0.0462 
After 1000 training step(s), validation accuracy using average model is 0.9784 
After 2000 training step(s), validation accuracy using average model is 0.9806 
After 3000 training step(s), validation accuracy using average model is 0.9798 
After 4000 training step(s), validation accuracy using average model is 0.9814 
After 5000 training step(s), validation accuracy using average model is 0.9826 
After 6000 training step(s), validation accuracy using average model is 0.9828 
After 7000 training step(s), validation accuracy using average model is 0.9832 
After 8000 training step(s), validation accuracy using average model is 0.9838 
After 9000 training step(s), validation accuracy using average model is 0.983 
After 10000 training step(s), validation accuracy using average model is 0.9836 
After 11000 training step(s), validation accuracy using average model is 0.9822 
After 12000 training step(s), validation accuracy using average model is 0.983 
After 13000 training step(s), validation accuracy using average model is 0.983 
After 14000 training step(s), validation accuracy using average model is 0.9844 
After 15000 training step(s), validation accuracy using average model is 0.9832 
After 16000 training step(s), validation accuracy using average model is 0.9844 
After 17000 training step(s), validation accuracy using average model is 0.9842 
After 18000 training step(s), validation accuracy using average model is 0.9842 
After 19000 training step(s), validation accuracy using average model is 0.9838 
After 20000 training step(s), validation accuracy using average model is 0.9834 
After 21000 training step(s), validation accuracy using average model is 0.9828 
After 22000 training step(s), validation accuracy using average model is 0.9834 
After 23000 training step(s), validation accuracy using average model is 0.9844 
After 24000 training step(s), validation accuracy using average model is 0.9838 
After 25000 training step(s), validation accuracy using average model is 0.9834 
After 26000 training step(s), validation accuracy using average model is 0.984 
After 27000 training step(s), validation accuracy using average model is 0.984 
After 28000 training step(s), validation accuracy using average model is 0.9836 
After 29000 training step(s), validation accuracy using average model is 0.9842 
After 30000 training steps, test accuracy using average model is 0.9839

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python EOL while scanning string literal问题解决方法
Sep 18 Python
python中list列表的高级函数
May 17 Python
python实现音乐下载器
Apr 15 Python
Django-Rest-Framework 权限管理源码浅析(小结)
Nov 12 Python
Python去除字符串前后空格的几种方法
Mar 04 Python
Django urls.py重构及参数传递详解
Jul 23 Python
Python使用random模块生成随机数操作实例详解
Sep 17 Python
Python 实现训练集、测试集随机划分
Jan 08 Python
django 模型字段设置默认值代码
Jul 15 Python
matplotlib运行时配置(Runtime Configuration,rc)参数rcParams解析
Jan 05 Python
python爬虫scrapy基本使用超详细教程
Feb 20 Python
python3 删除所有自定义变量的操作
Apr 08 Python
Python3 读取Word文件方式
Feb 13 #Python
解决Python import docx出错DLL load failed的问题
Feb 13 #Python
python求最大公约数和最小公倍数的简单方法
Feb 13 #Python
python圣诞树编写实例详解
Feb 13 #Python
python如何实现复制目录到指定目录
Feb 13 #Python
Python制作简易版小工具之计算天数的实现思路
Feb 13 #Python
解决python-docx打包之后找不到default.docx的问题
Feb 13 #Python
You might like
mysql时区问题
2008/03/26 PHP
分享下PHP register_globals 值为on与off的理解
2013/09/26 PHP
php使用imagick模块实现图片缩放、裁剪、压缩示例
2014/04/17 PHP
php通过前序遍历树实现无需递归的无限极分类
2015/07/10 PHP
YII2框架中使用yii.js实现的post请求
2017/04/09 PHP
php简单构造json多维数组的方法示例
2017/06/08 PHP
jQuery 页面 Mask实现代码
2010/01/09 Javascript
基于jQuery的淡入淡出可自动切换的幻灯插件
2010/08/24 Javascript
修复ie8&chrome下window的resize事件多次执行
2011/10/20 Javascript
jquery分页插件AmSetPager(自写)
2013/04/15 Javascript
事件委托与阻止冒泡阻止其父元素事件触发
2014/09/02 Javascript
使用JSON.parse将json字符串转换成json对象的时候会出错
2014/09/04 Javascript
JavaScript中this详解
2015/09/01 Javascript
javascript绘制漂亮的心型线效果完整实例
2016/02/02 Javascript
JavaScript中push(),join() 函数 实例详解
2016/09/06 Javascript
Vue开发中整合axios的文件整理
2017/04/29 Javascript
详解开源的JavaScript插件化框架MinimaJS
2017/10/26 Javascript
JS基于递归实现网页版计算器的方法分析
2017/12/20 Javascript
微信小程序基于canvas渐变实现的彩虹效果示例
2019/05/03 Javascript
vue组件实现移动端九宫格转盘抽奖
2020/10/16 Javascript
[07:47]DOTA2国际邀请赛采访专栏:探访Valve总部
2013/08/08 DOTA
Python for Informatics 第11章之正则表达式(二)
2016/04/21 Python
一篇文章快速了解Python的GIL
2018/01/12 Python
python实现数据写入excel表格
2018/03/25 Python
pybind11和numpy进行交互的方法
2019/07/04 Python
python使用sklearn实现决策树的方法示例
2019/09/12 Python
Pycharm修改python路径过程图解
2020/05/22 Python
Python定时任务框架APScheduler原理及常用代码
2020/10/05 Python
万得城电器土耳其网站:欧洲第一大电子产品零售商
2016/10/07 全球购物
荷兰音乐会和音乐剧门票订购网站:Topticketshop
2019/08/27 全球购物
意大利和国际最佳时尚品牌:Drestige
2019/12/28 全球购物
外贸英语专业求职信范文
2013/12/25 职场文书
工作骂脏话检讨书
2014/10/05 职场文书
12.4全国法制宣传日活动方案
2014/11/02 职场文书
教师党员自我评价2015
2015/03/04 职场文书
如何把新闻人物写得立体、鲜活?
2019/08/14 职场文书