TensorFlow实现Batch Normalization


Posted in Python onMarch 08, 2018

一、BN(Batch Normalization)算法

1. 对数据进行归一化处理的重要性

神经网络学习过程的本质就是学习数据分布,在训练数据与测试数据分布不同情况下,模型的泛化能力就大大降低;另一方面,若训练过程中每批batch的数据分布也各不相同,那么网络每批迭代学习过程也会出现较大波动,使之更难趋于收敛,降低训练收敛速度。对于深层网络,网络前几层的微小变化都会被网络累积放大,则训练数据的分布变化问题会被放大,更加影响训练速度。

2. BN算法的强大之处

1)为了加速梯度下降算法的训练,我们可以采取指数衰减学习率等方法在初期快速学习,后期缓慢进入全局最优区域。使用BN算法后,就可以直接选择比较大的学习率,且设置很大的学习率衰减速度,大大提高训练速度。即使选择了较小的学习率,也会比以前不使用BN情况下的收敛速度快。总结就是BN算法具有快速收敛的特性。

2)BN具有提高网络泛化能力的特性。采用BN算法后,就可以移除针对过拟合问题而设置的dropout和L2正则化项,或者采用更小的L2正则化参数。

3)BN本身是一个归一化网络层,则局部响应归一化层(Local Response Normalization,LRN层)则可不需要了(Alexnet网络中使用到)。

3. BN算法概述

BN算法提出了变换重构,引入了可学习参数γ、β,这就是算法的关键之处:

TensorFlow实现Batch Normalization

引入这两个参数后,我们的网络便可以学习恢复出原是网络所要学习的特征分布,BN层的钱箱传到过程如下:

TensorFlow实现Batch Normalization

其中m为batchsize。BatchNormalization中所有的操作都是平滑可导,这使得back propagation可以有效运行并学到相应的参数γ,β。需要注意的一点是Batch Normalization在training和testing时行为有所差别。Training时μβ和σβ由当前batch计算得出;在Testing时μβ和σβ应使用Training时保存的均值或类似的经过处理的值,而不是由当前batch计算。

二、TensorFlow相关函数

1.tf.nn.moments(x, axes, shift=None, name=None, keep_dims=False)

x是输入张量,axes是在哪个维度上求解, 即想要 normalize的维度, [0] 代表 batch 维度,如果是图像数据,可以传入 [0, 1, 2],相当于求[batch, height, width] 的均值/方差,注意不要加入channel 维度。该函数返回两个张量,均值mean和方差variance。

2.tf.identity(input, name=None)

返回与输入张量input形状和内容一致的张量。

3.tf.nn.batch_normalization(x, mean, variance, offset, scale, variance_epsilon,name=None)

计算公式为scale(x - mean)/ variance + offset。

这些参数中,tf.nn.moments可得到均值mean和方差variance,offset和scale是可训练的,offset一般初始化为0,scale初始化为1,offset和scale的shape与mean相同,variance_epsilon参数设为一个很小的值如0.001。

三、TensorFlow代码实现

1. 完整代码

import tensorflow as tf 
import numpy as np 
import matplotlib.pyplot as plt 
 
ACTIVITION = tf.nn.relu 
N_LAYERS = 7 # 总共7层隐藏层 
N_HIDDEN_UNITS = 30 # 每层包含30个神经元 
 
def fix_seed(seed=1): # 设置随机数种子 
  np.random.seed(seed) 
  tf.set_random_seed(seed) 
 
def plot_his(inputs, inputs_norm): # 绘制直方图函数 
  for j, all_inputs in enumerate([inputs, inputs_norm]): 
    for i, input in enumerate(all_inputs): 
      plt.subplot(2, len(all_inputs), j*len(all_inputs)+(i+1)) 
      plt.cla() 
      if i == 0: 
        the_range = (-7, 10) 
      else: 
        the_range = (-1, 1) 
      plt.hist(input.ravel(), bins=15, range=the_range, color='#FF5733') 
      plt.yticks(()) 
      if j == 1: 
        plt.xticks(the_range) 
      else: 
        plt.xticks(()) 
      ax = plt.gca() 
      ax.spines['right'].set_color('none') 
      ax.spines['top'].set_color('none') 
    plt.title("%s normalizing" % ("Without" if j == 0 else "With")) 
  plt.draw() 
  plt.pause(0.01) 
 
def built_net(xs, ys, norm): # 搭建网络函数 
  # 添加层 
  def add_layer(inputs, in_size, out_size, activation_function=None, norm=False): 
    Weights = tf.Variable(tf.random_normal([in_size, out_size], 
                        mean=0.0, stddev=1.0)) 
    biases = tf.Variable(tf.zeros([1, out_size]) + 0.1) 
    Wx_plus_b = tf.matmul(inputs, Weights) + biases 
 
    if norm: # 判断是否是Batch Normalization层 
      # 计算均值和方差,axes参数0表示batch维度 
      fc_mean, fc_var = tf.nn.moments(Wx_plus_b, axes=[0]) 
      scale = tf.Variable(tf.ones([out_size])) 
      shift = tf.Variable(tf.zeros([out_size])) 
      epsilon = 0.001 
 
      # 定义滑动平均模型对象 
      ema = tf.train.ExponentialMovingAverage(decay=0.5) 
 
      def mean_var_with_update(): 
        ema_apply_op = ema.apply([fc_mean, fc_var]) 
        with tf.control_dependencies([ema_apply_op]): 
          return tf.identity(fc_mean), tf.identity(fc_var) 
 
      mean, var = mean_var_with_update() 
 
      Wx_plus_b = tf.nn.batch_normalization(Wx_plus_b, mean, var, 
                         shift, scale, epsilon) 
 
    if activation_function is None: 
      outputs = Wx_plus_b 
    else: 
      outputs = activation_function(Wx_plus_b) 
    return outputs 
 
  fix_seed(1) 
 
  if norm: # 为第一层进行BN 
    fc_mean, fc_var = tf.nn.moments(xs, axes=[0]) 
    scale = tf.Variable(tf.ones([1])) 
    shift = tf.Variable(tf.zeros([1])) 
    epsilon = 0.001 
 
    ema = tf.train.ExponentialMovingAverage(decay=0.5) 
 
    def mean_var_with_update(): 
      ema_apply_op = ema.apply([fc_mean, fc_var]) 
      with tf.control_dependencies([ema_apply_op]): 
        return tf.identity(fc_mean), tf.identity(fc_var) 
 
    mean, var = mean_var_with_update() 
    xs = tf.nn.batch_normalization(xs, mean, var, shift, scale, epsilon) 
 
  layers_inputs = [xs] # 记录每一层的输入 
 
  for l_n in range(N_LAYERS): # 依次添加7层 
    layer_input = layers_inputs[l_n] 
    in_size = layers_inputs[l_n].get_shape()[1].value 
 
    output = add_layer(layer_input, in_size, N_HIDDEN_UNITS, ACTIVITION, norm) 
    layers_inputs.append(output) 
 
  prediction = add_layer(layers_inputs[-1], 30, 1, activation_function=None) 
  cost = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction), 
                    reduction_indices=[1])) 
 
  train_op = tf.train.GradientDescentOptimizer(0.001).minimize(cost) 
  return [train_op, cost, layers_inputs] 
 
fix_seed(1) 
x_data = np.linspace(-7, 10, 2500)[:, np.newaxis] 
np.random.shuffle(x_data) 
noise =np.random.normal(0, 8, x_data.shape) 
y_data = np.square(x_data) - 5 + noise 
 
plt.scatter(x_data, y_data) 
plt.show() 
 
xs = tf.placeholder(tf.float32, [None, 1]) 
ys = tf.placeholder(tf.float32, [None, 1]) 
 
train_op, cost, layers_inputs = built_net(xs, ys, norm=False) 
train_op_norm, cost_norm, layers_inputs_norm = built_net(xs, ys, norm=True) 
 
with tf.Session() as sess: 
  sess.run(tf.global_variables_initializer()) 
 
  cost_his = [] 
  cost_his_norm = [] 
  record_step = 5 
 
  plt.ion() 
  plt.figure(figsize=(7, 3)) 
  for i in range(250): 
    if i % 50 == 0: 
      all_inputs, all_inputs_norm = sess.run([layers_inputs, layers_inputs_norm], 
                          feed_dict={xs: x_data, ys: y_data}) 
      plot_his(all_inputs, all_inputs_norm) 
 
    sess.run([train_op, train_op_norm], 
         feed_dict={xs: x_data[i*10:i*10+10], ys: y_data[i*10:i*10+10]}) 
 
    if i % record_step == 0: 
      cost_his.append(sess.run(cost, feed_dict={xs: x_data, ys: y_data})) 
      cost_his_norm.append(sess.run(cost_norm, 
                     feed_dict={xs: x_data, ys: y_data})) 
 
  plt.ioff() 
  plt.figure() 
  plt.plot(np.arange(len(cost_his))*record_step, 
       np.array(cost_his), label='Without BN')   # no norm 
  plt.plot(np.arange(len(cost_his))*record_step, 
       np.array(cost_his_norm), label='With BN')  # norm 
  plt.legend() 
  plt.show()

2. 实验结果

输入数据分布:

TensorFlow实现Batch Normalization

批标准化BN效果对比:

TensorFlow实现Batch Normalization

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python的正则表达式re模块的常用方法
Mar 09 Python
Python遍历目录并批量更换文件名和目录名的方法
Sep 19 Python
python爬虫爬取网页表格数据
Mar 07 Python
python中pylint使用方法(pylint代码检查)
Apr 06 Python
python监测当前联网状态并连接的实例
Dec 18 Python
django2.0扩展用户字段示例
Feb 13 Python
python 利用jinja2模板生成html代码实例
Oct 10 Python
Python 下载及安装详细步骤
Nov 04 Python
Python之——生成动态路由轨迹图的实例
Nov 22 Python
Python JSON编解码方式原理详解
Jan 20 Python
python实现马丁策略回测3000只股票的实例代码
Jan 22 Python
Python Django搭建文件下载服务器的实现
May 10 Python
用Django实现一个可运行的区块链应用
Mar 08 #Python
Python pyinotify日志监控系统处理日志的方法
Mar 08 #Python
TensorFlow模型保存和提取的方法
Mar 08 #Python
火车票抢票python代码公开揭秘!
Mar 08 #Python
Python实现定时备份mysql数据库并把备份数据库邮件发送
Mar 08 #Python
python实现12306抢票及自动邮件发送提醒付款功能
Mar 08 #Python
TensorFlow模型保存/载入的两种方法
Mar 08 #Python
You might like
仿Aspnetpager的一个PHP分页类代码 附源码下载
2012/10/08 PHP
学习PHP Cookie处理函数
2016/08/09 PHP
网上抓的一个特效
2007/05/11 Javascript
javascript 异步页面查询实现代码(asp.net)
2010/05/26 Javascript
jQuery EasyUI API 中文文档 - Pagination分页
2011/09/29 Javascript
JavaScript高级程序设计(第3版)学习笔记7 js函数(上)
2012/10/11 Javascript
Egret引擎开发指南之视觉编程
2014/09/03 Javascript
关于JS中的apply,call,bind的深入解析
2016/04/05 Javascript
AngularJS中的Promise详细介绍及实例代码
2016/12/13 Javascript
JS实现根据密码长度显示安全条功能
2017/03/08 Javascript
详解JavaScript数组过滤相同元素的5种方法
2017/05/23 Javascript
javascript  数组排序与对象排序的实例
2017/07/17 Javascript
AngularJs 延时器、计时器实例代码
2017/09/16 Javascript
JS简单实现父子窗口传值功能示例【未使用iframe框架】
2017/09/20 Javascript
vux uploader 图片上传组件的安装使用方法
2018/05/15 Javascript
JavaScript实现简单的图片切换功能(实例代码)
2020/04/10 Javascript
vue 解决provide和inject响应的问题
2020/11/12 Javascript
Python实现Youku视频批量下载功能
2017/03/14 Python
Python 使用 environs 库定义环境变量的方法
2020/02/25 Python
如何基于python实现年会抽奖工具
2020/10/20 Python
html5实现的便签特效(实战分享)
2013/11/29 HTML / CSS
HTML5 canvas基本绘图之填充样式实现
2016/06/27 HTML / CSS
香港太阳眼镜网上商店:SmartBuyGlasses香港
2016/07/22 全球购物
英国领先的运动物理治疗供应公司:Vivomed
2018/07/14 全球购物
What's the difference between Debug and Trace class? (Debug类与Trace类有什么区别)
2013/09/10 面试题
中央空调节能方案
2014/06/15 职场文书
国际贸易毕业生求职信
2014/07/20 职场文书
办护照工作证明
2014/10/01 职场文书
部门群众路线教育实践活动对照检查材料思想汇报
2014/10/07 职场文书
婚前保证书范文
2015/02/28 职场文书
求职简历自我评价怎么写
2015/03/10 职场文书
领导干部失职检讨书
2015/05/05 职场文书
Python爬虫进阶之Beautiful Soup库详解
2021/04/29 Python
python单元测试之pytest的使用
2021/06/07 Python
python可视化大屏库big_screen示例详解
2021/11/23 Python
Python使用pandas导入xlsx格式的excel文件内容操作代码
2022/12/24 Python