Tensorflow中的dropout的使用方法


Posted in Python onMarch 13, 2020

Hinton在论文《Improving neural networks by preventing co-adaptation of feature detectors》中提出了Dropout。Dropout用来防止神经网络的过拟合。Tensorflow中可以通过如下3中方式实现dropout。

tf.nn.dropout

def dropout(x, keep_prob, noise_shape=None, seed=None, name=None):

其中,x为浮点类型的tensor,keep_prob为浮点类型的scalar,范围在(0,1]之间,表示x中的元素被保留下来的概率,noise_shape为一维的tensor(int32类型),表示标记张量的形状(representing the shape for randomly generated keep/drop flags),并且noise_shape指定的形状必须对x的形状是可广播的。如果x的形状是[k, l, m, n],并且noise_shape为[k, l, m, n],那么x中的每一个元素是否保留都是独立,但如果x的形状是[k, l, m, n],并且noise_shape为[k, 1, 1, n],则x中的元素沿着第0个维度第3个维度以相互独立的概率保留或者丢弃,而元素沿着第1个维度和第2个维度要么同时保留,要么同时丢弃。

关于Tensorflow中的广播机制,可以参考《TensorFlow 和 NumPy 的 Broadcasting 机制探秘》

最终,会输出一个与x形状相同的张量ret,如果x中的元素被丢弃,则在ret中的对应位置元素为0,如果x中的元素被保留,则在ret中对应位置上的值为Tensorflow中的dropout的使用方法,这么做是为了使得ret中的元素之和等于x中的元素之和。

tf.layers.dropout

def dropout(inputs,
   rate=0.5,
   noise_shape=None,
   seed=None,
   training=False,
   name=None):

参数inputs为输入的张量,与tf.nn.dropout的参数keep_prob不同,rate指定元素被丢弃的概率,如果rate=0.1,则inputs中10%的元素将被丢弃,noise_shape与tf.nn.dropout的noise_shape一致,training参数用来指示当前阶段是出于训练阶段还是测试阶段,如果training为true(即训练阶段),则会进行dropout,否则不进行dropout,直接返回inputs。

自定义稀疏张量的dropout

上述的两种方法都是针对dense tensor的dropout,但有的时候,输入可能是稀疏张量,仿照tf.nn.dropout和tf.layers.dropout的内部实现原理,自定义稀疏张量的dropout。

def sparse_dropout(x, keep_prob, noise_shape):
 keep_tensor = keep_prob + tf.random_uniform(noise_shape)
 drop_mask = tf.cast(tf.floor(keep_tensor), dtype=tf.bool)
 out = tf.sparse_retain(x, drop_mask)
 return out * (1.0/keep_prob)

其中,参数x和keep_prob与tf.nn.dropout一致,noise_shape为x中非空元素的个数,如果x中有4个非空值,则noise_shape为[4],keep_tensor的元素为[keep_prob, 1.0 + keep_prob)的均匀分布,通过tf.floor向下取整得到标记张量drop_mask,tf.sparse_retain用于在一个 SparseTensor 中保留指定的非空值。

案例

def nn_dropout(x, keep_prob, noise_shape):
 out = tf.nn.dropout(x, keep_prob, noise_shape)
 return out


def layers_dropout(x, keep_prob, noise_shape, training=False):
 out = tf.layers.dropout(x, keep_prob, noise_shape, training=training)
 return out


def sparse_dropout(x, keep_prob, noise_shape):
 keep_tensor = keep_prob + tf.random_uniform(noise_shape)
 drop_mask = tf.cast(tf.floor(keep_tensor), dtype=tf.bool)
 out = tf.sparse_retain(x, drop_mask)
 return out * (1.0/keep_prob)


if __name__ == '__main__':
 inputs1 = tf.SparseTensor(indices=[[0, 0], [0, 2], [1, 1], [1, 2]], values=[1.0, 2.0, 3.0, 4.0], dense_shape=[2, 3])
 inputs2 = tf.sparse_tensor_to_dense(inputs1)
 nn_d_out = nn_dropout(inputs2, 0.5, [2, 3])
 layers_d_out = layers_dropout(inputs2, 0.5, [2, 3], training=True)
 sparse_d_out = sparse_dropout(inputs1, 0.5, [4])
 with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  (in1, in2) = sess.run([inputs1, inputs2])
  print(in1)
  print(in2)
  (out1, out2, out3) = sess.run([nn_d_out, layers_d_out, sparse_d_out])
  print(out1)
  print(out2)
  print(out3)

tensorflow中,稀疏张量为SparseTensor,稀疏张量的值为SparseTensorValue。3种dropout的输出如下,

SparseTensorValue(indices=array([[0, 0],
  [0, 2],
  [1, 1],
  [1, 2]], dtype=int64), values=array([ 1., 2., 3., 4.], dtype=float32), dense_shape=array([2, 3], dtype=int64))
[[ 1. 0. 2.]
 [ 0. 3. 4.]]
 
[[ 2. 0. 0.]
 [ 0. 0. 0.]]
[[ 0. 0. 4.]
 [ 0. 0. 0.]]
SparseTensorValue(indices=array([], shape=(0, 2), dtype=int64), values=array([], dtype=float32), dense_shape=array([2, 3], dtype=int64))

到此这篇关于Tensorflow中的dropout的使用方法的文章就介绍到这了,更多相关Tensorflow dropout内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
go和python调用其它程序并得到程序输出
Feb 10 Python
Python中isnumeric()方法的使用简介
May 19 Python
python针对excel的操作技巧
Mar 13 Python
python实现人人自动回复、抢沙发功能
Jun 08 Python
mac安装scrapy并创建项目的实例讲解
Jun 13 Python
python高级特性和高阶函数及使用详解
Oct 17 Python
使用python代码进行身份证号校验的实现示例
Nov 21 Python
Python通过VGG16模型实现图像风格转换操作详解
Jan 16 Python
六种酷炫Python运行进度条效果的实现代码
Jul 17 Python
Python爬虫之Selenium警告框(弹窗)处理
Dec 04 Python
详解python中的三种命令行模块(sys.argv,argparse,click)
Dec 15 Python
Python绘制K线图之可视化神器pyecharts的使用
Mar 02 Python
python实现简单俄罗斯方块
Mar 13 #Python
Python实现检测文件的MD5值来查找重复文件案例
Mar 12 #Python
python 判断txt每行内容中是否包含子串并重新写入保存的实例
Mar 12 #Python
python 两个一样的字符串用==结果为false问题的解决
Mar 12 #Python
python不相等的两个字符串的 if 条件判断为True详解
Mar 12 #Python
Python 实现使用空值进行赋值 None
Mar 12 #Python
PyCharm永久激活方式(推荐)
Sep 22 #Python
You might like
php中$_REQUEST、$_POST、$_GET的区别和联系小结
2011/11/23 PHP
smarty自定义函数htmlcheckboxes用法实例
2015/01/22 PHP
thinkPHP删除前弹出确认框的简单实现方法
2016/05/16 PHP
PHP实现冒泡排序的简单实例
2016/05/26 PHP
thinkPHP模板中for循环与switch语句用法示例
2016/11/30 PHP
由JavaScript中call()方法引发的对面向对象继承机制call的思考
2011/09/12 Javascript
jquery实现按Enter键触发事件示例
2013/09/10 Javascript
js字母大小写转换实现方法总结
2013/11/13 Javascript
Vue组件BootPage实现简单的分页功能
2016/09/12 Javascript
AngularJS 单元测试(二)详解
2016/09/21 Javascript
JQuery Ajax WebService传递参数的简单实例
2016/11/02 Javascript
通过AngularJS实现图片上传及缩略图展示示例
2017/01/03 Javascript
基于jQuery实现一个marquee无缝滚动的插件
2017/03/09 Javascript
jQuery实现移动端Tab选项卡效果
2017/03/15 Javascript
Zepto实现密码的隐藏/显示
2017/04/07 Javascript
vue实现todolist单页面应用
2017/04/11 Javascript
详解AngularJS脏检查机制及$timeout的妙用
2017/06/19 Javascript
js 监控iframe URL的变化实例代码
2017/07/12 Javascript
[01:14]2014DOTA2展望TI 剑指西雅图newbee战队专访
2014/06/30 DOTA
[03:44]2015国际邀请赛选手档案—Cloud9.NoTail
2015/07/28 DOTA
[52:44]VGJ.T vs infamous Supermajor小组赛D组败者组第一轮 BO3 第一场 6.3
2018/06/04 DOTA
Python 中urls.py:URL dispatcher(路由配置文件)详解
2017/03/24 Python
Python实现购物车购物小程序
2018/04/18 Python
详解Python的循环结构知识点
2019/05/20 Python
连接pandas以及数组转pandas的方法
2019/06/28 Python
Python实现性能自动化测试竟然如此简单
2019/07/30 Python
Python简单实现词云图代码及步骤解析
2020/06/04 Python
美国半成品食材配送服务商:Home Chef
2018/01/25 全球购物
来自Ocado的宠物商店:Fetch
2018/07/10 全球购物
全球烹饪课程的领先预订平台:Cookly
2020/01/28 全球购物
请描述一下”is a”关系和”has a”关系
2015/02/03 面试题
若通过ObjectOutputStream向一个文件中多次以追加方式写入object,为什么用ObjectInputStream读取这些object时会产生StreamCorruptedException?
2016/10/17 面试题
便利店促销方案
2014/02/20 职场文书
国际贸易本科毕业生求职信
2014/09/26 职场文书
2016年学校爱国卫生月活动总结
2016/04/06 职场文书
java实现自定义时钟并实现走时功能
2022/06/21 Java/Android