TensorFlow实现Batch Normalization


Posted in Python onMarch 08, 2018

一、BN(Batch Normalization)算法

1. 对数据进行归一化处理的重要性

神经网络学习过程的本质就是学习数据分布,在训练数据与测试数据分布不同情况下,模型的泛化能力就大大降低;另一方面,若训练过程中每批batch的数据分布也各不相同,那么网络每批迭代学习过程也会出现较大波动,使之更难趋于收敛,降低训练收敛速度。对于深层网络,网络前几层的微小变化都会被网络累积放大,则训练数据的分布变化问题会被放大,更加影响训练速度。

2. BN算法的强大之处

1)为了加速梯度下降算法的训练,我们可以采取指数衰减学习率等方法在初期快速学习,后期缓慢进入全局最优区域。使用BN算法后,就可以直接选择比较大的学习率,且设置很大的学习率衰减速度,大大提高训练速度。即使选择了较小的学习率,也会比以前不使用BN情况下的收敛速度快。总结就是BN算法具有快速收敛的特性。

2)BN具有提高网络泛化能力的特性。采用BN算法后,就可以移除针对过拟合问题而设置的dropout和L2正则化项,或者采用更小的L2正则化参数。

3)BN本身是一个归一化网络层,则局部响应归一化层(Local Response Normalization,LRN层)则可不需要了(Alexnet网络中使用到)。

3. BN算法概述

BN算法提出了变换重构,引入了可学习参数γ、β,这就是算法的关键之处:

TensorFlow实现Batch Normalization

引入这两个参数后,我们的网络便可以学习恢复出原是网络所要学习的特征分布,BN层的钱箱传到过程如下:

TensorFlow实现Batch Normalization

其中m为batchsize。BatchNormalization中所有的操作都是平滑可导,这使得back propagation可以有效运行并学到相应的参数γ,β。需要注意的一点是Batch Normalization在training和testing时行为有所差别。Training时μβ和σβ由当前batch计算得出;在Testing时μβ和σβ应使用Training时保存的均值或类似的经过处理的值,而不是由当前batch计算。

二、TensorFlow相关函数

1.tf.nn.moments(x, axes, shift=None, name=None, keep_dims=False)

x是输入张量,axes是在哪个维度上求解, 即想要 normalize的维度, [0] 代表 batch 维度,如果是图像数据,可以传入 [0, 1, 2],相当于求[batch, height, width] 的均值/方差,注意不要加入channel 维度。该函数返回两个张量,均值mean和方差variance。

2.tf.identity(input, name=None)

返回与输入张量input形状和内容一致的张量。

3.tf.nn.batch_normalization(x, mean, variance, offset, scale, variance_epsilon,name=None)

计算公式为scale(x - mean)/ variance + offset。

这些参数中,tf.nn.moments可得到均值mean和方差variance,offset和scale是可训练的,offset一般初始化为0,scale初始化为1,offset和scale的shape与mean相同,variance_epsilon参数设为一个很小的值如0.001。

三、TensorFlow代码实现

1. 完整代码

import tensorflow as tf 
import numpy as np 
import matplotlib.pyplot as plt 
 
ACTIVITION = tf.nn.relu 
N_LAYERS = 7 # 总共7层隐藏层 
N_HIDDEN_UNITS = 30 # 每层包含30个神经元 
 
def fix_seed(seed=1): # 设置随机数种子 
  np.random.seed(seed) 
  tf.set_random_seed(seed) 
 
def plot_his(inputs, inputs_norm): # 绘制直方图函数 
  for j, all_inputs in enumerate([inputs, inputs_norm]): 
    for i, input in enumerate(all_inputs): 
      plt.subplot(2, len(all_inputs), j*len(all_inputs)+(i+1)) 
      plt.cla() 
      if i == 0: 
        the_range = (-7, 10) 
      else: 
        the_range = (-1, 1) 
      plt.hist(input.ravel(), bins=15, range=the_range, color='#FF5733') 
      plt.yticks(()) 
      if j == 1: 
        plt.xticks(the_range) 
      else: 
        plt.xticks(()) 
      ax = plt.gca() 
      ax.spines['right'].set_color('none') 
      ax.spines['top'].set_color('none') 
    plt.title("%s normalizing" % ("Without" if j == 0 else "With")) 
  plt.draw() 
  plt.pause(0.01) 
 
def built_net(xs, ys, norm): # 搭建网络函数 
  # 添加层 
  def add_layer(inputs, in_size, out_size, activation_function=None, norm=False): 
    Weights = tf.Variable(tf.random_normal([in_size, out_size], 
                        mean=0.0, stddev=1.0)) 
    biases = tf.Variable(tf.zeros([1, out_size]) + 0.1) 
    Wx_plus_b = tf.matmul(inputs, Weights) + biases 
 
    if norm: # 判断是否是Batch Normalization层 
      # 计算均值和方差,axes参数0表示batch维度 
      fc_mean, fc_var = tf.nn.moments(Wx_plus_b, axes=[0]) 
      scale = tf.Variable(tf.ones([out_size])) 
      shift = tf.Variable(tf.zeros([out_size])) 
      epsilon = 0.001 
 
      # 定义滑动平均模型对象 
      ema = tf.train.ExponentialMovingAverage(decay=0.5) 
 
      def mean_var_with_update(): 
        ema_apply_op = ema.apply([fc_mean, fc_var]) 
        with tf.control_dependencies([ema_apply_op]): 
          return tf.identity(fc_mean), tf.identity(fc_var) 
 
      mean, var = mean_var_with_update() 
 
      Wx_plus_b = tf.nn.batch_normalization(Wx_plus_b, mean, var, 
                         shift, scale, epsilon) 
 
    if activation_function is None: 
      outputs = Wx_plus_b 
    else: 
      outputs = activation_function(Wx_plus_b) 
    return outputs 
 
  fix_seed(1) 
 
  if norm: # 为第一层进行BN 
    fc_mean, fc_var = tf.nn.moments(xs, axes=[0]) 
    scale = tf.Variable(tf.ones([1])) 
    shift = tf.Variable(tf.zeros([1])) 
    epsilon = 0.001 
 
    ema = tf.train.ExponentialMovingAverage(decay=0.5) 
 
    def mean_var_with_update(): 
      ema_apply_op = ema.apply([fc_mean, fc_var]) 
      with tf.control_dependencies([ema_apply_op]): 
        return tf.identity(fc_mean), tf.identity(fc_var) 
 
    mean, var = mean_var_with_update() 
    xs = tf.nn.batch_normalization(xs, mean, var, shift, scale, epsilon) 
 
  layers_inputs = [xs] # 记录每一层的输入 
 
  for l_n in range(N_LAYERS): # 依次添加7层 
    layer_input = layers_inputs[l_n] 
    in_size = layers_inputs[l_n].get_shape()[1].value 
 
    output = add_layer(layer_input, in_size, N_HIDDEN_UNITS, ACTIVITION, norm) 
    layers_inputs.append(output) 
 
  prediction = add_layer(layers_inputs[-1], 30, 1, activation_function=None) 
  cost = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction), 
                    reduction_indices=[1])) 
 
  train_op = tf.train.GradientDescentOptimizer(0.001).minimize(cost) 
  return [train_op, cost, layers_inputs] 
 
fix_seed(1) 
x_data = np.linspace(-7, 10, 2500)[:, np.newaxis] 
np.random.shuffle(x_data) 
noise =np.random.normal(0, 8, x_data.shape) 
y_data = np.square(x_data) - 5 + noise 
 
plt.scatter(x_data, y_data) 
plt.show() 
 
xs = tf.placeholder(tf.float32, [None, 1]) 
ys = tf.placeholder(tf.float32, [None, 1]) 
 
train_op, cost, layers_inputs = built_net(xs, ys, norm=False) 
train_op_norm, cost_norm, layers_inputs_norm = built_net(xs, ys, norm=True) 
 
with tf.Session() as sess: 
  sess.run(tf.global_variables_initializer()) 
 
  cost_his = [] 
  cost_his_norm = [] 
  record_step = 5 
 
  plt.ion() 
  plt.figure(figsize=(7, 3)) 
  for i in range(250): 
    if i % 50 == 0: 
      all_inputs, all_inputs_norm = sess.run([layers_inputs, layers_inputs_norm], 
                          feed_dict={xs: x_data, ys: y_data}) 
      plot_his(all_inputs, all_inputs_norm) 
 
    sess.run([train_op, train_op_norm], 
         feed_dict={xs: x_data[i*10:i*10+10], ys: y_data[i*10:i*10+10]}) 
 
    if i % record_step == 0: 
      cost_his.append(sess.run(cost, feed_dict={xs: x_data, ys: y_data})) 
      cost_his_norm.append(sess.run(cost_norm, 
                     feed_dict={xs: x_data, ys: y_data})) 
 
  plt.ioff() 
  plt.figure() 
  plt.plot(np.arange(len(cost_his))*record_step, 
       np.array(cost_his), label='Without BN')   # no norm 
  plt.plot(np.arange(len(cost_his))*record_step, 
       np.array(cost_his_norm), label='With BN')  # norm 
  plt.legend() 
  plt.show()

2. 实验结果

输入数据分布:

TensorFlow实现Batch Normalization

批标准化BN效果对比:

TensorFlow实现Batch Normalization

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Python装饰器在Django框架下去除冗余代码的教程
Apr 16 Python
python实现合并两个数组的方法
May 16 Python
Python之py2exe打包工具详解
Jun 14 Python
微信跳一跳python辅助脚本(总结)
Jan 11 Python
Python实现的计算马氏距离算法示例
Apr 03 Python
详解PyTorch手写数字识别(MNIST数据集)
Aug 16 Python
opencv转换颜色空间更改图片背景
Aug 20 Python
Python中BeautifuSoup库的用法使用详解
Nov 15 Python
Python Pandas 转换unix时间戳方式
Dec 07 Python
python主线程与子线程的结束顺序实例解析
Dec 17 Python
解决Windows下python和pip命令无法使用的问题
Aug 31 Python
看看如何用Python绘制小米新版天价logo
Apr 20 Python
用Django实现一个可运行的区块链应用
Mar 08 #Python
Python pyinotify日志监控系统处理日志的方法
Mar 08 #Python
TensorFlow模型保存和提取的方法
Mar 08 #Python
火车票抢票python代码公开揭秘!
Mar 08 #Python
Python实现定时备份mysql数据库并把备份数据库邮件发送
Mar 08 #Python
python实现12306抢票及自动邮件发送提醒付款功能
Mar 08 #Python
TensorFlow模型保存/载入的两种方法
Mar 08 #Python
You might like
php中将数组转成字符串并保存到数据库中的函数代码
2013/09/29 PHP
PHP防止表单重复提交的几种常用方法汇总
2014/08/19 PHP
PHP实现的常规正则验证helper公共类完整实例
2017/04/27 PHP
php基于Redis消息队列实现的消息推送的方法
2018/11/28 PHP
js的一些常用方法小结
2011/06/29 Javascript
用javascript读取xml文件读取节点数据
2014/08/12 Javascript
jquery用offset()方法获得元素的xy坐标
2014/09/06 Javascript
node.js中的fs.fchmodSync方法使用说明
2014/12/16 Javascript
jQuery实现自动切换播放的经典滑动门效果
2015/09/12 Javascript
jQuery超简单选项卡完整实例
2015/09/26 Javascript
AngularJS进行性能调优的7个建议
2015/12/28 Javascript
基于JQuery实现分隔条的功能
2016/06/17 Javascript
jQuery绑定自定义事件的魔法升级版
2016/06/30 Javascript
Node.js Streams文件读写操作详解
2016/07/04 Javascript
VC调用javascript的几种方法(推荐)
2016/08/09 Javascript
jQuery 全选 全不选 事件绑定的实现代码
2017/01/23 Javascript
JavaScript无阻塞加载和defer、async详解
2017/02/26 Javascript
parabola.js抛物线与加入购物车效果的示例代码
2017/10/25 Javascript
vue实现消息的无缝滚动效果的示例代码
2017/12/05 Javascript
浅析Vue中method与computed的区别
2018/03/06 Javascript
element中table高度自适应的实现
2020/10/21 Javascript
[05:09]第二届DOTA2亚洲邀请赛决赛日比赛集锦:iG 3:0 OG夺冠
2017/04/05 DOTA
Python浅拷贝与深拷贝用法实例
2015/05/09 Python
python在ubuntu中的几种安装方法(小结)
2017/12/08 Python
Python数据结构与算法之图的基本实现及迭代器实例详解
2017/12/12 Python
树莓派+摄像头实现对移动物体的检测
2019/06/22 Python
python项目对接钉钉SDK的实现
2019/07/15 Python
新年福利来一波之Python轻松集齐五福(demo)
2020/01/20 Python
Python如何实现远程方法调用
2020/08/07 Python
html5实现多文件的上传示例代码
2014/02/13 HTML / CSS
HTML5 DeviceOrientation实现手机网站摇一摇功能代码实例
2015/04/24 HTML / CSS
Expedia英国:全球最大的在线旅游公司
2017/09/07 全球购物
汉森批发:Hansen Wholesale
2018/05/24 全球购物
网络工程专业毕业生推荐信
2013/10/28 职场文书
九年级英语教学反思
2014/01/31 职场文书
新课程改革心得体会
2016/01/22 职场文书