TensorFlow搭建神经网络最佳实践


Posted in Python onMarch 09, 2018

一、TensorFLow完整样例

在MNIST数据集上,搭建一个简单神经网络结构,一个包含ReLU单元的非线性化处理的两层神经网络。在训练神经网络的时候,使用带指数衰减的学习率设置、使用正则化来避免过拟合、使用滑动平均模型来使得最终的模型更加健壮。

程序将计算神经网络前向传播的部分单独定义一个函数inference,训练部分定义一个train函数,再定义一个主函数main。

完整程序:

#!/usr/bin/env python3 
# -*- coding: utf-8 -*- 
""" 
Created on Thu May 25 08:56:30 2017 
 
@author: marsjhao 
""" 
 
import tensorflow as tf 
from tensorflow.examples.tutorials.mnist import input_data 
 
INPUT_NODE = 784 # 输入节点数 
OUTPUT_NODE = 10 # 输出节点数 
LAYER1_NODE = 500 # 隐含层节点数 
BATCH_SIZE = 100 
LEARNING_RETE_BASE = 0.8 # 基学习率 
LEARNING_RETE_DECAY = 0.99 # 学习率的衰减率 
REGULARIZATION_RATE = 0.0001 # 正则化项的权重系数 
TRAINING_STEPS = 10000 # 迭代训练次数 
MOVING_AVERAGE_DECAY = 0.99 # 滑动平均的衰减系数 
 
# 传入神经网络的权重和偏置,计算神经网络前向传播的结果 
def inference(input_tensor, avg_class, weights1, biases1, weights2, biases2): 
  # 判断是否传入ExponentialMovingAverage类对象 
  if avg_class == None: 
    layer1 = tf.nn.relu(tf.matmul(input_tensor, weights1) + biases1) 
    return tf.matmul(layer1, weights2) + biases2 
  else: 
    layer1 = tf.nn.relu(tf.matmul(input_tensor, avg_class.average(weights1)) 
                   + avg_class.average(biases1)) 
    return tf.matmul(layer1, avg_class.average(weights2))\ 
             + avg_class.average(biases2) 
 
# 神经网络模型的训练过程 
def train(mnist): 
  x = tf.placeholder(tf.float32, [None,INPUT_NODE], name='x-input') 
  y_ = tf.placeholder(tf.float32, [None, OUTPUT_NODE], name='y-input') 
 
  # 定义神经网络结构的参数 
  weights1 = tf.Variable(tf.truncated_normal([INPUT_NODE, LAYER1_NODE], 
                        stddev=0.1)) 
  biases1 = tf.Variable(tf.constant(0.1, shape=[LAYER1_NODE])) 
  weights2 = tf.Variable(tf.truncated_normal([LAYER1_NODE, OUTPUT_NODE], 
                        stddev=0.1)) 
  biases2 = tf.Variable(tf.constant(0.1, shape=[OUTPUT_NODE])) 
 
  # 计算非滑动平均模型下的参数的前向传播的结果 
  y = inference(x, None, weights1, biases1, weights2, biases2) 
   
  global_step = tf.Variable(0, trainable=False) # 定义存储当前迭代训练轮数的变量 
 
  # 定义ExponentialMovingAverage类对象 
  variable_averages = tf.train.ExponentialMovingAverage( 
            MOVING_AVERAGE_DECAY, global_step) # 传入当前迭代轮数参数 
  # 定义对所有可训练变量trainable_variables进行更新滑动平均值的操作op 
  variables_averages_op = variable_averages.apply(tf.trainable_variables()) 
 
  # 计算滑动模型下的参数的前向传播的结果 
  average_y = inference(x, variable_averages, weights1, biases1, weights2, biases2) 
 
  # 定义交叉熵损失值 
  cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( 
          logits=y, labels=tf.argmax(y_, 1)) 
  cross_entropy_mean = tf.reduce_mean(cross_entropy) 
  # 定义L2正则化器并对weights1和weights2正则化 
  regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE) 
  regularization = regularizer(weights1) + regularizer(weights2) 
  loss = cross_entropy_mean + regularization # 总损失值 
 
  # 定义指数衰减学习率 
  learning_rate = tf.train.exponential_decay(LEARNING_RETE_BASE, global_step, 
          mnist.train.num_examples / BATCH_SIZE, LEARNING_RETE_DECAY) 
  # 定义梯度下降操作op,global_step参数可实现自加1运算 
  train_step = tf.train.GradientDescentOptimizer(learning_rate)\ 
             .minimize(loss, global_step=global_step) 
  # 组合两个操作op 
  train_op = tf.group(train_step, variables_averages_op) 
  ''''' 
  # 与tf.group()等价的语句 
  with tf.control_dependencies([train_step, variables_averages_op]): 
    train_op = tf.no_op(name='train') 
  ''' 
  # 定义准确率 
  # 在最终预测的时候,神经网络的输出采用的是经过滑动平均的前向传播计算结果 
  correct_prediction = tf.equal(tf.argmax(average_y, 1), tf.argmax(y_, 1)) 
  accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 
 
  # 初始化回话sess并开始迭代训练 
  with tf.Session() as sess: 
    sess.run(tf.global_variables_initializer()) 
    # 验证集待喂入数据 
    validate_feed = {x: mnist.validation.images, y_: mnist.validation.labels} 
    # 测试集待喂入数据 
    test_feed = {x: mnist.test.images, y_: mnist.test.labels} 
    for i in range(TRAINING_STEPS): 
      if i % 1000 == 0: 
        validate_acc = sess.run(accuracy, feed_dict=validate_feed) 
        print('After %d training steps, validation accuracy' 
           ' using average model is %f' % (i, validate_acc)) 
      xs, ys = mnist.train.next_batch(BATCH_SIZE) 
      sess.run(train_op, feed_dict={x: xs, y_:ys}) 
 
    test_acc = sess.run(accuracy, feed_dict=test_feed) 
    print('After %d training steps, test accuracy' 
       ' using average model is %f' % (TRAINING_STEPS, test_acc)) 
 
# 主函数 
def main(argv=None): 
  mnist = input_data.read_data_sets("MNIST_data", one_hot=True) 
  train(mnist) 
 
# 当前的python文件是shell文件执行的入口文件,而非当做import的python module。 
if __name__ == '__main__': # 在模块内部执行 
  tf.app.run() # 调用main函数并传入所需的参数list

二、分析与改进设计

1. 程序分析改进

第一,计算前向传播的函数inference中需要将所有的变量以参数的形式传入函数,当神经网络结构变得更加复杂、参数更多的时候,程序的可读性将变得非常差。

第二,在程序退出时,训练好的模型就无法再利用,且大型神经网络的训练时间都比较长,在训练过程中需要每隔一段时间保存一次模型训练的中间结果,这样如果在训练过程中程序死机,死机前的最新的模型参数仍能保留,杜绝了时间和资源的浪费。

第三,将训练和测试分成两个独立的程序,将训练和测试都会用到的前向传播的过程抽象成单独的库函数。这样就保证了在训练和预测两个过程中所调用的前向传播计算程序是一致的。

2. 改进后程序设计

mnist_inference.py

该文件中定义了神经网络的前向传播过程,其中的多次用到的weights定义过程又单独定义成函数。

通过tf.get_variable函数来获取变量,在神经网络训练时创建这些变量,在测试时会通过保存的模型加载这些变量的取值,而且可以在变量加载时将滑动平均值重命名。所以可以直接通过同样的名字在训练时使用变量自身,在测试时使用变量的滑动平均值。

mnist_train.py

该程序给出了神经网络的完整训练过程。

mnist_eval.py

在滑动平均模型上做测试。

通过tf.train.get_checkpoint_state(mnist_train.MODEL_SAVE_PATH)获取最新模型的文件名,实际是获取checkpoint文件的所有内容。

三、TensorFlow最佳实践样例

mnist_inference.py

import tensorflow as tf 
 
INPUT_NODE = 784 
OUTPUT_NODE = 10 
LAYER1_NODE = 500 
 
def get_weight_variable(shape, regularizer): 
  weights = tf.get_variable("weights", shape, 
         initializer=tf.truncated_normal_initializer(stddev=0.1)) 
  if regularizer != None: 
    # 将权重参数的正则化项加入至损失集合 
    tf.add_to_collection('losses', regularizer(weights)) 
  return weights 
 
def inference(input_tensor, regularizer): 
  with tf.variable_scope('layer1'): 
    weights = get_weight_variable([INPUT_NODE, LAYER1_NODE], regularizer) 
    biases = tf.get_variable("biases", [LAYER1_NODE], 
                 initializer=tf.constant_initializer(0.0)) 
    layer1 = tf.nn.relu(tf.matmul(input_tensor, weights) + biases) 
 
  with tf.variable_scope('layer2'): 
    weights = get_weight_variable([LAYER1_NODE, OUTPUT_NODE], regularizer) 
    biases = tf.get_variable("biases", [OUTPUT_NODE], 
                 initializer=tf.constant_initializer(0.0)) 
    layer2 = tf.matmul(layer1, weights) + biases 
 
  return layer2

mnist_train.py

import os 
import tensorflow as tf 
from tensorflow.examples.tutorials.mnist import input_data 
import mnist_inference 
 
BATCH_SIZE = 100 
LEARNING_RATE_BASE = 0.8 
LEARNING_RATE_DECAY = 0.99 
REGULARIZATION_RATE = 0.0001 
TRAINING_STEPS = 10000 
MOVING_AVERAGE_DECAY = 0.99 
 
MODEL_SAVE_PATH = "Model_Folder/" 
MODEL_NAME = "model.ckpt" 
 
def train(mnist): 
  # 定义输入placeholder 
  x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], 
            name='x-input') 
  y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], 
            name='y-input') 
  # 定义正则化器及计算前向过程输出 
  regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE) 
  y = mnist_inference.inference(x, regularizer) 
  # 定义当前训练轮数及滑动平均模型 
  global_step = tf.Variable(0, trainable=False) 
  variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, 
                             global_step) 
  variables_averages_op = variable_averages.apply(tf.trainable_variables()) 
  # 定义损失函数 
  cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, 
                          labels=tf.argmax(y_, 1)) 
  cross_entropy_mean = tf.reduce_mean(cross_entropy) 
  loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses')) 
  # 定义指数衰减学习率 
  learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE, global_step, 
          mnist.train.num_examples / BATCH_SIZE, LEARNING_RATE_DECAY) 
  # 定义训练操作,包括模型训练及滑动模型操作 
  train_step = tf.train.GradientDescentOptimizer(learning_rate)\ 
          .minimize(loss, global_step=global_step) 
  train_op = tf.group(train_step, variables_averages_op) 
  # 定义Saver类对象,保存模型,TensorFlow持久化类 
  saver = tf.train.Saver() 
 
  # 定义会话,启动训练过程 
  with tf.Session() as sess: 
    tf.global_variables_initializer().run() 
 
    for i in range(TRAINING_STEPS): 
      xs, ys = mnist.train.next_batch(BATCH_SIZE) 
      _, loss_value, step = sess.run([train_op, loss, global_step], 
                      feed_dict={x: xs, y_: ys}) 
      if i % 1000 == 0: 
        print("After %d training step(s), loss on training batch is %g."\ 
            % (step, loss_value)) 
        # save方法的global_step参数可以让每个被保存的模型的文件名末尾加上当前训练轮数 
        saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), 
              global_step=global_step) 
 
def main(argv=None): 
  mnist = input_data.read_data_sets("MNIST_data", one_hot=True) 
  train(mnist) 
 
if __name__ == '__main__': 
  tf.app.run()

mnist_eval.py

import time 
import tensorflow as tf 
from tensorflow.examples.tutorials.mnist import input_data 
import mnist_inference 
import mnist_train 
 
EVAL_INTERVAL_SECS = 10 
 
def evaluate(mnist): 
  with tf.Graph().as_default() as g: 
    # 定义输入placeholder 
    x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], 
              name='x-input') 
    y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], 
              name='y-input') 
    # 定义feed字典 
    validate_feed = {x: mnist.validation.images, y_: mnist.validation.labels} 
    # 测试时不加参数正则化损失 
    y = mnist_inference.inference(x, None) 
    # 计算正确率 
    correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) 
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 
    # 加载滑动平均模型下的参数值 
    variable_averages = tf.train.ExponentialMovingAverage( 
                   mnist_train.MOVING_AVERAGE_DECAY) 
    saver = tf.train.Saver(variable_averages.variables_to_restore()) 
 
    # 每隔EVAL_INTERVAL_SECS秒启动一次会话 
    while True: 
      with tf.Session() as sess: 
        ckpt = tf.train.get_checkpoint_state(mnist_train.MODEL_SAVE_PATH) 
        if ckpt and ckpt.model_checkpoint_path: 
          saver.restore(sess, ckpt.model_checkpoint_path) 
          # 取checkpoint文件中的当前迭代轮数global_step 
          global_step = ckpt.model_checkpoint_path\ 
                   .split('/')[-1].split('-')[-1] 
          accuracy_score = sess.run(accuracy, feed_dict=validate_feed) 
          print("After %s training step(s), validation accuracy = %g"\ 
             % (global_step, accuracy_score)) 
 
        else: 
          print('No checkpoint file found') 
          return 
      time.sleep(EVAL_INTERVAL_SECS) 
 
def main(argv=None): 
  mnist = input_data.read_data_sets("MNIST_data", one_hot=True) 
  evaluate(mnist) 
 
if __name__ == '__main__': 
  tf.app.run()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python代码检查工具pylint 让你的python更规范
Sep 05 Python
python实现ping的方法
Jul 06 Python
Python爬虫DOTA排行榜爬取实例(分享)
Jun 13 Python
Python获取CPU、内存使用率以及网络使用状态代码
Feb 08 Python
Tensorflow 同时载入多个模型的实例讲解
Jul 27 Python
Django框架模板介绍
Jan 15 Python
使用Python批量修改文件名的代码实例
Jan 24 Python
Python实现自定义读写分离代码实例
Nov 16 Python
通过实例简单了解Python中yield的作用
Dec 11 Python
keras 特征图可视化实例(中间层)
Jan 24 Python
python 使用cx-freeze打包程序的实现
Mar 14 Python
python+selenium+chrome批量文件下载并自动创建文件夹实例
Apr 27 Python
TensorFlow实现Batch Normalization
Mar 08 #Python
用Django实现一个可运行的区块链应用
Mar 08 #Python
Python pyinotify日志监控系统处理日志的方法
Mar 08 #Python
TensorFlow模型保存和提取的方法
Mar 08 #Python
火车票抢票python代码公开揭秘!
Mar 08 #Python
Python实现定时备份mysql数据库并把备份数据库邮件发送
Mar 08 #Python
python实现12306抢票及自动邮件发送提醒付款功能
Mar 08 #Python
You might like
PHP编程与应用
2006/10/09 PHP
如何使用PHP往windows中添加用户
2006/12/06 PHP
应用开发中涉及到的css和php笔记分享
2011/08/02 PHP
PHP使用观察者模式处理异常信息的方法详解
2019/09/24 PHP
jQuery侧边栏随窗口滚动实现方法
2013/03/04 Javascript
鼠标滚轮控制网页横向移动实现思路
2013/03/22 Javascript
js保留两位小数使用toFixed实现
2013/07/29 Javascript
js数组方法扩展实现数组统计函数
2014/04/09 Javascript
drag-and-drop实现图片浏览器预览
2015/08/06 Javascript
javascript实现3D变换的立体圆圈实例
2015/08/06 Javascript
Bootstrap表单布局样式源代码
2016/07/04 Javascript
easyui-combobox 实现简单的自动补全功能示例
2016/11/08 Javascript
js实现无缝轮播图
2020/03/09 Javascript
vue 解决mintui弹窗弹起来,底部页面滚动bug问题
2020/11/12 Javascript
[01:01:43]EG vs VP 2018国际邀请赛淘汰赛BO3 第二场 8.24
2018/08/25 DOTA
微信 用脚本查看是否被微信好友删除
2016/10/28 Python
解决python3 urllib中urlopen报错的问题
2017/03/25 Python
Python3实现抓取javascript动态生成的html网页功能示例
2017/08/22 Python
Python字符串格式化的方法(两种)
2017/09/19 Python
python利用微信公众号实现报警功能
2018/06/10 Python
python 3.6.5 安装配置方法图文教程
2018/09/18 Python
在PyCharm中三步完成PyPy解释器的配置的方法
2018/10/29 Python
Python中的十大图像处理工具(小结)
2019/06/10 Python
PHP统计代码行数的小代码
2019/09/19 Python
基于Pyinstaller打包Python程序并压缩文件大小
2020/05/28 Python
Python 生成短8位唯一id实战教程
2021/01/13 Python
解决html5中的video标签ios系统中无法播放使用的问题
2020/08/10 HTML / CSS
奥地利领先的在线药房:SHOP APOTHEKE
2019/10/07 全球购物
电影T恤、80年代T恤和80年代服装:TV Store Online
2020/01/05 全球购物
JD Sports澳洲官网:英国领先的运动鞋和运动时尚零售商
2020/02/15 全球购物
资金主管岗位职责范本
2014/03/04 职场文书
《火烧云》教学反思
2014/04/12 职场文书
党员示范岗材料
2014/12/19 职场文书
孟佩杰观后感
2015/06/17 职场文书
消费者投诉书范文
2015/07/02 职场文书
学生会2016感恩节活动小结
2016/04/01 职场文书