TensorFlow搭建神经网络最佳实践


Posted in Python onMarch 09, 2018

一、TensorFLow完整样例

在MNIST数据集上,搭建一个简单神经网络结构,一个包含ReLU单元的非线性化处理的两层神经网络。在训练神经网络的时候,使用带指数衰减的学习率设置、使用正则化来避免过拟合、使用滑动平均模型来使得最终的模型更加健壮。

程序将计算神经网络前向传播的部分单独定义一个函数inference,训练部分定义一个train函数,再定义一个主函数main。

完整程序:

#!/usr/bin/env python3 
# -*- coding: utf-8 -*- 
""" 
Created on Thu May 25 08:56:30 2017 
 
@author: marsjhao 
""" 
 
import tensorflow as tf 
from tensorflow.examples.tutorials.mnist import input_data 
 
INPUT_NODE = 784 # 输入节点数 
OUTPUT_NODE = 10 # 输出节点数 
LAYER1_NODE = 500 # 隐含层节点数 
BATCH_SIZE = 100 
LEARNING_RETE_BASE = 0.8 # 基学习率 
LEARNING_RETE_DECAY = 0.99 # 学习率的衰减率 
REGULARIZATION_RATE = 0.0001 # 正则化项的权重系数 
TRAINING_STEPS = 10000 # 迭代训练次数 
MOVING_AVERAGE_DECAY = 0.99 # 滑动平均的衰减系数 
 
# 传入神经网络的权重和偏置,计算神经网络前向传播的结果 
def inference(input_tensor, avg_class, weights1, biases1, weights2, biases2): 
  # 判断是否传入ExponentialMovingAverage类对象 
  if avg_class == None: 
    layer1 = tf.nn.relu(tf.matmul(input_tensor, weights1) + biases1) 
    return tf.matmul(layer1, weights2) + biases2 
  else: 
    layer1 = tf.nn.relu(tf.matmul(input_tensor, avg_class.average(weights1)) 
                   + avg_class.average(biases1)) 
    return tf.matmul(layer1, avg_class.average(weights2))\ 
             + avg_class.average(biases2) 
 
# 神经网络模型的训练过程 
def train(mnist): 
  x = tf.placeholder(tf.float32, [None,INPUT_NODE], name='x-input') 
  y_ = tf.placeholder(tf.float32, [None, OUTPUT_NODE], name='y-input') 
 
  # 定义神经网络结构的参数 
  weights1 = tf.Variable(tf.truncated_normal([INPUT_NODE, LAYER1_NODE], 
                        stddev=0.1)) 
  biases1 = tf.Variable(tf.constant(0.1, shape=[LAYER1_NODE])) 
  weights2 = tf.Variable(tf.truncated_normal([LAYER1_NODE, OUTPUT_NODE], 
                        stddev=0.1)) 
  biases2 = tf.Variable(tf.constant(0.1, shape=[OUTPUT_NODE])) 
 
  # 计算非滑动平均模型下的参数的前向传播的结果 
  y = inference(x, None, weights1, biases1, weights2, biases2) 
   
  global_step = tf.Variable(0, trainable=False) # 定义存储当前迭代训练轮数的变量 
 
  # 定义ExponentialMovingAverage类对象 
  variable_averages = tf.train.ExponentialMovingAverage( 
            MOVING_AVERAGE_DECAY, global_step) # 传入当前迭代轮数参数 
  # 定义对所有可训练变量trainable_variables进行更新滑动平均值的操作op 
  variables_averages_op = variable_averages.apply(tf.trainable_variables()) 
 
  # 计算滑动模型下的参数的前向传播的结果 
  average_y = inference(x, variable_averages, weights1, biases1, weights2, biases2) 
 
  # 定义交叉熵损失值 
  cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( 
          logits=y, labels=tf.argmax(y_, 1)) 
  cross_entropy_mean = tf.reduce_mean(cross_entropy) 
  # 定义L2正则化器并对weights1和weights2正则化 
  regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE) 
  regularization = regularizer(weights1) + regularizer(weights2) 
  loss = cross_entropy_mean + regularization # 总损失值 
 
  # 定义指数衰减学习率 
  learning_rate = tf.train.exponential_decay(LEARNING_RETE_BASE, global_step, 
          mnist.train.num_examples / BATCH_SIZE, LEARNING_RETE_DECAY) 
  # 定义梯度下降操作op,global_step参数可实现自加1运算 
  train_step = tf.train.GradientDescentOptimizer(learning_rate)\ 
             .minimize(loss, global_step=global_step) 
  # 组合两个操作op 
  train_op = tf.group(train_step, variables_averages_op) 
  ''''' 
  # 与tf.group()等价的语句 
  with tf.control_dependencies([train_step, variables_averages_op]): 
    train_op = tf.no_op(name='train') 
  ''' 
  # 定义准确率 
  # 在最终预测的时候,神经网络的输出采用的是经过滑动平均的前向传播计算结果 
  correct_prediction = tf.equal(tf.argmax(average_y, 1), tf.argmax(y_, 1)) 
  accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 
 
  # 初始化回话sess并开始迭代训练 
  with tf.Session() as sess: 
    sess.run(tf.global_variables_initializer()) 
    # 验证集待喂入数据 
    validate_feed = {x: mnist.validation.images, y_: mnist.validation.labels} 
    # 测试集待喂入数据 
    test_feed = {x: mnist.test.images, y_: mnist.test.labels} 
    for i in range(TRAINING_STEPS): 
      if i % 1000 == 0: 
        validate_acc = sess.run(accuracy, feed_dict=validate_feed) 
        print('After %d training steps, validation accuracy' 
           ' using average model is %f' % (i, validate_acc)) 
      xs, ys = mnist.train.next_batch(BATCH_SIZE) 
      sess.run(train_op, feed_dict={x: xs, y_:ys}) 
 
    test_acc = sess.run(accuracy, feed_dict=test_feed) 
    print('After %d training steps, test accuracy' 
       ' using average model is %f' % (TRAINING_STEPS, test_acc)) 
 
# 主函数 
def main(argv=None): 
  mnist = input_data.read_data_sets("MNIST_data", one_hot=True) 
  train(mnist) 
 
# 当前的python文件是shell文件执行的入口文件,而非当做import的python module。 
if __name__ == '__main__': # 在模块内部执行 
  tf.app.run() # 调用main函数并传入所需的参数list

二、分析与改进设计

1. 程序分析改进

第一,计算前向传播的函数inference中需要将所有的变量以参数的形式传入函数,当神经网络结构变得更加复杂、参数更多的时候,程序的可读性将变得非常差。

第二,在程序退出时,训练好的模型就无法再利用,且大型神经网络的训练时间都比较长,在训练过程中需要每隔一段时间保存一次模型训练的中间结果,这样如果在训练过程中程序死机,死机前的最新的模型参数仍能保留,杜绝了时间和资源的浪费。

第三,将训练和测试分成两个独立的程序,将训练和测试都会用到的前向传播的过程抽象成单独的库函数。这样就保证了在训练和预测两个过程中所调用的前向传播计算程序是一致的。

2. 改进后程序设计

mnist_inference.py

该文件中定义了神经网络的前向传播过程,其中的多次用到的weights定义过程又单独定义成函数。

通过tf.get_variable函数来获取变量,在神经网络训练时创建这些变量,在测试时会通过保存的模型加载这些变量的取值,而且可以在变量加载时将滑动平均值重命名。所以可以直接通过同样的名字在训练时使用变量自身,在测试时使用变量的滑动平均值。

mnist_train.py

该程序给出了神经网络的完整训练过程。

mnist_eval.py

在滑动平均模型上做测试。

通过tf.train.get_checkpoint_state(mnist_train.MODEL_SAVE_PATH)获取最新模型的文件名,实际是获取checkpoint文件的所有内容。

三、TensorFlow最佳实践样例

mnist_inference.py

import tensorflow as tf 
 
INPUT_NODE = 784 
OUTPUT_NODE = 10 
LAYER1_NODE = 500 
 
def get_weight_variable(shape, regularizer): 
  weights = tf.get_variable("weights", shape, 
         initializer=tf.truncated_normal_initializer(stddev=0.1)) 
  if regularizer != None: 
    # 将权重参数的正则化项加入至损失集合 
    tf.add_to_collection('losses', regularizer(weights)) 
  return weights 
 
def inference(input_tensor, regularizer): 
  with tf.variable_scope('layer1'): 
    weights = get_weight_variable([INPUT_NODE, LAYER1_NODE], regularizer) 
    biases = tf.get_variable("biases", [LAYER1_NODE], 
                 initializer=tf.constant_initializer(0.0)) 
    layer1 = tf.nn.relu(tf.matmul(input_tensor, weights) + biases) 
 
  with tf.variable_scope('layer2'): 
    weights = get_weight_variable([LAYER1_NODE, OUTPUT_NODE], regularizer) 
    biases = tf.get_variable("biases", [OUTPUT_NODE], 
                 initializer=tf.constant_initializer(0.0)) 
    layer2 = tf.matmul(layer1, weights) + biases 
 
  return layer2

mnist_train.py

import os 
import tensorflow as tf 
from tensorflow.examples.tutorials.mnist import input_data 
import mnist_inference 
 
BATCH_SIZE = 100 
LEARNING_RATE_BASE = 0.8 
LEARNING_RATE_DECAY = 0.99 
REGULARIZATION_RATE = 0.0001 
TRAINING_STEPS = 10000 
MOVING_AVERAGE_DECAY = 0.99 
 
MODEL_SAVE_PATH = "Model_Folder/" 
MODEL_NAME = "model.ckpt" 
 
def train(mnist): 
  # 定义输入placeholder 
  x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], 
            name='x-input') 
  y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], 
            name='y-input') 
  # 定义正则化器及计算前向过程输出 
  regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE) 
  y = mnist_inference.inference(x, regularizer) 
  # 定义当前训练轮数及滑动平均模型 
  global_step = tf.Variable(0, trainable=False) 
  variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, 
                             global_step) 
  variables_averages_op = variable_averages.apply(tf.trainable_variables()) 
  # 定义损失函数 
  cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, 
                          labels=tf.argmax(y_, 1)) 
  cross_entropy_mean = tf.reduce_mean(cross_entropy) 
  loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses')) 
  # 定义指数衰减学习率 
  learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE, global_step, 
          mnist.train.num_examples / BATCH_SIZE, LEARNING_RATE_DECAY) 
  # 定义训练操作,包括模型训练及滑动模型操作 
  train_step = tf.train.GradientDescentOptimizer(learning_rate)\ 
          .minimize(loss, global_step=global_step) 
  train_op = tf.group(train_step, variables_averages_op) 
  # 定义Saver类对象,保存模型,TensorFlow持久化类 
  saver = tf.train.Saver() 
 
  # 定义会话,启动训练过程 
  with tf.Session() as sess: 
    tf.global_variables_initializer().run() 
 
    for i in range(TRAINING_STEPS): 
      xs, ys = mnist.train.next_batch(BATCH_SIZE) 
      _, loss_value, step = sess.run([train_op, loss, global_step], 
                      feed_dict={x: xs, y_: ys}) 
      if i % 1000 == 0: 
        print("After %d training step(s), loss on training batch is %g."\ 
            % (step, loss_value)) 
        # save方法的global_step参数可以让每个被保存的模型的文件名末尾加上当前训练轮数 
        saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), 
              global_step=global_step) 
 
def main(argv=None): 
  mnist = input_data.read_data_sets("MNIST_data", one_hot=True) 
  train(mnist) 
 
if __name__ == '__main__': 
  tf.app.run()

mnist_eval.py

import time 
import tensorflow as tf 
from tensorflow.examples.tutorials.mnist import input_data 
import mnist_inference 
import mnist_train 
 
EVAL_INTERVAL_SECS = 10 
 
def evaluate(mnist): 
  with tf.Graph().as_default() as g: 
    # 定义输入placeholder 
    x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], 
              name='x-input') 
    y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], 
              name='y-input') 
    # 定义feed字典 
    validate_feed = {x: mnist.validation.images, y_: mnist.validation.labels} 
    # 测试时不加参数正则化损失 
    y = mnist_inference.inference(x, None) 
    # 计算正确率 
    correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) 
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 
    # 加载滑动平均模型下的参数值 
    variable_averages = tf.train.ExponentialMovingAverage( 
                   mnist_train.MOVING_AVERAGE_DECAY) 
    saver = tf.train.Saver(variable_averages.variables_to_restore()) 
 
    # 每隔EVAL_INTERVAL_SECS秒启动一次会话 
    while True: 
      with tf.Session() as sess: 
        ckpt = tf.train.get_checkpoint_state(mnist_train.MODEL_SAVE_PATH) 
        if ckpt and ckpt.model_checkpoint_path: 
          saver.restore(sess, ckpt.model_checkpoint_path) 
          # 取checkpoint文件中的当前迭代轮数global_step 
          global_step = ckpt.model_checkpoint_path\ 
                   .split('/')[-1].split('-')[-1] 
          accuracy_score = sess.run(accuracy, feed_dict=validate_feed) 
          print("After %s training step(s), validation accuracy = %g"\ 
             % (global_step, accuracy_score)) 
 
        else: 
          print('No checkpoint file found') 
          return 
      time.sleep(EVAL_INTERVAL_SECS) 
 
def main(argv=None): 
  mnist = input_data.read_data_sets("MNIST_data", one_hot=True) 
  evaluate(mnist) 
 
if __name__ == '__main__': 
  tf.app.run()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python GAE、Django导出Excel的方法
Nov 24 Python
Python中的闭包详细介绍和实例
Nov 21 Python
使用Python求解最大公约数的实现方法
Aug 20 Python
浅谈Python的垃圾回收机制
Dec 17 Python
Python基于SMTP协议实现发送邮件功能详解
Aug 14 Python
python实现决策树分类
Aug 30 Python
python正则表达式匹配[]中间为任意字符的实例
Dec 25 Python
基于OpenCV python3实现证件照换背景的方法
Mar 22 Python
python解释器spython使用及原理解析
Aug 24 Python
Python基于time模块表示时间常用方法
Jun 18 Python
Python离线安装openpyxl模块的步骤
Mar 30 Python
django如何自定义manage.py管理命令
Apr 27 Python
TensorFlow实现Batch Normalization
Mar 08 #Python
用Django实现一个可运行的区块链应用
Mar 08 #Python
Python pyinotify日志监控系统处理日志的方法
Mar 08 #Python
TensorFlow模型保存和提取的方法
Mar 08 #Python
火车票抢票python代码公开揭秘!
Mar 08 #Python
Python实现定时备份mysql数据库并把备份数据库邮件发送
Mar 08 #Python
python实现12306抢票及自动邮件发送提醒付款功能
Mar 08 #Python
You might like
php daodb插入、更新与删除数据
2009/03/19 PHP
php在文件指定行中写入代码的方法
2012/05/23 PHP
Thinkphp搜索时首页分页和搜索页保持条件分页的方法
2014/12/05 PHP
php使用Jpgraph绘制饼状图的方法
2015/06/10 PHP
PHP图形计数器程序显示网站用户浏览量
2016/07/20 PHP
“不能执行已释放的Script代码”错误的原因及解决办法
2007/09/09 Javascript
传递参数的标准方法(jQuery.ajax)
2008/11/19 Javascript
Google排名中的10个最著名的 JavaScript库
2010/04/27 Javascript
jQuery学习笔记(1)--用jQuery实现异步通信(用json传值)具体思路
2013/04/08 Javascript
中文字符串截取的js函数代码
2013/04/17 Javascript
可选择和输入的下拉列表框示例
2013/11/05 Javascript
超级给力的JavaScript的React框架入门教程
2015/07/02 Javascript
JavaScript使用DeviceOne开发实战(四)仿优酷视频应用
2015/12/02 Javascript
详解jQuery移动页面开发中的ui-grid网格布局使用
2015/12/03 Javascript
js根据手机客户端浏览器类型,判断跳转官网/手机网站多个实例代码
2016/04/30 Javascript
jQuery版AJAX简易封装代码
2016/09/14 Javascript
BootStrap框架个人总结(bootstrap框架、导航条、下拉菜单、轮播广告carousel、栅格系统布局、标签页tabs、模态框、菜单定位)
2016/12/01 Javascript
jQuery实现导航回弹效果
2017/02/27 Javascript
自带气泡提示的vue校验插件(vue-verify-pop)
2017/04/07 Javascript
Angular.js实现动态加载组件详解
2017/05/28 Javascript
Vue项目使用CDN优化首屏加载问题
2018/04/01 Javascript
原生js实现淘宝放大镜效果
2020/10/28 Javascript
Electron-vue开发的客户端支付收款工具的实现
2019/05/24 Javascript
python 判断自定义对象类型
2009/03/21 Python
python搭建简易服务器分析与实现
2012/12/15 Python
酷! 程序员用Python带你玩转冲顶大会
2018/01/17 Python
用python实现将数组元素按从小到大的顺序排列方法
2018/07/02 Python
Python3.7实现中控考勤机自动连接
2018/08/28 Python
在python中对变量判断是否为None的三种方法总结
2019/01/23 Python
解决django后台管理界面添加中文内容乱码问题
2019/11/15 Python
wxPython实现文本框基础组件
2019/11/18 Python
简历中自我评价分享
2013/10/09 职场文书
初中三年学生的学习自我评价
2013/11/13 职场文书
一百条裙子读书笔记
2015/07/01 职场文书
如何理解Vue简单状态管理之store模式
2021/05/15 Vue.js
利用JuiceFS使MySQL 备份验证性能提升 10 倍
2022/03/17 MySQL