TensorFlow实现批量归一化操作的示例


Posted in Python onApril 22, 2020

批量归一化

在对神经网络的优化方法中,有一种使用十分广泛的方法——批量归一化,使得神经网络的识别准确度得到了极大的提升。

在网络的前向计算过程中,当输出的数据不再同一分布时,可能会使得loss的值非常大,使得网络无法进行计算。产生梯度爆炸的原因是因为网络的内部协变量转移,即正向传播的不同层参数会将反向训练计算时参照的数据样本分布改变。批量归一化的目的,就是要最大限度地保证每次的正向传播输出在同一分布上,这样反向计算时参照的数据样本分布就会与正向计算时的数据分布一样了,保证分布的统一。

了解了原理,批量正则化的做法就会变得简单,即将每一层运算出来的数据都归一化成均值为0方差为1的标准高斯分布。这样就会在保留样本分布特征的同时,又消除层与层间的分布差异。在实际的应用中,批量归一化的收敛非常快,并且有很强的泛化能力,在一些情况下,完全可以代替前面的正则化,dropout。

批量归一化的定义

在TensorFlow中有自带的BN函数定义:

tf.nn.batch_normalization(x,
             maen,
             variance,
             offset,
             scale,
             variance_epsilon)

各个参数的含义如下:

x:代表输入

mean:代表样本的均值

variance:代表方差

offset:代表偏移量,即相加一个转化值,通常是用激活函数来做。

scale:代表缩放,即乘以一个转化值,同理,一般是1

variance_epsilon:为了避免分母是0的情况,给分母加一个极小值。

要使用这个函数,还需要另外的一个函数的配合:tf.nn.moments(),由此函数来计算均值和方差,然后就可以使用BN了,给函数的定义如下:

tf.nn.moments(x, axes, name, keep_dims=False),axes指定那个轴求均值和方差。

为了更好的效果,我们使用平滑指数衰减的方法来优化每次的均值和方差,这里可以使用

tf.train.ExponentialMovingAverage()函数,它的作用是让上一次的值对本次的值有一个衰减后的影响,从而使的每次的值连起来后会相对平滑一下。

批量归一化的简单用法

下面介绍具体的用法,在使用的时候需要引入头文件。

from tensorflow.contrib.layers.python.layers import batch_norm

函数的定义如下:

batch_norm(inputs,
      decay,
      center,
      scale,
      epsilon,
      activation_fn,
      param_initializers=None,
      param_regularizers=None,
      updates_collections=ops.GraphKeys.UPDATE_OPS,
      is_training=True,
      reuse=None,
      variables_collections=None,
      outputs_collections=None,
      trainable=True,
      batch_weights=None,
      fused=False,
      data_format=DATA_FORMAT_NHWC,
      zero_debias_moving_mean=False,
      scope=None,
      renorm=False,
      renorm_clipping=None,
      renorm_decay=0.99)

各参数的具体含义如下:

inputs:输入

decay:移动平均值的衰减速度,使用的是平滑指数衰减的方法更新均值方差,一般会设置0.9,值太小会导致更新太快,值太大会导致几乎没有衰减,容易出现过拟合。

scale:是否进行变换,通过乘以一个gamma值进行缩放,我们常习惯在BN后面接一个线性变化,如relu。

epsilon:为了避免分母为0,给分母加上一个极小值,一般默认。

is_training:当为True时,代表训练过程,这时会不断更新样本集的均值和方差,当测试时,要设置为False,这样就会使用训练样本的均值和方差。

updates_collections:在训练时,提供一种内置的均值方差更新机制,即通过图中的tf.GraphKeys.UPDATE_OPS变量来更新。但它是在每次当前批次训练完成后才更新均值和方差,这样导致当前数据总是使用前一次的均值和方差,没有得到最新的值,所以一般设置为None,让均值和方差及时更新,但在性能上稍慢。

reuse:支持变量共享。

具体的代码如下:

x = tf.placeholder(dtype=tf.float32, shape=[None, 32, 32, 3])
y = tf.placeholder(dtype=tf.float32, shape=[None, 10])
train = tf.Variable(tf.constant(False))

x_images = tf.reshape(x, [-1, 32, 32, 3])


def batch_norm_layer(value, train=False, name='batch_norm'):
  if train is not False:
    return batch_norm(value, decay=0.9, updates_collections=None, is_training=True)
  else:
    return batch_norm(value, decay=0.9, updates_collections=None, is_training=False)


w_conv1 = init_cnn.weight_variable([3, 3, 3, 64]) # [-1, 32, 32, 3]
b_conv1 = init_cnn.bias_variable([64])
h_conv1 = tf.nn.relu(batch_norm_layer((init_cnn.conv2d(x_images, w_conv1) + b_conv1), train))
h_pool1 = init_cnn.max_pool_2x2(h_conv1)


w_conv2 = init_cnn.weight_variable([3, 3, 64, 64]) # [-1, 16, 16, 64]
b_conv2 = init_cnn.bias_variable([64])
h_conv2 = tf.nn.relu(batch_norm_layer((init_cnn.conv2d(h_pool1, w_conv2) + b_conv2), train))
h_pool2 = init_cnn.max_pool_2x2(h_conv2)


w_conv3 = init_cnn.weight_variable([3, 3, 64, 32]) # [-1, 18, 8, 32]
b_conv3 = init_cnn.bias_variable([32])
h_conv3 = tf.nn.relu(batch_norm_layer((init_cnn.conv2d(h_pool2, w_conv3) + b_conv3), train))
h_pool3 = init_cnn.max_pool_2x2(h_conv3)

w_conv4 = init_cnn.weight_variable([3, 3, 32, 16]) # [-1, 18, 8, 32]
b_conv4 = init_cnn.bias_variable([16])
h_conv4 = tf.nn.relu(batch_norm_layer((init_cnn.conv2d(h_pool3, w_conv4) + b_conv4), train))
h_pool4 = init_cnn.max_pool_2x2(h_conv4)


w_conv5 = init_cnn.weight_variable([3, 3, 16, 10]) # [-1, 4, 4, 16]
b_conv5 = init_cnn.bias_variable([10])
h_conv5 = tf.nn.relu(batch_norm_layer((init_cnn.conv2d(h_pool4, w_conv5) + b_conv5), train))
h_pool5 = init_cnn.avg_pool_4x4(h_conv5)         # [-1, 4, 4, 10]

y_pool = tf.reshape(h_pool5, shape=[-1, 10])


cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=y_pool))

optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy)

加上了BN层之后,识别的准确率显著的得到了提升,并且计算速度也是飞起。

到此这篇关于TensorFlow实现批量归一化操作的示例的文章就介绍到这了,更多相关TensorFlow 批量归一化操作内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
以一个投票程序的实例来讲解Python的Django框架使用
Feb 18 Python
Python selenium 三种等待方式解读
Sep 15 Python
Python实现比较扑克牌大小程序代码示例
Dec 06 Python
Python自动化运维之IP地址处理模块详解
Dec 10 Python
PyGame贪吃蛇的实现代码示例
Nov 21 Python
python中aioysql(异步操作MySQL)的方法
Apr 11 Python
Python Des加密解密如何实现软件注册码机器码
Jan 08 Python
使用 pytorch 创建神经网络拟合sin函数的实现
Feb 24 Python
关于python 的legend图例,参数使用说明
Apr 17 Python
从0到1使用python开发一个半自动答题小程序的实现
May 12 Python
在keras里面实现计算f1-score的代码
Jun 15 Python
pytorch中的model.eval()和BN层的使用
May 22 Python
三步解决python PermissionError: [WinError 5]拒绝访问的情况
Apr 22 #Python
python实现四人制扑克牌游戏
Apr 22 #Python
如何在django中实现分页功能
Apr 22 #Python
在Windows上安装和配置 Jupyter Lab 作为桌面级应用程序教程
Apr 22 #Python
python实现扑克牌交互式界面发牌程序
Apr 22 #Python
文件上传服务器-jupyter 中python解压及压缩方式
Apr 22 #Python
如何将tensorflow训练好的模型移植到Android (MNIST手写数字识别)
Apr 22 #Python
You might like
5.PHP的其他功能
2006/10/09 PHP
PHP删除非空目录的函数代码小结
2013/02/28 PHP
php导出excel格式数据问题
2014/03/11 PHP
Codeigniter中集成smarty和adodb的方法
2016/03/04 PHP
用CSS+JS实现的进度条效果效果
2007/06/05 Javascript
JavaScript设置IFrame高度自适应(兼容各主流浏览器)
2013/06/05 Javascript
jquery div拖动效果示例代码
2013/12/08 Javascript
js在输入框屏蔽按键,只能键入数字的示例代码
2014/01/03 Javascript
详解JavaScript中的blink()方法的使用
2015/06/08 Javascript
JS实现带关闭功能的阿里妈妈网站顶部滑出banner工具条代码
2015/09/17 Javascript
JavaScript中循环遍历Array与Map的方法小结
2016/03/12 Javascript
jquery 实现滚动条下拉时无限加载的简单实例
2016/06/01 Javascript
基于JS实现回到页面顶部的五种写法(从实现到增强)
2016/09/03 Javascript
Angular.js ng-file-upload结合springMVC的使用教程
2017/07/10 Javascript
vue利用axios来完成数据的交互
2018/03/23 Javascript
微信小程序常用赋值方法小结
2019/04/30 Javascript
小程序云开发教程如何使用云函数实现点赞功能
2019/05/18 Javascript
解决python xlrd无法读取excel文件的问题
2018/12/25 Python
Python数据类型之String字符串实例详解
2019/05/08 Python
Python实现中值滤波去噪方式
2019/12/18 Python
使用TensorFlow对图像进行随机旋转的实现示例
2020/01/20 Python
Django-xadmin+rule对象级权限的实现方式
2020/03/30 Python
python 获取字典特定值对应的键的实现
2020/09/29 Python
15个Pythonic的代码示例(值得收藏)
2020/10/29 Python
python IP地址转整数
2020/11/20 Python
国际书籍零售商:Wordery
2017/11/01 全球购物
Kate Spade美国官网:纽约新兴时尚品牌,以包包闻名于世
2017/11/09 全球购物
英国和世界各地预订便宜的酒店:LateRooms.com
2019/05/05 全球购物
日语系毕业生推荐信
2013/11/11 职场文书
计算机学生的自我评价分享
2014/02/18 职场文书
《跨越海峡的生命桥》教学反思
2014/02/24 职场文书
丧事主持词大全
2014/04/02 职场文书
2015年建党94周年演讲稿
2015/03/19 职场文书
2015年度信用社工作总结
2015/05/04 职场文书
mysql timestamp比较查询遇到的坑及解决
2021/11/27 MySQL
Python可视化神器pyecharts绘制水球图
2022/07/07 Python