Tensorflow中的dropout的使用方法


Posted in Python onMarch 13, 2020

Hinton在论文《Improving neural networks by preventing co-adaptation of feature detectors》中提出了Dropout。Dropout用来防止神经网络的过拟合。Tensorflow中可以通过如下3中方式实现dropout。

tf.nn.dropout

def dropout(x, keep_prob, noise_shape=None, seed=None, name=None):

其中,x为浮点类型的tensor,keep_prob为浮点类型的scalar,范围在(0,1]之间,表示x中的元素被保留下来的概率,noise_shape为一维的tensor(int32类型),表示标记张量的形状(representing the shape for randomly generated keep/drop flags),并且noise_shape指定的形状必须对x的形状是可广播的。如果x的形状是[k, l, m, n],并且noise_shape为[k, l, m, n],那么x中的每一个元素是否保留都是独立,但如果x的形状是[k, l, m, n],并且noise_shape为[k, 1, 1, n],则x中的元素沿着第0个维度第3个维度以相互独立的概率保留或者丢弃,而元素沿着第1个维度和第2个维度要么同时保留,要么同时丢弃。

关于Tensorflow中的广播机制,可以参考《TensorFlow 和 NumPy 的 Broadcasting 机制探秘》

最终,会输出一个与x形状相同的张量ret,如果x中的元素被丢弃,则在ret中的对应位置元素为0,如果x中的元素被保留,则在ret中对应位置上的值为Tensorflow中的dropout的使用方法,这么做是为了使得ret中的元素之和等于x中的元素之和。

tf.layers.dropout

def dropout(inputs,
   rate=0.5,
   noise_shape=None,
   seed=None,
   training=False,
   name=None):

参数inputs为输入的张量,与tf.nn.dropout的参数keep_prob不同,rate指定元素被丢弃的概率,如果rate=0.1,则inputs中10%的元素将被丢弃,noise_shape与tf.nn.dropout的noise_shape一致,training参数用来指示当前阶段是出于训练阶段还是测试阶段,如果training为true(即训练阶段),则会进行dropout,否则不进行dropout,直接返回inputs。

自定义稀疏张量的dropout

上述的两种方法都是针对dense tensor的dropout,但有的时候,输入可能是稀疏张量,仿照tf.nn.dropout和tf.layers.dropout的内部实现原理,自定义稀疏张量的dropout。

def sparse_dropout(x, keep_prob, noise_shape):
 keep_tensor = keep_prob + tf.random_uniform(noise_shape)
 drop_mask = tf.cast(tf.floor(keep_tensor), dtype=tf.bool)
 out = tf.sparse_retain(x, drop_mask)
 return out * (1.0/keep_prob)

其中,参数x和keep_prob与tf.nn.dropout一致,noise_shape为x中非空元素的个数,如果x中有4个非空值,则noise_shape为[4],keep_tensor的元素为[keep_prob, 1.0 + keep_prob)的均匀分布,通过tf.floor向下取整得到标记张量drop_mask,tf.sparse_retain用于在一个 SparseTensor 中保留指定的非空值。

案例

def nn_dropout(x, keep_prob, noise_shape):
 out = tf.nn.dropout(x, keep_prob, noise_shape)
 return out


def layers_dropout(x, keep_prob, noise_shape, training=False):
 out = tf.layers.dropout(x, keep_prob, noise_shape, training=training)
 return out


def sparse_dropout(x, keep_prob, noise_shape):
 keep_tensor = keep_prob + tf.random_uniform(noise_shape)
 drop_mask = tf.cast(tf.floor(keep_tensor), dtype=tf.bool)
 out = tf.sparse_retain(x, drop_mask)
 return out * (1.0/keep_prob)


if __name__ == '__main__':
 inputs1 = tf.SparseTensor(indices=[[0, 0], [0, 2], [1, 1], [1, 2]], values=[1.0, 2.0, 3.0, 4.0], dense_shape=[2, 3])
 inputs2 = tf.sparse_tensor_to_dense(inputs1)
 nn_d_out = nn_dropout(inputs2, 0.5, [2, 3])
 layers_d_out = layers_dropout(inputs2, 0.5, [2, 3], training=True)
 sparse_d_out = sparse_dropout(inputs1, 0.5, [4])
 with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  (in1, in2) = sess.run([inputs1, inputs2])
  print(in1)
  print(in2)
  (out1, out2, out3) = sess.run([nn_d_out, layers_d_out, sparse_d_out])
  print(out1)
  print(out2)
  print(out3)

tensorflow中,稀疏张量为SparseTensor,稀疏张量的值为SparseTensorValue。3种dropout的输出如下,

SparseTensorValue(indices=array([[0, 0],
  [0, 2],
  [1, 1],
  [1, 2]], dtype=int64), values=array([ 1., 2., 3., 4.], dtype=float32), dense_shape=array([2, 3], dtype=int64))
[[ 1. 0. 2.]
 [ 0. 3. 4.]]
 
[[ 2. 0. 0.]
 [ 0. 0. 0.]]
[[ 0. 0. 4.]
 [ 0. 0. 0.]]
SparseTensorValue(indices=array([], shape=(0, 2), dtype=int64), values=array([], dtype=float32), dense_shape=array([2, 3], dtype=int64))

到此这篇关于Tensorflow中的dropout的使用方法的文章就介绍到这了,更多相关Tensorflow dropout内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python通过socket查询whois的方法
Jul 18 Python
简单讲解Python中的数字类型及基本的数学计算
Mar 11 Python
python 随机数使用方法,推导以及字符串,双色球小程序实例
Sep 12 Python
python OpenCV学习笔记直方图反向投影的实现
Feb 07 Python
Python 变量类型详解
Oct 10 Python
为什么str(float)在Python 3中比Python 2返回更多的数字
Oct 16 Python
python实现求特征选择的信息增益
Dec 18 Python
Python3.5 Json与pickle实现数据序列化与反序列化操作示例
Apr 29 Python
python爬虫selenium和phantomJs使用方法解析
Aug 08 Python
Python 动态导入对象,importlib.import_module()的使用方法
Aug 28 Python
Pandas时间序列:时期(period)及其算术运算详解
Feb 25 Python
python模拟点击网页按钮实现方法
Feb 25 Python
python实现简单俄罗斯方块
Mar 13 #Python
Python实现检测文件的MD5值来查找重复文件案例
Mar 12 #Python
python 判断txt每行内容中是否包含子串并重新写入保存的实例
Mar 12 #Python
python 两个一样的字符串用==结果为false问题的解决
Mar 12 #Python
python不相等的两个字符串的 if 条件判断为True详解
Mar 12 #Python
Python 实现使用空值进行赋值 None
Mar 12 #Python
PyCharm永久激活方式(推荐)
Sep 22 #Python
You might like
PHP将回调函数作用到给定数组单元的方法
2014/08/19 PHP
php使用递归计算文件夹大小
2014/12/24 PHP
PHP实现加密的几种方式介绍
2015/02/22 PHP
从wamp到xampp的升级之路
2015/04/08 PHP
PHP实现批量修改文件后缀名的方法
2015/07/30 PHP
既简单又安全的PHP验证码 附调用方法
2016/06/02 PHP
JavaScript 异步调用框架 (Part 6 - 实例 & 模式)
2009/08/04 Javascript
js实现遮罩层划出效果是生成div而不是显示
2014/07/29 Javascript
jQuery使用$.ajax进行即时验证的方法
2015/12/08 Javascript
Avalon中文长字符截取、关键字符隐藏、自定义过滤器
2016/05/18 Javascript
Angular2 (RC5) 路由与导航详解
2016/09/21 Javascript
jQuery中Datatables增加跳转到指定页功能
2017/02/08 Javascript
Javascript实现登录记住用户名和密码功能
2017/03/22 Javascript
node.js 利用流实现读写同步,边读边写的方法
2017/09/11 Javascript
jquery 一键复制到剪切板的实例
2017/09/20 jQuery
基于vue 动态加载图片src的解决方法
2018/02/05 Javascript
Node.js使用MySQL连接池的方法实例
2018/02/11 Javascript
vue keep-alive请求数据的方法示例
2018/05/16 Javascript
vue+axios+mock.js环境搭建的方法步骤
2018/08/28 Javascript
JavaScript继承的特性与实践应用深入详解
2018/12/30 Javascript
微信小程序textarea层级过高(盖住其他元素)问题的解决办法
2019/03/04 Javascript
原生JS实现天气预报
2020/06/16 Javascript
Python中的作用域规则详解
2015/01/30 Python
python计算圆周率pi的方法
2015/07/11 Python
Linux下python与C++使用dlib实现人脸检测
2018/06/29 Python
浅谈python标准库--functools.partial
2019/03/13 Python
html5使用window.postMessage进行跨域实现数据交互的一次实战
2021/02/24 HTML / CSS
英国最大的婴儿监视器网上商店:Baby Monitors Direct
2018/04/24 全球购物
eBay荷兰购物网站:eBay.nl
2020/06/26 全球购物
预备党员思想汇报范文
2014/01/11 职场文书
八项规定整改措施
2014/02/12 职场文书
三好学生演讲稿范文
2014/04/26 职场文书
银行竞聘上岗演讲稿
2014/09/12 职场文书
个人党性锻炼总结
2015/03/05 职场文书
刑事申诉状范文
2015/05/20 职场文书
Mysql MVCC机制原理详解
2021/04/20 MySQL