TensorFlow搭建神经网络最佳实践


Posted in Python onMarch 09, 2018

一、TensorFLow完整样例

在MNIST数据集上,搭建一个简单神经网络结构,一个包含ReLU单元的非线性化处理的两层神经网络。在训练神经网络的时候,使用带指数衰减的学习率设置、使用正则化来避免过拟合、使用滑动平均模型来使得最终的模型更加健壮。

程序将计算神经网络前向传播的部分单独定义一个函数inference,训练部分定义一个train函数,再定义一个主函数main。

完整程序:

#!/usr/bin/env python3 
# -*- coding: utf-8 -*- 
""" 
Created on Thu May 25 08:56:30 2017 
 
@author: marsjhao 
""" 
 
import tensorflow as tf 
from tensorflow.examples.tutorials.mnist import input_data 
 
INPUT_NODE = 784 # 输入节点数 
OUTPUT_NODE = 10 # 输出节点数 
LAYER1_NODE = 500 # 隐含层节点数 
BATCH_SIZE = 100 
LEARNING_RETE_BASE = 0.8 # 基学习率 
LEARNING_RETE_DECAY = 0.99 # 学习率的衰减率 
REGULARIZATION_RATE = 0.0001 # 正则化项的权重系数 
TRAINING_STEPS = 10000 # 迭代训练次数 
MOVING_AVERAGE_DECAY = 0.99 # 滑动平均的衰减系数 
 
# 传入神经网络的权重和偏置,计算神经网络前向传播的结果 
def inference(input_tensor, avg_class, weights1, biases1, weights2, biases2): 
  # 判断是否传入ExponentialMovingAverage类对象 
  if avg_class == None: 
    layer1 = tf.nn.relu(tf.matmul(input_tensor, weights1) + biases1) 
    return tf.matmul(layer1, weights2) + biases2 
  else: 
    layer1 = tf.nn.relu(tf.matmul(input_tensor, avg_class.average(weights1)) 
                   + avg_class.average(biases1)) 
    return tf.matmul(layer1, avg_class.average(weights2))\ 
             + avg_class.average(biases2) 
 
# 神经网络模型的训练过程 
def train(mnist): 
  x = tf.placeholder(tf.float32, [None,INPUT_NODE], name='x-input') 
  y_ = tf.placeholder(tf.float32, [None, OUTPUT_NODE], name='y-input') 
 
  # 定义神经网络结构的参数 
  weights1 = tf.Variable(tf.truncated_normal([INPUT_NODE, LAYER1_NODE], 
                        stddev=0.1)) 
  biases1 = tf.Variable(tf.constant(0.1, shape=[LAYER1_NODE])) 
  weights2 = tf.Variable(tf.truncated_normal([LAYER1_NODE, OUTPUT_NODE], 
                        stddev=0.1)) 
  biases2 = tf.Variable(tf.constant(0.1, shape=[OUTPUT_NODE])) 
 
  # 计算非滑动平均模型下的参数的前向传播的结果 
  y = inference(x, None, weights1, biases1, weights2, biases2) 
   
  global_step = tf.Variable(0, trainable=False) # 定义存储当前迭代训练轮数的变量 
 
  # 定义ExponentialMovingAverage类对象 
  variable_averages = tf.train.ExponentialMovingAverage( 
            MOVING_AVERAGE_DECAY, global_step) # 传入当前迭代轮数参数 
  # 定义对所有可训练变量trainable_variables进行更新滑动平均值的操作op 
  variables_averages_op = variable_averages.apply(tf.trainable_variables()) 
 
  # 计算滑动模型下的参数的前向传播的结果 
  average_y = inference(x, variable_averages, weights1, biases1, weights2, biases2) 
 
  # 定义交叉熵损失值 
  cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( 
          logits=y, labels=tf.argmax(y_, 1)) 
  cross_entropy_mean = tf.reduce_mean(cross_entropy) 
  # 定义L2正则化器并对weights1和weights2正则化 
  regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE) 
  regularization = regularizer(weights1) + regularizer(weights2) 
  loss = cross_entropy_mean + regularization # 总损失值 
 
  # 定义指数衰减学习率 
  learning_rate = tf.train.exponential_decay(LEARNING_RETE_BASE, global_step, 
          mnist.train.num_examples / BATCH_SIZE, LEARNING_RETE_DECAY) 
  # 定义梯度下降操作op,global_step参数可实现自加1运算 
  train_step = tf.train.GradientDescentOptimizer(learning_rate)\ 
             .minimize(loss, global_step=global_step) 
  # 组合两个操作op 
  train_op = tf.group(train_step, variables_averages_op) 
  ''''' 
  # 与tf.group()等价的语句 
  with tf.control_dependencies([train_step, variables_averages_op]): 
    train_op = tf.no_op(name='train') 
  ''' 
  # 定义准确率 
  # 在最终预测的时候,神经网络的输出采用的是经过滑动平均的前向传播计算结果 
  correct_prediction = tf.equal(tf.argmax(average_y, 1), tf.argmax(y_, 1)) 
  accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 
 
  # 初始化回话sess并开始迭代训练 
  with tf.Session() as sess: 
    sess.run(tf.global_variables_initializer()) 
    # 验证集待喂入数据 
    validate_feed = {x: mnist.validation.images, y_: mnist.validation.labels} 
    # 测试集待喂入数据 
    test_feed = {x: mnist.test.images, y_: mnist.test.labels} 
    for i in range(TRAINING_STEPS): 
      if i % 1000 == 0: 
        validate_acc = sess.run(accuracy, feed_dict=validate_feed) 
        print('After %d training steps, validation accuracy' 
           ' using average model is %f' % (i, validate_acc)) 
      xs, ys = mnist.train.next_batch(BATCH_SIZE) 
      sess.run(train_op, feed_dict={x: xs, y_:ys}) 
 
    test_acc = sess.run(accuracy, feed_dict=test_feed) 
    print('After %d training steps, test accuracy' 
       ' using average model is %f' % (TRAINING_STEPS, test_acc)) 
 
# 主函数 
def main(argv=None): 
  mnist = input_data.read_data_sets("MNIST_data", one_hot=True) 
  train(mnist) 
 
# 当前的python文件是shell文件执行的入口文件,而非当做import的python module。 
if __name__ == '__main__': # 在模块内部执行 
  tf.app.run() # 调用main函数并传入所需的参数list

二、分析与改进设计

1. 程序分析改进

第一,计算前向传播的函数inference中需要将所有的变量以参数的形式传入函数,当神经网络结构变得更加复杂、参数更多的时候,程序的可读性将变得非常差。

第二,在程序退出时,训练好的模型就无法再利用,且大型神经网络的训练时间都比较长,在训练过程中需要每隔一段时间保存一次模型训练的中间结果,这样如果在训练过程中程序死机,死机前的最新的模型参数仍能保留,杜绝了时间和资源的浪费。

第三,将训练和测试分成两个独立的程序,将训练和测试都会用到的前向传播的过程抽象成单独的库函数。这样就保证了在训练和预测两个过程中所调用的前向传播计算程序是一致的。

2. 改进后程序设计

mnist_inference.py

该文件中定义了神经网络的前向传播过程,其中的多次用到的weights定义过程又单独定义成函数。

通过tf.get_variable函数来获取变量,在神经网络训练时创建这些变量,在测试时会通过保存的模型加载这些变量的取值,而且可以在变量加载时将滑动平均值重命名。所以可以直接通过同样的名字在训练时使用变量自身,在测试时使用变量的滑动平均值。

mnist_train.py

该程序给出了神经网络的完整训练过程。

mnist_eval.py

在滑动平均模型上做测试。

通过tf.train.get_checkpoint_state(mnist_train.MODEL_SAVE_PATH)获取最新模型的文件名,实际是获取checkpoint文件的所有内容。

三、TensorFlow最佳实践样例

mnist_inference.py

import tensorflow as tf 
 
INPUT_NODE = 784 
OUTPUT_NODE = 10 
LAYER1_NODE = 500 
 
def get_weight_variable(shape, regularizer): 
  weights = tf.get_variable("weights", shape, 
         initializer=tf.truncated_normal_initializer(stddev=0.1)) 
  if regularizer != None: 
    # 将权重参数的正则化项加入至损失集合 
    tf.add_to_collection('losses', regularizer(weights)) 
  return weights 
 
def inference(input_tensor, regularizer): 
  with tf.variable_scope('layer1'): 
    weights = get_weight_variable([INPUT_NODE, LAYER1_NODE], regularizer) 
    biases = tf.get_variable("biases", [LAYER1_NODE], 
                 initializer=tf.constant_initializer(0.0)) 
    layer1 = tf.nn.relu(tf.matmul(input_tensor, weights) + biases) 
 
  with tf.variable_scope('layer2'): 
    weights = get_weight_variable([LAYER1_NODE, OUTPUT_NODE], regularizer) 
    biases = tf.get_variable("biases", [OUTPUT_NODE], 
                 initializer=tf.constant_initializer(0.0)) 
    layer2 = tf.matmul(layer1, weights) + biases 
 
  return layer2

mnist_train.py

import os 
import tensorflow as tf 
from tensorflow.examples.tutorials.mnist import input_data 
import mnist_inference 
 
BATCH_SIZE = 100 
LEARNING_RATE_BASE = 0.8 
LEARNING_RATE_DECAY = 0.99 
REGULARIZATION_RATE = 0.0001 
TRAINING_STEPS = 10000 
MOVING_AVERAGE_DECAY = 0.99 
 
MODEL_SAVE_PATH = "Model_Folder/" 
MODEL_NAME = "model.ckpt" 
 
def train(mnist): 
  # 定义输入placeholder 
  x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], 
            name='x-input') 
  y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], 
            name='y-input') 
  # 定义正则化器及计算前向过程输出 
  regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE) 
  y = mnist_inference.inference(x, regularizer) 
  # 定义当前训练轮数及滑动平均模型 
  global_step = tf.Variable(0, trainable=False) 
  variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, 
                             global_step) 
  variables_averages_op = variable_averages.apply(tf.trainable_variables()) 
  # 定义损失函数 
  cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, 
                          labels=tf.argmax(y_, 1)) 
  cross_entropy_mean = tf.reduce_mean(cross_entropy) 
  loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses')) 
  # 定义指数衰减学习率 
  learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE, global_step, 
          mnist.train.num_examples / BATCH_SIZE, LEARNING_RATE_DECAY) 
  # 定义训练操作,包括模型训练及滑动模型操作 
  train_step = tf.train.GradientDescentOptimizer(learning_rate)\ 
          .minimize(loss, global_step=global_step) 
  train_op = tf.group(train_step, variables_averages_op) 
  # 定义Saver类对象,保存模型,TensorFlow持久化类 
  saver = tf.train.Saver() 
 
  # 定义会话,启动训练过程 
  with tf.Session() as sess: 
    tf.global_variables_initializer().run() 
 
    for i in range(TRAINING_STEPS): 
      xs, ys = mnist.train.next_batch(BATCH_SIZE) 
      _, loss_value, step = sess.run([train_op, loss, global_step], 
                      feed_dict={x: xs, y_: ys}) 
      if i % 1000 == 0: 
        print("After %d training step(s), loss on training batch is %g."\ 
            % (step, loss_value)) 
        # save方法的global_step参数可以让每个被保存的模型的文件名末尾加上当前训练轮数 
        saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), 
              global_step=global_step) 
 
def main(argv=None): 
  mnist = input_data.read_data_sets("MNIST_data", one_hot=True) 
  train(mnist) 
 
if __name__ == '__main__': 
  tf.app.run()

mnist_eval.py

import time 
import tensorflow as tf 
from tensorflow.examples.tutorials.mnist import input_data 
import mnist_inference 
import mnist_train 
 
EVAL_INTERVAL_SECS = 10 
 
def evaluate(mnist): 
  with tf.Graph().as_default() as g: 
    # 定义输入placeholder 
    x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], 
              name='x-input') 
    y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], 
              name='y-input') 
    # 定义feed字典 
    validate_feed = {x: mnist.validation.images, y_: mnist.validation.labels} 
    # 测试时不加参数正则化损失 
    y = mnist_inference.inference(x, None) 
    # 计算正确率 
    correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) 
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 
    # 加载滑动平均模型下的参数值 
    variable_averages = tf.train.ExponentialMovingAverage( 
                   mnist_train.MOVING_AVERAGE_DECAY) 
    saver = tf.train.Saver(variable_averages.variables_to_restore()) 
 
    # 每隔EVAL_INTERVAL_SECS秒启动一次会话 
    while True: 
      with tf.Session() as sess: 
        ckpt = tf.train.get_checkpoint_state(mnist_train.MODEL_SAVE_PATH) 
        if ckpt and ckpt.model_checkpoint_path: 
          saver.restore(sess, ckpt.model_checkpoint_path) 
          # 取checkpoint文件中的当前迭代轮数global_step 
          global_step = ckpt.model_checkpoint_path\ 
                   .split('/')[-1].split('-')[-1] 
          accuracy_score = sess.run(accuracy, feed_dict=validate_feed) 
          print("After %s training step(s), validation accuracy = %g"\ 
             % (global_step, accuracy_score)) 
 
        else: 
          print('No checkpoint file found') 
          return 
      time.sleep(EVAL_INTERVAL_SECS) 
 
def main(argv=None): 
  mnist = input_data.read_data_sets("MNIST_data", one_hot=True) 
  evaluate(mnist) 
 
if __name__ == '__main__': 
  tf.app.run()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中线程的MQ消息队列实现以及消息队列的优点解析
Jun 29 Python
各种Python库安装包下载地址与安装过程详细介绍(Windows版)
Nov 02 Python
Python创建一个空的dataframe,并循环赋值的方法
Nov 08 Python
Python3爬虫教程之利用Python实现发送天气预报邮件
Dec 16 Python
Django的models模型的具体使用
Jul 15 Python
django 使用全局搜索功能的实例详解
Jul 18 Python
python 根据网易云歌曲的ID 直接下载歌曲的实例
Aug 24 Python
对Python中一维向量和一维向量转置相乘的方法详解
Aug 26 Python
MNIST数据集转化为二维图片的实现示例
Jan 10 Python
python tqdm库的使用
Nov 30 Python
python使用openpyxl库读写Excel表格的方法(增删改查操作)
May 02 Python
OpenCV-Python实现轮廓的特征值
Jun 09 Python
TensorFlow实现Batch Normalization
Mar 08 #Python
用Django实现一个可运行的区块链应用
Mar 08 #Python
Python pyinotify日志监控系统处理日志的方法
Mar 08 #Python
TensorFlow模型保存和提取的方法
Mar 08 #Python
火车票抢票python代码公开揭秘!
Mar 08 #Python
Python实现定时备份mysql数据库并把备份数据库邮件发送
Mar 08 #Python
python实现12306抢票及自动邮件发送提醒付款功能
Mar 08 #Python
You might like
getJSON跨域SyntaxError问题分析
2014/08/07 PHP
smarty中常用方法实例总结
2015/08/07 PHP
是 WordPress 让 PHP 更流行了 而不是框架
2016/02/03 PHP
php自定义函数实现JS的escape的方法示例
2016/07/07 PHP
21个值得收藏的Javascript技巧
2014/02/04 Javascript
jQuery的$.proxy()应用示例介绍
2014/04/03 Javascript
JQuery EasyUI 数字格式化处理示例
2014/05/05 Javascript
node.js中的buffer.Buffer.isEncoding方法使用说明
2014/12/14 Javascript
JQuery中属性过滤选择器用法实例分析
2015/05/18 Javascript
JS设置cookie、读取cookie
2016/02/24 Javascript
js事件处理程序跨浏览器解决方案
2016/03/27 Javascript
浅谈JavaScript的push(),pop(),concat()方法
2016/06/03 Javascript
JS实现将Asp.Net的DateTime Json类型转换为标准时间的方法
2016/08/02 Javascript
实例解析jQuery中如何取消后续执行内容
2016/12/01 Javascript
WEB开发之注册页面验证码倒计时代码的实现
2016/12/15 Javascript
jQuery插件FusionCharts绘制的2D条状图效果【附demo源码】
2017/05/13 jQuery
浅谈JavaScript中的属性:如何遍历属性
2017/09/14 Javascript
Vue中父子组件通讯之todolist组件功能开发
2018/05/21 Javascript
9种方法优化jQuery代码详解
2020/02/04 jQuery
[52:57]2014 DOTA2国际邀请赛中国区预选赛 LGD-CDEC VS HGT
2014/05/21 DOTA
Python and、or以及and-or语法总结
2015/04/14 Python
python删除过期文件的方法
2015/05/29 Python
浅析Python pandas模块输出每行中间省略号问题
2018/07/03 Python
浅谈python下tiff图像的读取和保存方法
2018/12/04 Python
Python 中pandas索引切片读取数据缺失数据处理问题
2019/10/09 Python
Python中pyecharts安装及安装失败的解决方法
2020/02/18 Python
python logging通过json文件配置的步骤
2020/04/27 Python
Python getattr()函数使用方法代码实例
2020/08/10 Python
Python抓包并解析json爬虫的完整实例代码
2020/11/03 Python
推荐WEB开发者最佳HTML5和CSS3代码生成器
2015/11/24 HTML / CSS
SIMON MILLER官网:洛杉矶的生活方式品牌
2020/10/19 全球购物
网络通讯中,端口有什么含义,端口的取值范围
2012/11/23 面试题
英语复习计划
2015/01/19 职场文书
MySQL查看表和清空表的常用命令总结
2021/05/26 MySQL
基于Java的MathML转图片的方法(示例代码)
2021/06/23 Java/Android
MySQL数据库如何查看表占用空间大小
2022/06/10 MySQL