TensorFlow搭建神经网络最佳实践


Posted in Python onMarch 09, 2018

一、TensorFLow完整样例

在MNIST数据集上,搭建一个简单神经网络结构,一个包含ReLU单元的非线性化处理的两层神经网络。在训练神经网络的时候,使用带指数衰减的学习率设置、使用正则化来避免过拟合、使用滑动平均模型来使得最终的模型更加健壮。

程序将计算神经网络前向传播的部分单独定义一个函数inference,训练部分定义一个train函数,再定义一个主函数main。

完整程序:

#!/usr/bin/env python3 
# -*- coding: utf-8 -*- 
""" 
Created on Thu May 25 08:56:30 2017 
 
@author: marsjhao 
""" 
 
import tensorflow as tf 
from tensorflow.examples.tutorials.mnist import input_data 
 
INPUT_NODE = 784 # 输入节点数 
OUTPUT_NODE = 10 # 输出节点数 
LAYER1_NODE = 500 # 隐含层节点数 
BATCH_SIZE = 100 
LEARNING_RETE_BASE = 0.8 # 基学习率 
LEARNING_RETE_DECAY = 0.99 # 学习率的衰减率 
REGULARIZATION_RATE = 0.0001 # 正则化项的权重系数 
TRAINING_STEPS = 10000 # 迭代训练次数 
MOVING_AVERAGE_DECAY = 0.99 # 滑动平均的衰减系数 
 
# 传入神经网络的权重和偏置,计算神经网络前向传播的结果 
def inference(input_tensor, avg_class, weights1, biases1, weights2, biases2): 
  # 判断是否传入ExponentialMovingAverage类对象 
  if avg_class == None: 
    layer1 = tf.nn.relu(tf.matmul(input_tensor, weights1) + biases1) 
    return tf.matmul(layer1, weights2) + biases2 
  else: 
    layer1 = tf.nn.relu(tf.matmul(input_tensor, avg_class.average(weights1)) 
                   + avg_class.average(biases1)) 
    return tf.matmul(layer1, avg_class.average(weights2))\ 
             + avg_class.average(biases2) 
 
# 神经网络模型的训练过程 
def train(mnist): 
  x = tf.placeholder(tf.float32, [None,INPUT_NODE], name='x-input') 
  y_ = tf.placeholder(tf.float32, [None, OUTPUT_NODE], name='y-input') 
 
  # 定义神经网络结构的参数 
  weights1 = tf.Variable(tf.truncated_normal([INPUT_NODE, LAYER1_NODE], 
                        stddev=0.1)) 
  biases1 = tf.Variable(tf.constant(0.1, shape=[LAYER1_NODE])) 
  weights2 = tf.Variable(tf.truncated_normal([LAYER1_NODE, OUTPUT_NODE], 
                        stddev=0.1)) 
  biases2 = tf.Variable(tf.constant(0.1, shape=[OUTPUT_NODE])) 
 
  # 计算非滑动平均模型下的参数的前向传播的结果 
  y = inference(x, None, weights1, biases1, weights2, biases2) 
   
  global_step = tf.Variable(0, trainable=False) # 定义存储当前迭代训练轮数的变量 
 
  # 定义ExponentialMovingAverage类对象 
  variable_averages = tf.train.ExponentialMovingAverage( 
            MOVING_AVERAGE_DECAY, global_step) # 传入当前迭代轮数参数 
  # 定义对所有可训练变量trainable_variables进行更新滑动平均值的操作op 
  variables_averages_op = variable_averages.apply(tf.trainable_variables()) 
 
  # 计算滑动模型下的参数的前向传播的结果 
  average_y = inference(x, variable_averages, weights1, biases1, weights2, biases2) 
 
  # 定义交叉熵损失值 
  cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( 
          logits=y, labels=tf.argmax(y_, 1)) 
  cross_entropy_mean = tf.reduce_mean(cross_entropy) 
  # 定义L2正则化器并对weights1和weights2正则化 
  regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE) 
  regularization = regularizer(weights1) + regularizer(weights2) 
  loss = cross_entropy_mean + regularization # 总损失值 
 
  # 定义指数衰减学习率 
  learning_rate = tf.train.exponential_decay(LEARNING_RETE_BASE, global_step, 
          mnist.train.num_examples / BATCH_SIZE, LEARNING_RETE_DECAY) 
  # 定义梯度下降操作op,global_step参数可实现自加1运算 
  train_step = tf.train.GradientDescentOptimizer(learning_rate)\ 
             .minimize(loss, global_step=global_step) 
  # 组合两个操作op 
  train_op = tf.group(train_step, variables_averages_op) 
  ''''' 
  # 与tf.group()等价的语句 
  with tf.control_dependencies([train_step, variables_averages_op]): 
    train_op = tf.no_op(name='train') 
  ''' 
  # 定义准确率 
  # 在最终预测的时候,神经网络的输出采用的是经过滑动平均的前向传播计算结果 
  correct_prediction = tf.equal(tf.argmax(average_y, 1), tf.argmax(y_, 1)) 
  accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 
 
  # 初始化回话sess并开始迭代训练 
  with tf.Session() as sess: 
    sess.run(tf.global_variables_initializer()) 
    # 验证集待喂入数据 
    validate_feed = {x: mnist.validation.images, y_: mnist.validation.labels} 
    # 测试集待喂入数据 
    test_feed = {x: mnist.test.images, y_: mnist.test.labels} 
    for i in range(TRAINING_STEPS): 
      if i % 1000 == 0: 
        validate_acc = sess.run(accuracy, feed_dict=validate_feed) 
        print('After %d training steps, validation accuracy' 
           ' using average model is %f' % (i, validate_acc)) 
      xs, ys = mnist.train.next_batch(BATCH_SIZE) 
      sess.run(train_op, feed_dict={x: xs, y_:ys}) 
 
    test_acc = sess.run(accuracy, feed_dict=test_feed) 
    print('After %d training steps, test accuracy' 
       ' using average model is %f' % (TRAINING_STEPS, test_acc)) 
 
# 主函数 
def main(argv=None): 
  mnist = input_data.read_data_sets("MNIST_data", one_hot=True) 
  train(mnist) 
 
# 当前的python文件是shell文件执行的入口文件,而非当做import的python module。 
if __name__ == '__main__': # 在模块内部执行 
  tf.app.run() # 调用main函数并传入所需的参数list

二、分析与改进设计

1. 程序分析改进

第一,计算前向传播的函数inference中需要将所有的变量以参数的形式传入函数,当神经网络结构变得更加复杂、参数更多的时候,程序的可读性将变得非常差。

第二,在程序退出时,训练好的模型就无法再利用,且大型神经网络的训练时间都比较长,在训练过程中需要每隔一段时间保存一次模型训练的中间结果,这样如果在训练过程中程序死机,死机前的最新的模型参数仍能保留,杜绝了时间和资源的浪费。

第三,将训练和测试分成两个独立的程序,将训练和测试都会用到的前向传播的过程抽象成单独的库函数。这样就保证了在训练和预测两个过程中所调用的前向传播计算程序是一致的。

2. 改进后程序设计

mnist_inference.py

该文件中定义了神经网络的前向传播过程,其中的多次用到的weights定义过程又单独定义成函数。

通过tf.get_variable函数来获取变量,在神经网络训练时创建这些变量,在测试时会通过保存的模型加载这些变量的取值,而且可以在变量加载时将滑动平均值重命名。所以可以直接通过同样的名字在训练时使用变量自身,在测试时使用变量的滑动平均值。

mnist_train.py

该程序给出了神经网络的完整训练过程。

mnist_eval.py

在滑动平均模型上做测试。

通过tf.train.get_checkpoint_state(mnist_train.MODEL_SAVE_PATH)获取最新模型的文件名,实际是获取checkpoint文件的所有内容。

三、TensorFlow最佳实践样例

mnist_inference.py

import tensorflow as tf 
 
INPUT_NODE = 784 
OUTPUT_NODE = 10 
LAYER1_NODE = 500 
 
def get_weight_variable(shape, regularizer): 
  weights = tf.get_variable("weights", shape, 
         initializer=tf.truncated_normal_initializer(stddev=0.1)) 
  if regularizer != None: 
    # 将权重参数的正则化项加入至损失集合 
    tf.add_to_collection('losses', regularizer(weights)) 
  return weights 
 
def inference(input_tensor, regularizer): 
  with tf.variable_scope('layer1'): 
    weights = get_weight_variable([INPUT_NODE, LAYER1_NODE], regularizer) 
    biases = tf.get_variable("biases", [LAYER1_NODE], 
                 initializer=tf.constant_initializer(0.0)) 
    layer1 = tf.nn.relu(tf.matmul(input_tensor, weights) + biases) 
 
  with tf.variable_scope('layer2'): 
    weights = get_weight_variable([LAYER1_NODE, OUTPUT_NODE], regularizer) 
    biases = tf.get_variable("biases", [OUTPUT_NODE], 
                 initializer=tf.constant_initializer(0.0)) 
    layer2 = tf.matmul(layer1, weights) + biases 
 
  return layer2

mnist_train.py

import os 
import tensorflow as tf 
from tensorflow.examples.tutorials.mnist import input_data 
import mnist_inference 
 
BATCH_SIZE = 100 
LEARNING_RATE_BASE = 0.8 
LEARNING_RATE_DECAY = 0.99 
REGULARIZATION_RATE = 0.0001 
TRAINING_STEPS = 10000 
MOVING_AVERAGE_DECAY = 0.99 
 
MODEL_SAVE_PATH = "Model_Folder/" 
MODEL_NAME = "model.ckpt" 
 
def train(mnist): 
  # 定义输入placeholder 
  x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], 
            name='x-input') 
  y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], 
            name='y-input') 
  # 定义正则化器及计算前向过程输出 
  regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE) 
  y = mnist_inference.inference(x, regularizer) 
  # 定义当前训练轮数及滑动平均模型 
  global_step = tf.Variable(0, trainable=False) 
  variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, 
                             global_step) 
  variables_averages_op = variable_averages.apply(tf.trainable_variables()) 
  # 定义损失函数 
  cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, 
                          labels=tf.argmax(y_, 1)) 
  cross_entropy_mean = tf.reduce_mean(cross_entropy) 
  loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses')) 
  # 定义指数衰减学习率 
  learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE, global_step, 
          mnist.train.num_examples / BATCH_SIZE, LEARNING_RATE_DECAY) 
  # 定义训练操作,包括模型训练及滑动模型操作 
  train_step = tf.train.GradientDescentOptimizer(learning_rate)\ 
          .minimize(loss, global_step=global_step) 
  train_op = tf.group(train_step, variables_averages_op) 
  # 定义Saver类对象,保存模型,TensorFlow持久化类 
  saver = tf.train.Saver() 
 
  # 定义会话,启动训练过程 
  with tf.Session() as sess: 
    tf.global_variables_initializer().run() 
 
    for i in range(TRAINING_STEPS): 
      xs, ys = mnist.train.next_batch(BATCH_SIZE) 
      _, loss_value, step = sess.run([train_op, loss, global_step], 
                      feed_dict={x: xs, y_: ys}) 
      if i % 1000 == 0: 
        print("After %d training step(s), loss on training batch is %g."\ 
            % (step, loss_value)) 
        # save方法的global_step参数可以让每个被保存的模型的文件名末尾加上当前训练轮数 
        saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), 
              global_step=global_step) 
 
def main(argv=None): 
  mnist = input_data.read_data_sets("MNIST_data", one_hot=True) 
  train(mnist) 
 
if __name__ == '__main__': 
  tf.app.run()

mnist_eval.py

import time 
import tensorflow as tf 
from tensorflow.examples.tutorials.mnist import input_data 
import mnist_inference 
import mnist_train 
 
EVAL_INTERVAL_SECS = 10 
 
def evaluate(mnist): 
  with tf.Graph().as_default() as g: 
    # 定义输入placeholder 
    x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], 
              name='x-input') 
    y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], 
              name='y-input') 
    # 定义feed字典 
    validate_feed = {x: mnist.validation.images, y_: mnist.validation.labels} 
    # 测试时不加参数正则化损失 
    y = mnist_inference.inference(x, None) 
    # 计算正确率 
    correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) 
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 
    # 加载滑动平均模型下的参数值 
    variable_averages = tf.train.ExponentialMovingAverage( 
                   mnist_train.MOVING_AVERAGE_DECAY) 
    saver = tf.train.Saver(variable_averages.variables_to_restore()) 
 
    # 每隔EVAL_INTERVAL_SECS秒启动一次会话 
    while True: 
      with tf.Session() as sess: 
        ckpt = tf.train.get_checkpoint_state(mnist_train.MODEL_SAVE_PATH) 
        if ckpt and ckpt.model_checkpoint_path: 
          saver.restore(sess, ckpt.model_checkpoint_path) 
          # 取checkpoint文件中的当前迭代轮数global_step 
          global_step = ckpt.model_checkpoint_path\ 
                   .split('/')[-1].split('-')[-1] 
          accuracy_score = sess.run(accuracy, feed_dict=validate_feed) 
          print("After %s training step(s), validation accuracy = %g"\ 
             % (global_step, accuracy_score)) 
 
        else: 
          print('No checkpoint file found') 
          return 
      time.sleep(EVAL_INTERVAL_SECS) 
 
def main(argv=None): 
  mnist = input_data.read_data_sets("MNIST_data", one_hot=True) 
  evaluate(mnist) 
 
if __name__ == '__main__': 
  tf.app.run()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python dict remove数组删除(del,pop)
Mar 24 Python
python入门教程之识别验证码
Mar 04 Python
python实现随机森林random forest的原理及方法
Dec 21 Python
python处理数据,存进hive表的方法
Jul 04 Python
python3.7.0的安装步骤
Aug 27 Python
Python sorted函数详解(高级篇)
Sep 18 Python
python脚本执行CMD命令并返回结果的例子
Aug 14 Python
python调用支付宝支付接口流程
Aug 15 Python
基于Python实现ComicReaper漫画自动爬取脚本过程解析
Nov 11 Python
解决pycharm上的jupyter notebook端口被占用问题
Dec 17 Python
python 计算方位角实例(根据两点的坐标计算)
Jan 17 Python
python匿名函数lambda原理及实例解析
Feb 07 Python
TensorFlow实现Batch Normalization
Mar 08 #Python
用Django实现一个可运行的区块链应用
Mar 08 #Python
Python pyinotify日志监控系统处理日志的方法
Mar 08 #Python
TensorFlow模型保存和提取的方法
Mar 08 #Python
火车票抢票python代码公开揭秘!
Mar 08 #Python
Python实现定时备份mysql数据库并把备份数据库邮件发送
Mar 08 #Python
python实现12306抢票及自动邮件发送提醒付款功能
Mar 08 #Python
You might like
php连接mysql数据库最简单的实现方法
2019/09/24 PHP
用js实现手把手教你月入万刀(转贴)
2007/11/07 Javascript
JavaScript 学习笔记 Black.Caffeine 09.11.28
2009/11/30 Javascript
JavaScript中出现乱码的处理心得
2009/12/24 Javascript
javascript结合fileReader 实现上传图片
2015/01/30 Javascript
IE下支持文本框和密码框placeholder效果的JQuery插件分享
2015/01/31 Javascript
JavaScript中Boolean对象的属性解析
2015/10/21 Javascript
JS获取地址栏参数的两种方法(简单实用)
2016/06/14 Javascript
EasyUI Pagination 分页的两种做法小结
2016/07/09 Javascript
require.js配合插件text.js实现最简单的单页应用程序
2016/07/12 Javascript
JS实现的简单拖拽购物车功能示例【附源码下载】
2018/01/03 Javascript
在 React、Vue项目中使用SVG的方法
2018/02/09 Javascript
Vue插件打包与发布的方法示例
2018/08/20 Javascript
es6基础学习之解构赋值
2018/12/10 Javascript
详解微信小程序-canvas绘制文字实现自动换行
2019/04/26 Javascript
vue2配置scss的方法步骤
2019/06/06 Javascript
layui自定义ajax左侧三级菜单
2019/07/26 Javascript
Vue中点击active并第一个默认选中功能的实现
2020/02/24 Javascript
[01:38]2018DOTA2亚洲邀请赛主赛事第二日现场采访 神秘商人痛陈生计不易
2018/04/05 DOTA
跟老齐学Python之有点简约的元组
2014/09/24 Python
Python实现感知机(PLA)算法
2017/12/20 Python
详解python之heapq模块及排序操作
2019/04/04 Python
python re.match()用法相关示例
2021/01/27 Python
CSS3选择器新增问题的实现
2021/01/21 HTML / CSS
美国知名珠宝首饰品牌:Gemvara
2017/10/06 全球购物
西铁城美国官方网站:Citizen Watch美国
2019/11/08 全球购物
linux面试题参考答案(4)
2014/09/21 面试题
贷款承诺书
2015/01/20 职场文书
2016年春季运动会广播稿
2015/08/19 职场文书
2016应届毕业生实习心得体会
2015/10/09 职场文书
党员公开承诺书(2016最新版)
2016/03/24 职场文书
PHP连接MSSQL数据库案例,PHPWAMP多个PHP版本连接SQL Server数据库
2021/04/16 PHP
使用Python的开发框架Brownie部署以太坊智能合约
2021/05/28 Python
SpringBoot中HttpSessionListener的简单使用方式
2022/03/17 Java/Android
Spring Cloud Netflix 套件中的负载均衡组件 Ribbon
2022/04/13 Java/Android
el-table-column 内容不自动换行的解决方法
2022/08/14 Vue.js