TensorFlow搭建神经网络最佳实践


Posted in Python onMarch 09, 2018

一、TensorFLow完整样例

在MNIST数据集上,搭建一个简单神经网络结构,一个包含ReLU单元的非线性化处理的两层神经网络。在训练神经网络的时候,使用带指数衰减的学习率设置、使用正则化来避免过拟合、使用滑动平均模型来使得最终的模型更加健壮。

程序将计算神经网络前向传播的部分单独定义一个函数inference,训练部分定义一个train函数,再定义一个主函数main。

完整程序:

#!/usr/bin/env python3 
# -*- coding: utf-8 -*- 
""" 
Created on Thu May 25 08:56:30 2017 
 
@author: marsjhao 
""" 
 
import tensorflow as tf 
from tensorflow.examples.tutorials.mnist import input_data 
 
INPUT_NODE = 784 # 输入节点数 
OUTPUT_NODE = 10 # 输出节点数 
LAYER1_NODE = 500 # 隐含层节点数 
BATCH_SIZE = 100 
LEARNING_RETE_BASE = 0.8 # 基学习率 
LEARNING_RETE_DECAY = 0.99 # 学习率的衰减率 
REGULARIZATION_RATE = 0.0001 # 正则化项的权重系数 
TRAINING_STEPS = 10000 # 迭代训练次数 
MOVING_AVERAGE_DECAY = 0.99 # 滑动平均的衰减系数 
 
# 传入神经网络的权重和偏置,计算神经网络前向传播的结果 
def inference(input_tensor, avg_class, weights1, biases1, weights2, biases2): 
  # 判断是否传入ExponentialMovingAverage类对象 
  if avg_class == None: 
    layer1 = tf.nn.relu(tf.matmul(input_tensor, weights1) + biases1) 
    return tf.matmul(layer1, weights2) + biases2 
  else: 
    layer1 = tf.nn.relu(tf.matmul(input_tensor, avg_class.average(weights1)) 
                   + avg_class.average(biases1)) 
    return tf.matmul(layer1, avg_class.average(weights2))\ 
             + avg_class.average(biases2) 
 
# 神经网络模型的训练过程 
def train(mnist): 
  x = tf.placeholder(tf.float32, [None,INPUT_NODE], name='x-input') 
  y_ = tf.placeholder(tf.float32, [None, OUTPUT_NODE], name='y-input') 
 
  # 定义神经网络结构的参数 
  weights1 = tf.Variable(tf.truncated_normal([INPUT_NODE, LAYER1_NODE], 
                        stddev=0.1)) 
  biases1 = tf.Variable(tf.constant(0.1, shape=[LAYER1_NODE])) 
  weights2 = tf.Variable(tf.truncated_normal([LAYER1_NODE, OUTPUT_NODE], 
                        stddev=0.1)) 
  biases2 = tf.Variable(tf.constant(0.1, shape=[OUTPUT_NODE])) 
 
  # 计算非滑动平均模型下的参数的前向传播的结果 
  y = inference(x, None, weights1, biases1, weights2, biases2) 
   
  global_step = tf.Variable(0, trainable=False) # 定义存储当前迭代训练轮数的变量 
 
  # 定义ExponentialMovingAverage类对象 
  variable_averages = tf.train.ExponentialMovingAverage( 
            MOVING_AVERAGE_DECAY, global_step) # 传入当前迭代轮数参数 
  # 定义对所有可训练变量trainable_variables进行更新滑动平均值的操作op 
  variables_averages_op = variable_averages.apply(tf.trainable_variables()) 
 
  # 计算滑动模型下的参数的前向传播的结果 
  average_y = inference(x, variable_averages, weights1, biases1, weights2, biases2) 
 
  # 定义交叉熵损失值 
  cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( 
          logits=y, labels=tf.argmax(y_, 1)) 
  cross_entropy_mean = tf.reduce_mean(cross_entropy) 
  # 定义L2正则化器并对weights1和weights2正则化 
  regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE) 
  regularization = regularizer(weights1) + regularizer(weights2) 
  loss = cross_entropy_mean + regularization # 总损失值 
 
  # 定义指数衰减学习率 
  learning_rate = tf.train.exponential_decay(LEARNING_RETE_BASE, global_step, 
          mnist.train.num_examples / BATCH_SIZE, LEARNING_RETE_DECAY) 
  # 定义梯度下降操作op,global_step参数可实现自加1运算 
  train_step = tf.train.GradientDescentOptimizer(learning_rate)\ 
             .minimize(loss, global_step=global_step) 
  # 组合两个操作op 
  train_op = tf.group(train_step, variables_averages_op) 
  ''''' 
  # 与tf.group()等价的语句 
  with tf.control_dependencies([train_step, variables_averages_op]): 
    train_op = tf.no_op(name='train') 
  ''' 
  # 定义准确率 
  # 在最终预测的时候,神经网络的输出采用的是经过滑动平均的前向传播计算结果 
  correct_prediction = tf.equal(tf.argmax(average_y, 1), tf.argmax(y_, 1)) 
  accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 
 
  # 初始化回话sess并开始迭代训练 
  with tf.Session() as sess: 
    sess.run(tf.global_variables_initializer()) 
    # 验证集待喂入数据 
    validate_feed = {x: mnist.validation.images, y_: mnist.validation.labels} 
    # 测试集待喂入数据 
    test_feed = {x: mnist.test.images, y_: mnist.test.labels} 
    for i in range(TRAINING_STEPS): 
      if i % 1000 == 0: 
        validate_acc = sess.run(accuracy, feed_dict=validate_feed) 
        print('After %d training steps, validation accuracy' 
           ' using average model is %f' % (i, validate_acc)) 
      xs, ys = mnist.train.next_batch(BATCH_SIZE) 
      sess.run(train_op, feed_dict={x: xs, y_:ys}) 
 
    test_acc = sess.run(accuracy, feed_dict=test_feed) 
    print('After %d training steps, test accuracy' 
       ' using average model is %f' % (TRAINING_STEPS, test_acc)) 
 
# 主函数 
def main(argv=None): 
  mnist = input_data.read_data_sets("MNIST_data", one_hot=True) 
  train(mnist) 
 
# 当前的python文件是shell文件执行的入口文件,而非当做import的python module。 
if __name__ == '__main__': # 在模块内部执行 
  tf.app.run() # 调用main函数并传入所需的参数list

二、分析与改进设计

1. 程序分析改进

第一,计算前向传播的函数inference中需要将所有的变量以参数的形式传入函数,当神经网络结构变得更加复杂、参数更多的时候,程序的可读性将变得非常差。

第二,在程序退出时,训练好的模型就无法再利用,且大型神经网络的训练时间都比较长,在训练过程中需要每隔一段时间保存一次模型训练的中间结果,这样如果在训练过程中程序死机,死机前的最新的模型参数仍能保留,杜绝了时间和资源的浪费。

第三,将训练和测试分成两个独立的程序,将训练和测试都会用到的前向传播的过程抽象成单独的库函数。这样就保证了在训练和预测两个过程中所调用的前向传播计算程序是一致的。

2. 改进后程序设计

mnist_inference.py

该文件中定义了神经网络的前向传播过程,其中的多次用到的weights定义过程又单独定义成函数。

通过tf.get_variable函数来获取变量,在神经网络训练时创建这些变量,在测试时会通过保存的模型加载这些变量的取值,而且可以在变量加载时将滑动平均值重命名。所以可以直接通过同样的名字在训练时使用变量自身,在测试时使用变量的滑动平均值。

mnist_train.py

该程序给出了神经网络的完整训练过程。

mnist_eval.py

在滑动平均模型上做测试。

通过tf.train.get_checkpoint_state(mnist_train.MODEL_SAVE_PATH)获取最新模型的文件名,实际是获取checkpoint文件的所有内容。

三、TensorFlow最佳实践样例

mnist_inference.py

import tensorflow as tf 
 
INPUT_NODE = 784 
OUTPUT_NODE = 10 
LAYER1_NODE = 500 
 
def get_weight_variable(shape, regularizer): 
  weights = tf.get_variable("weights", shape, 
         initializer=tf.truncated_normal_initializer(stddev=0.1)) 
  if regularizer != None: 
    # 将权重参数的正则化项加入至损失集合 
    tf.add_to_collection('losses', regularizer(weights)) 
  return weights 
 
def inference(input_tensor, regularizer): 
  with tf.variable_scope('layer1'): 
    weights = get_weight_variable([INPUT_NODE, LAYER1_NODE], regularizer) 
    biases = tf.get_variable("biases", [LAYER1_NODE], 
                 initializer=tf.constant_initializer(0.0)) 
    layer1 = tf.nn.relu(tf.matmul(input_tensor, weights) + biases) 
 
  with tf.variable_scope('layer2'): 
    weights = get_weight_variable([LAYER1_NODE, OUTPUT_NODE], regularizer) 
    biases = tf.get_variable("biases", [OUTPUT_NODE], 
                 initializer=tf.constant_initializer(0.0)) 
    layer2 = tf.matmul(layer1, weights) + biases 
 
  return layer2

mnist_train.py

import os 
import tensorflow as tf 
from tensorflow.examples.tutorials.mnist import input_data 
import mnist_inference 
 
BATCH_SIZE = 100 
LEARNING_RATE_BASE = 0.8 
LEARNING_RATE_DECAY = 0.99 
REGULARIZATION_RATE = 0.0001 
TRAINING_STEPS = 10000 
MOVING_AVERAGE_DECAY = 0.99 
 
MODEL_SAVE_PATH = "Model_Folder/" 
MODEL_NAME = "model.ckpt" 
 
def train(mnist): 
  # 定义输入placeholder 
  x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], 
            name='x-input') 
  y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], 
            name='y-input') 
  # 定义正则化器及计算前向过程输出 
  regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE) 
  y = mnist_inference.inference(x, regularizer) 
  # 定义当前训练轮数及滑动平均模型 
  global_step = tf.Variable(0, trainable=False) 
  variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, 
                             global_step) 
  variables_averages_op = variable_averages.apply(tf.trainable_variables()) 
  # 定义损失函数 
  cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, 
                          labels=tf.argmax(y_, 1)) 
  cross_entropy_mean = tf.reduce_mean(cross_entropy) 
  loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses')) 
  # 定义指数衰减学习率 
  learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE, global_step, 
          mnist.train.num_examples / BATCH_SIZE, LEARNING_RATE_DECAY) 
  # 定义训练操作,包括模型训练及滑动模型操作 
  train_step = tf.train.GradientDescentOptimizer(learning_rate)\ 
          .minimize(loss, global_step=global_step) 
  train_op = tf.group(train_step, variables_averages_op) 
  # 定义Saver类对象,保存模型,TensorFlow持久化类 
  saver = tf.train.Saver() 
 
  # 定义会话,启动训练过程 
  with tf.Session() as sess: 
    tf.global_variables_initializer().run() 
 
    for i in range(TRAINING_STEPS): 
      xs, ys = mnist.train.next_batch(BATCH_SIZE) 
      _, loss_value, step = sess.run([train_op, loss, global_step], 
                      feed_dict={x: xs, y_: ys}) 
      if i % 1000 == 0: 
        print("After %d training step(s), loss on training batch is %g."\ 
            % (step, loss_value)) 
        # save方法的global_step参数可以让每个被保存的模型的文件名末尾加上当前训练轮数 
        saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), 
              global_step=global_step) 
 
def main(argv=None): 
  mnist = input_data.read_data_sets("MNIST_data", one_hot=True) 
  train(mnist) 
 
if __name__ == '__main__': 
  tf.app.run()

mnist_eval.py

import time 
import tensorflow as tf 
from tensorflow.examples.tutorials.mnist import input_data 
import mnist_inference 
import mnist_train 
 
EVAL_INTERVAL_SECS = 10 
 
def evaluate(mnist): 
  with tf.Graph().as_default() as g: 
    # 定义输入placeholder 
    x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], 
              name='x-input') 
    y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], 
              name='y-input') 
    # 定义feed字典 
    validate_feed = {x: mnist.validation.images, y_: mnist.validation.labels} 
    # 测试时不加参数正则化损失 
    y = mnist_inference.inference(x, None) 
    # 计算正确率 
    correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) 
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 
    # 加载滑动平均模型下的参数值 
    variable_averages = tf.train.ExponentialMovingAverage( 
                   mnist_train.MOVING_AVERAGE_DECAY) 
    saver = tf.train.Saver(variable_averages.variables_to_restore()) 
 
    # 每隔EVAL_INTERVAL_SECS秒启动一次会话 
    while True: 
      with tf.Session() as sess: 
        ckpt = tf.train.get_checkpoint_state(mnist_train.MODEL_SAVE_PATH) 
        if ckpt and ckpt.model_checkpoint_path: 
          saver.restore(sess, ckpt.model_checkpoint_path) 
          # 取checkpoint文件中的当前迭代轮数global_step 
          global_step = ckpt.model_checkpoint_path\ 
                   .split('/')[-1].split('-')[-1] 
          accuracy_score = sess.run(accuracy, feed_dict=validate_feed) 
          print("After %s training step(s), validation accuracy = %g"\ 
             % (global_step, accuracy_score)) 
 
        else: 
          print('No checkpoint file found') 
          return 
      time.sleep(EVAL_INTERVAL_SECS) 
 
def main(argv=None): 
  mnist = input_data.read_data_sets("MNIST_data", one_hot=True) 
  evaluate(mnist) 
 
if __name__ == '__main__': 
  tf.app.run()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现统计代码行数的方法
May 22 Python
Python实现全角半角字符互转的方法
Nov 28 Python
Python爬虫中urllib库的进阶学习
Jan 05 Python
Python根据已知邻接矩阵绘制无向图操作示例
Jun 23 Python
python使用itchat模块给心爱的人每天发天气预报
Nov 25 Python
在notepad++中实现直接运行python代码
Dec 18 Python
利用python在excel中画图的实现方法
Mar 17 Python
python使用OpenCV模块实现图像的融合示例代码
Apr 10 Python
QML实现钟表效果
Jun 02 Python
Python Request类源码实现方法及原理解析
Aug 17 Python
Ubuntu权限不足无法创建文件夹解决方案
Nov 14 Python
Python超简单容易上手的画图工具库推荐
May 10 Python
TensorFlow实现Batch Normalization
Mar 08 #Python
用Django实现一个可运行的区块链应用
Mar 08 #Python
Python pyinotify日志监控系统处理日志的方法
Mar 08 #Python
TensorFlow模型保存和提取的方法
Mar 08 #Python
火车票抢票python代码公开揭秘!
Mar 08 #Python
Python实现定时备份mysql数据库并把备份数据库邮件发送
Mar 08 #Python
python实现12306抢票及自动邮件发送提醒付款功能
Mar 08 #Python
You might like
php 需要掌握的东西 不做浮躁的人
2009/12/28 PHP
PHP教程之PHP中shell脚本的使用方法分享
2012/02/23 PHP
PHP通过插入mysql数据来实现多机互锁实例
2014/11/05 PHP
zend framework中使用memcache的方法
2016/03/04 PHP
php中文语义分析实现方法示例
2019/09/28 PHP
javascript Xml增删改查(IE下)操作实现代码
2009/01/30 Javascript
基于jQuery的淡入淡出可自动切换的幻灯插件
2010/08/24 Javascript
JQuery for与each性能比较分析
2013/05/14 Javascript
javascript制作的网页侧边弹出框思路及实现代码
2014/05/21 Javascript
js全选实现和判断是否有复选框选中的方法
2015/02/17 Javascript
JavaScript实现的一个倒计时的类
2015/03/12 Javascript
JavaScript中的函数声明和函数表达式区别浅析
2015/03/27 Javascript
jQuery控制元素显示、隐藏、切换、滑动的方法总结
2015/04/16 Javascript
简介BootStrap model弹出框的使用
2016/04/27 Javascript
JavaScript函数中关于valueOf和toString的理解
2016/06/14 Javascript
JAVA Web实时消息后台服务器推送技术---GoEasy
2016/11/04 Javascript
JS实现的点击表头排序功能示例
2017/03/27 Javascript
jQuery取得元素标签名称小结(附代码)
2017/08/16 jQuery
vue实现购物车的小练习
2020/12/21 Vue.js
[42:32]DOTA2上海特级锦标赛B组资格赛#2 Fnatic VS Spirit第二局
2016/02/27 DOTA
[40:03]Liquid vs Optic 2018国际邀请赛淘汰赛BO3 第一场 8.21
2018/08/22 DOTA
python学习--使用QQ邮箱发送邮件代码实例
2019/04/16 Python
Python实现的远程文件自动打包并下载功能示例
2019/07/12 Python
python 视频逐帧保存为图片的完整实例
2019/12/10 Python
浅谈Pycharm的项目文件名是红色的原因及解决方式
2020/06/01 Python
美国最好的钓鱼、狩猎和划船装备商店:Bass Pro Shops
2018/12/02 全球购物
瑞士图书网站:Weltbild.ch
2019/09/17 全球购物
广州地球村科技数据库题目
2016/04/25 面试题
银行开业庆典方案
2014/02/06 职场文书
行政助理的岗位职责
2014/02/18 职场文书
毕业生自荐信格式
2014/03/07 职场文书
2015年个人招商工作总结
2015/04/25 职场文书
同学会感言
2015/07/30 职场文书
HTML5简单实现添加背景音乐的几种方法
2021/05/12 HTML / CSS
Windows安装Anaconda3的方法及使用过程详解
2021/06/11 Python
Java并发编程之详解CyclicBarrier线程同步
2021/06/23 Java/Android