TensorFlow实现批量归一化操作的示例


Posted in Python onApril 22, 2020

批量归一化

在对神经网络的优化方法中,有一种使用十分广泛的方法——批量归一化,使得神经网络的识别准确度得到了极大的提升。

在网络的前向计算过程中,当输出的数据不再同一分布时,可能会使得loss的值非常大,使得网络无法进行计算。产生梯度爆炸的原因是因为网络的内部协变量转移,即正向传播的不同层参数会将反向训练计算时参照的数据样本分布改变。批量归一化的目的,就是要最大限度地保证每次的正向传播输出在同一分布上,这样反向计算时参照的数据样本分布就会与正向计算时的数据分布一样了,保证分布的统一。

了解了原理,批量正则化的做法就会变得简单,即将每一层运算出来的数据都归一化成均值为0方差为1的标准高斯分布。这样就会在保留样本分布特征的同时,又消除层与层间的分布差异。在实际的应用中,批量归一化的收敛非常快,并且有很强的泛化能力,在一些情况下,完全可以代替前面的正则化,dropout。

批量归一化的定义

在TensorFlow中有自带的BN函数定义:

tf.nn.batch_normalization(x,
             maen,
             variance,
             offset,
             scale,
             variance_epsilon)

各个参数的含义如下:

x:代表输入

mean:代表样本的均值

variance:代表方差

offset:代表偏移量,即相加一个转化值,通常是用激活函数来做。

scale:代表缩放,即乘以一个转化值,同理,一般是1

variance_epsilon:为了避免分母是0的情况,给分母加一个极小值。

要使用这个函数,还需要另外的一个函数的配合:tf.nn.moments(),由此函数来计算均值和方差,然后就可以使用BN了,给函数的定义如下:

tf.nn.moments(x, axes, name, keep_dims=False),axes指定那个轴求均值和方差。

为了更好的效果,我们使用平滑指数衰减的方法来优化每次的均值和方差,这里可以使用

tf.train.ExponentialMovingAverage()函数,它的作用是让上一次的值对本次的值有一个衰减后的影响,从而使的每次的值连起来后会相对平滑一下。

批量归一化的简单用法

下面介绍具体的用法,在使用的时候需要引入头文件。

from tensorflow.contrib.layers.python.layers import batch_norm

函数的定义如下:

batch_norm(inputs,
      decay,
      center,
      scale,
      epsilon,
      activation_fn,
      param_initializers=None,
      param_regularizers=None,
      updates_collections=ops.GraphKeys.UPDATE_OPS,
      is_training=True,
      reuse=None,
      variables_collections=None,
      outputs_collections=None,
      trainable=True,
      batch_weights=None,
      fused=False,
      data_format=DATA_FORMAT_NHWC,
      zero_debias_moving_mean=False,
      scope=None,
      renorm=False,
      renorm_clipping=None,
      renorm_decay=0.99)

各参数的具体含义如下:

inputs:输入

decay:移动平均值的衰减速度,使用的是平滑指数衰减的方法更新均值方差,一般会设置0.9,值太小会导致更新太快,值太大会导致几乎没有衰减,容易出现过拟合。

scale:是否进行变换,通过乘以一个gamma值进行缩放,我们常习惯在BN后面接一个线性变化,如relu。

epsilon:为了避免分母为0,给分母加上一个极小值,一般默认。

is_training:当为True时,代表训练过程,这时会不断更新样本集的均值和方差,当测试时,要设置为False,这样就会使用训练样本的均值和方差。

updates_collections:在训练时,提供一种内置的均值方差更新机制,即通过图中的tf.GraphKeys.UPDATE_OPS变量来更新。但它是在每次当前批次训练完成后才更新均值和方差,这样导致当前数据总是使用前一次的均值和方差,没有得到最新的值,所以一般设置为None,让均值和方差及时更新,但在性能上稍慢。

reuse:支持变量共享。

具体的代码如下:

x = tf.placeholder(dtype=tf.float32, shape=[None, 32, 32, 3])
y = tf.placeholder(dtype=tf.float32, shape=[None, 10])
train = tf.Variable(tf.constant(False))

x_images = tf.reshape(x, [-1, 32, 32, 3])


def batch_norm_layer(value, train=False, name='batch_norm'):
  if train is not False:
    return batch_norm(value, decay=0.9, updates_collections=None, is_training=True)
  else:
    return batch_norm(value, decay=0.9, updates_collections=None, is_training=False)


w_conv1 = init_cnn.weight_variable([3, 3, 3, 64]) # [-1, 32, 32, 3]
b_conv1 = init_cnn.bias_variable([64])
h_conv1 = tf.nn.relu(batch_norm_layer((init_cnn.conv2d(x_images, w_conv1) + b_conv1), train))
h_pool1 = init_cnn.max_pool_2x2(h_conv1)


w_conv2 = init_cnn.weight_variable([3, 3, 64, 64]) # [-1, 16, 16, 64]
b_conv2 = init_cnn.bias_variable([64])
h_conv2 = tf.nn.relu(batch_norm_layer((init_cnn.conv2d(h_pool1, w_conv2) + b_conv2), train))
h_pool2 = init_cnn.max_pool_2x2(h_conv2)


w_conv3 = init_cnn.weight_variable([3, 3, 64, 32]) # [-1, 18, 8, 32]
b_conv3 = init_cnn.bias_variable([32])
h_conv3 = tf.nn.relu(batch_norm_layer((init_cnn.conv2d(h_pool2, w_conv3) + b_conv3), train))
h_pool3 = init_cnn.max_pool_2x2(h_conv3)

w_conv4 = init_cnn.weight_variable([3, 3, 32, 16]) # [-1, 18, 8, 32]
b_conv4 = init_cnn.bias_variable([16])
h_conv4 = tf.nn.relu(batch_norm_layer((init_cnn.conv2d(h_pool3, w_conv4) + b_conv4), train))
h_pool4 = init_cnn.max_pool_2x2(h_conv4)


w_conv5 = init_cnn.weight_variable([3, 3, 16, 10]) # [-1, 4, 4, 16]
b_conv5 = init_cnn.bias_variable([10])
h_conv5 = tf.nn.relu(batch_norm_layer((init_cnn.conv2d(h_pool4, w_conv5) + b_conv5), train))
h_pool5 = init_cnn.avg_pool_4x4(h_conv5)         # [-1, 4, 4, 10]

y_pool = tf.reshape(h_pool5, shape=[-1, 10])


cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=y_pool))

optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy)

加上了BN层之后,识别的准确率显著的得到了提升,并且计算速度也是飞起。

到此这篇关于TensorFlow实现批量归一化操作的示例的文章就介绍到这了,更多相关TensorFlow 批量归一化操作内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
详解Python中的序列化与反序列化的使用
Jun 30 Python
python实现发送和获取手机短信验证码
Jan 15 Python
python+opencv识别图片中的圆形
Mar 25 Python
Python实现读取SQLServer数据并插入到MongoDB数据库的方法示例
Jun 09 Python
pandas将numpy数组写入到csv的实例
Jul 04 Python
django 邮件发送模块smtp使用详解
Jul 22 Python
基于Python解密仿射密码
Oct 21 Python
详解Python3.8+PyQt5+pyqt5-tools+Pycharm配置详细教程
Nov 02 Python
如何在scrapy中集成selenium爬取网页的方法
Nov 18 Python
python实现学生信息管理系统(精简版)
Nov 27 Python
python爬虫scrapy框架之增量式爬虫的示例代码
Feb 26 Python
浅谈pytorch中的dropout的概率p
May 27 Python
三步解决python PermissionError: [WinError 5]拒绝访问的情况
Apr 22 #Python
python实现四人制扑克牌游戏
Apr 22 #Python
如何在django中实现分页功能
Apr 22 #Python
在Windows上安装和配置 Jupyter Lab 作为桌面级应用程序教程
Apr 22 #Python
python实现扑克牌交互式界面发牌程序
Apr 22 #Python
文件上传服务器-jupyter 中python解压及压缩方式
Apr 22 #Python
如何将tensorflow训练好的模型移植到Android (MNIST手写数字识别)
Apr 22 #Python
You might like
PHP实现批量删除(封装)
2017/04/28 PHP
JS获取当前网址、主机地址项目根路径
2013/11/19 Javascript
JS 打印功能代码可实现打印预览、打印设置等
2014/10/31 Javascript
jquery+php实现搜索框自动提示
2014/11/28 Javascript
AngularJs中route的使用方法和配置
2016/02/04 Javascript
jQuery自定义多选下拉框效果
2017/06/19 jQuery
JavaScript之事件委托实例(附原生js和jQuery代码)
2017/07/22 jQuery
JavaScript实现简单的双色球(实例讲解)
2017/07/31 Javascript
详解require.js配置路径的用法和css的引入
2017/09/06 Javascript
import与export在node.js中的使用详解
2017/09/28 Javascript
jQuery实现监听下拉框选中内容发生改变操作示例
2018/07/13 jQuery
Vue面试题及Vue知识点整理
2018/10/07 Javascript
vue-cli3全面配置详解
2018/11/14 Javascript
如何优雅地取消 JavaScript 异步任务
2020/03/22 Javascript
javascript实现文字跑马灯效果
2020/06/18 Javascript
[01:16:01]VGJ.S vs Mski Supermajor小组赛C组 BO3 第一场 6.3
2018/06/04 DOTA
使用python获取CPU和内存信息的思路与实现(linux系统)
2014/01/03 Python
Python异常学习笔记
2015/02/03 Python
在Mac OS上部署Nginx和FastCGI以及Flask框架的教程
2015/05/02 Python
Python中的ceil()方法使用教程
2015/05/14 Python
Python运算符重载用法实例
2015/05/28 Python
python转换字符串为摩尔斯电码的方法
2015/07/06 Python
python字符串的方法与操作大全
2018/01/30 Python
3个用于数据科学的顶级Python库
2018/09/29 Python
在python下使用tensorflow判断是否存在文件夹的实例
2019/06/10 Python
python中的单引号双引号区别知识点总结
2019/06/23 Python
python 实现识别图片上的数字
2019/07/30 Python
Python 3.6打包成EXE可执行程序的实现
2019/10/18 Python
pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解
2020/01/02 Python
Linux下升级安装python3.8并配置pip及yum的教程
2020/01/02 Python
Python使用Turtle模块绘制国旗的方法示例
2021/02/28 Python
北美个性化礼品商店:Things Remembered
2018/06/12 全球购物
单位刻章介绍信范文
2014/01/11 职场文书
便利店投资的创业计划书
2014/01/12 职场文书
家长对小学生的评语
2014/01/28 职场文书
2014年医院个人工作总结
2014/12/09 职场文书