Tensorflow训练MNIST手写数字识别模型


Posted in Python onFebruary 13, 2020

本文实例为大家分享了Tensorflow训练MNIST手写数字识别模型的具体代码,供大家参考,具体内容如下

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
 
INPUT_NODE = 784  # 输入层节点=图片像素=28x28=784
OUTPUT_NODE = 10  # 输出层节点数=图片类别数目
 
LAYER1_NODE = 500  # 隐藏层节点数,只有一个隐藏层
BATCH_SIZE = 100  # 一个训练包中的数据个数,数字越小
          # 越接近随机梯度下降,越大越接近梯度下降
 
LEARNING_RATE_BASE = 0.8   # 基础学习率
LEARNING_RATE_DECAY = 0.99  # 学习率衰减率
 
REGULARIZATION_RATE = 0.0001  # 正则化项系数
TRAINING_STEPS = 30000     # 训练轮数
MOVING_AVG_DECAY = 0.99    # 滑动平均衰减率
 
# 定义一个辅助函数,给定神经网络的输入和所有参数,计算神经网络的前向传播结果
def inference(input_tensor, avg_class, weights1, biases1,
       weights2, biases2):
 
 # 当没有提供滑动平均类时,直接使用参数当前取值
 if avg_class == None:
  # 计算隐藏层前向传播结果
  layer1 = tf.nn.relu(tf.matmul(input_tensor, weights1) + biases1)
  # 计算输出层前向传播结果
  return tf.matmul(layer1, weights2) + biases2
 else:
  # 首先计算变量的滑动平均值,然后计算前向传播结果
  layer1 = tf.nn.relu(
    tf.matmul(input_tensor, avg_class.average(weights1)) +
    avg_class.average(biases1))
  
  return tf.matmul(
    layer1, avg_class.average(weights2)) + avg_class.average(biases2)
 
# 训练模型的过程
def train(mnist):
 x = tf.placeholder(tf.float32, [None, INPUT_NODE], name='x-input')
 y_ = tf.placeholder(tf.float32, [None, OUTPUT_NODE], name='y-input')
 
 # 生成隐藏层参数
 weights1 = tf.Variable(
   tf.truncated_normal([INPUT_NODE, LAYER1_NODE], stddev=0.1))
 biases1 = tf.Variable(tf.constant(0.1, shape=[LAYER1_NODE]))
 
 # 生成输出层参数
 weights2 = tf.Variable(
   tf.truncated_normal([LAYER1_NODE, OUTPUT_NODE], stddev=0.1))
 biases2 = tf.Variable(tf.constant(0.1, shape=[OUTPUT_NODE]))
 
 # 计算前向传播结果,不使用参数滑动平均值 avg_class=None
 y = inference(x, None, weights1, biases1, weights2, biases2)
 
 # 定义训练轮数变量,指定为不可训练
 global_step = tf.Variable(0, trainable=False)
 
 # 给定滑动平均衰减率和训练轮数的变量,初始化滑动平均类
 variable_avgs = tf.train.ExponentialMovingAverage(
   MOVING_AVG_DECAY, global_step)
 
 # 在所有代表神经网络参数的可训练变量上使用滑动平均
 variables_avgs_op = variable_avgs.apply(tf.trainable_variables())
 
 # 计算使用滑动平均值后的前向传播结果
 avg_y = inference(x, variable_avgs, weights1, biases1, weights2, biases2)
 
 # 计算交叉熵作为损失函数
 cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
   logits=y, labels=tf.argmax(y_, 1))
 cross_entropy_mean = tf.reduce_mean(cross_entropy)
 
 # 计算L2正则化损失函数
 regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
 regularization = regularizer(weights1) + regularizer(weights2)
 
 loss = cross_entropy_mean + regularization
 
 # 设置指数衰减的学习率
 learning_rate = tf.train.exponential_decay(
   LEARNING_RATE_BASE,
   global_step,              # 当前迭代轮数
   mnist.train.num_examples / BATCH_SIZE, # 过完所有训练数据的迭代次数
   LEARNING_RATE_DECAY)
 
 
 # 优化损失函数
 train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(
   loss, global_step=global_step)
 
 # 反向传播同时更新神经网络参数及其滑动平均值
 with tf.control_dependencies([train_step, variables_avgs_op]):
  train_op = tf.no_op(name='train')
 
 # 检验使用了滑动平均模型的神经网络前向传播结果是否正确
 correct_prediction = tf.equal(tf.argmax(avg_y, 1), tf.argmax(y_, 1))
 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
 
 
 # 初始化会话并开始训练
 with tf.Session() as sess:
  tf.global_variables_initializer().run()
  
  # 准备验证数据,用于判断停止条件和训练效果
  validate_feed = {x: mnist.validation.images,
          y_: mnist.validation.labels}
  
  # 准备测试数据,用于模型优劣的最后评价标准
  test_feed = {x: mnist.test.images, y_: mnist.test.labels}
  
  # 迭代训练神经网络
  for i in range(TRAINING_STEPS):
   if i%1000 == 0:
    validate_acc = sess.run(accuracy, feed_dict=validate_feed)
    print("After %d training step(s), validation accuracy using average " 
       "model is %g " % (i, validate_acc))
    
   xs, ys = mnist.train.next_batch(BATCH_SIZE)
   sess.run(train_op, feed_dict={x: xs, y_: ys})
  
  # 训练结束后在测试集上检测模型的最终正确率
  test_acc = sess.run(accuracy, feed_dict=test_feed)
  print("After %d training steps, test accuracy using average model "
     "is %g " % (TRAINING_STEPS, test_acc))
  
  
# 主程序入口
def main(argv=None):
 mnist = input_data.read_data_sets("/tmp/data", one_hot=True)
 train(mnist)
 
# Tensorflow主程序入口
if __name__ == '__main__':
 tf.app.run()

输出结果如下:

Extracting /tmp/data/train-images-idx3-ubyte.gz
Extracting /tmp/data/train-labels-idx1-ubyte.gz
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz
After 0 training step(s), validation accuracy using average model is 0.0462 
After 1000 training step(s), validation accuracy using average model is 0.9784 
After 2000 training step(s), validation accuracy using average model is 0.9806 
After 3000 training step(s), validation accuracy using average model is 0.9798 
After 4000 training step(s), validation accuracy using average model is 0.9814 
After 5000 training step(s), validation accuracy using average model is 0.9826 
After 6000 training step(s), validation accuracy using average model is 0.9828 
After 7000 training step(s), validation accuracy using average model is 0.9832 
After 8000 training step(s), validation accuracy using average model is 0.9838 
After 9000 training step(s), validation accuracy using average model is 0.983 
After 10000 training step(s), validation accuracy using average model is 0.9836 
After 11000 training step(s), validation accuracy using average model is 0.9822 
After 12000 training step(s), validation accuracy using average model is 0.983 
After 13000 training step(s), validation accuracy using average model is 0.983 
After 14000 training step(s), validation accuracy using average model is 0.9844 
After 15000 training step(s), validation accuracy using average model is 0.9832 
After 16000 training step(s), validation accuracy using average model is 0.9844 
After 17000 training step(s), validation accuracy using average model is 0.9842 
After 18000 training step(s), validation accuracy using average model is 0.9842 
After 19000 training step(s), validation accuracy using average model is 0.9838 
After 20000 training step(s), validation accuracy using average model is 0.9834 
After 21000 training step(s), validation accuracy using average model is 0.9828 
After 22000 training step(s), validation accuracy using average model is 0.9834 
After 23000 training step(s), validation accuracy using average model is 0.9844 
After 24000 training step(s), validation accuracy using average model is 0.9838 
After 25000 training step(s), validation accuracy using average model is 0.9834 
After 26000 training step(s), validation accuracy using average model is 0.984 
After 27000 training step(s), validation accuracy using average model is 0.984 
After 28000 training step(s), validation accuracy using average model is 0.9836 
After 29000 training step(s), validation accuracy using average model is 0.9842 
After 30000 training steps, test accuracy using average model is 0.9839

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python切换hosts文件代码示例
Dec 31 Python
Python实现读取目录所有文件的文件名并保存到txt文件代码
Nov 22 Python
python比较两个列表是否相等的方法
Jul 28 Python
win系统下为Python3.5安装flask-mongoengine 库
Dec 20 Python
python 把文件中的每一行以数组的元素放入数组中的方法
Apr 29 Python
python3连接MySQL数据库实例详解
May 24 Python
Python socket模块方法实现详解
Nov 05 Python
tensorflow之获取tensor的shape作为max_pool的ksize实例
Jan 04 Python
python数据预处理 :数据共线性处理详解
Feb 24 Python
python GUI库图形界面开发之PyQt5滚动条控件QScrollBar详细使用方法与实例
Mar 06 Python
Python字节单位转换(将字节转换为K M G T)
Mar 02 Python
教你怎么用python爬取爱奇艺热门电影
May 20 Python
Python3 读取Word文件方式
Feb 13 #Python
解决Python import docx出错DLL load failed的问题
Feb 13 #Python
python求最大公约数和最小公倍数的简单方法
Feb 13 #Python
python圣诞树编写实例详解
Feb 13 #Python
python如何实现复制目录到指定目录
Feb 13 #Python
Python制作简易版小工具之计算天数的实现思路
Feb 13 #Python
解决python-docx打包之后找不到default.docx的问题
Feb 13 #Python
You might like
如何开发一个虚拟域名系统
2006/10/09 PHP
浅谈php函数serialize()与unserialize()的使用方法
2014/08/19 PHP
Smarty模板常见的简单应用分析
2016/11/15 PHP
php使用yield对性能提升的测试实例分析
2019/09/19 PHP
JQuery拖动表头边框线调整表格列宽效果代码
2014/09/10 Javascript
JavaScript数据类型检测代码分享
2015/01/26 Javascript
JavaScript类继承及实例化的方法
2015/07/25 Javascript
JavaScript头像上传插件源码分享
2016/03/29 Javascript
JS中IP地址与整数相互转换的实现代码
2017/04/10 Javascript
js轮播图透明度切换(带上下页和底部圆点切换)
2017/04/27 Javascript
解决webpack无法通过IP地址访问localhost的问题
2018/02/22 Javascript
vue.js项目nginx部署教程
2018/04/05 Javascript
详解javascript中的Error对象
2019/04/25 Javascript
配置一个vue3.0项目的完整步骤
2019/04/26 Javascript
微信小程序一周时间表功能实现
2019/10/17 Javascript
Vue将props值实时传递 并可修改的操作
2020/08/09 Javascript
[01:02:45]完美世界DOTA2联赛 LBZS vs Forest 第三场 11.07
2020/11/09 DOTA
python实现网页链接提取的方法分享
2014/02/25 Python
Python编程实现两个文件夹里文件的对比功能示例【包含内容的对比】
2017/06/20 Python
python+selenium实现京东自动登录及秒杀功能
2017/11/18 Python
Python3实现统计单词表中每个字母出现频率的方法示例
2019/01/28 Python
python并发编程 Process对象的其他属性方法join方法详解
2019/08/20 Python
Python装饰器使用你可能不知道的几种姿势
2019/10/25 Python
Django 多对多字段的更新和插入数据实例
2020/03/31 Python
详解python命令提示符窗口下如何运行python脚本
2020/09/11 Python
html5指南-2.如何操作document metadata
2013/01/07 HTML / CSS
html5 worker 实例(一) 为什么测试不到效果
2013/06/24 HTML / CSS
Bugatchi官方网站:男士服装在线
2019/04/10 全球购物
环保建议书作文
2014/03/12 职场文书
银行职员工作失误检讨书
2014/10/14 职场文书
学习心理学的体会
2014/11/07 职场文书
公司清洁工岗位职责
2015/04/15 职场文书
爱国主义影片观后感
2015/06/18 职场文书
关爱留守儿童主题班会
2015/08/13 职场文书
初中数学课堂教学反思
2016/02/17 职场文书
德劲DE1108畅想
2021/04/22 无线电