Tensorflow中的dropout的使用方法


Posted in Python onMarch 13, 2020

Hinton在论文《Improving neural networks by preventing co-adaptation of feature detectors》中提出了Dropout。Dropout用来防止神经网络的过拟合。Tensorflow中可以通过如下3中方式实现dropout。

tf.nn.dropout

def dropout(x, keep_prob, noise_shape=None, seed=None, name=None):

其中,x为浮点类型的tensor,keep_prob为浮点类型的scalar,范围在(0,1]之间,表示x中的元素被保留下来的概率,noise_shape为一维的tensor(int32类型),表示标记张量的形状(representing the shape for randomly generated keep/drop flags),并且noise_shape指定的形状必须对x的形状是可广播的。如果x的形状是[k, l, m, n],并且noise_shape为[k, l, m, n],那么x中的每一个元素是否保留都是独立,但如果x的形状是[k, l, m, n],并且noise_shape为[k, 1, 1, n],则x中的元素沿着第0个维度第3个维度以相互独立的概率保留或者丢弃,而元素沿着第1个维度和第2个维度要么同时保留,要么同时丢弃。

关于Tensorflow中的广播机制,可以参考《TensorFlow 和 NumPy 的 Broadcasting 机制探秘》

最终,会输出一个与x形状相同的张量ret,如果x中的元素被丢弃,则在ret中的对应位置元素为0,如果x中的元素被保留,则在ret中对应位置上的值为Tensorflow中的dropout的使用方法,这么做是为了使得ret中的元素之和等于x中的元素之和。

tf.layers.dropout

def dropout(inputs,
   rate=0.5,
   noise_shape=None,
   seed=None,
   training=False,
   name=None):

参数inputs为输入的张量,与tf.nn.dropout的参数keep_prob不同,rate指定元素被丢弃的概率,如果rate=0.1,则inputs中10%的元素将被丢弃,noise_shape与tf.nn.dropout的noise_shape一致,training参数用来指示当前阶段是出于训练阶段还是测试阶段,如果training为true(即训练阶段),则会进行dropout,否则不进行dropout,直接返回inputs。

自定义稀疏张量的dropout

上述的两种方法都是针对dense tensor的dropout,但有的时候,输入可能是稀疏张量,仿照tf.nn.dropout和tf.layers.dropout的内部实现原理,自定义稀疏张量的dropout。

def sparse_dropout(x, keep_prob, noise_shape):
 keep_tensor = keep_prob + tf.random_uniform(noise_shape)
 drop_mask = tf.cast(tf.floor(keep_tensor), dtype=tf.bool)
 out = tf.sparse_retain(x, drop_mask)
 return out * (1.0/keep_prob)

其中,参数x和keep_prob与tf.nn.dropout一致,noise_shape为x中非空元素的个数,如果x中有4个非空值,则noise_shape为[4],keep_tensor的元素为[keep_prob, 1.0 + keep_prob)的均匀分布,通过tf.floor向下取整得到标记张量drop_mask,tf.sparse_retain用于在一个 SparseTensor 中保留指定的非空值。

案例

def nn_dropout(x, keep_prob, noise_shape):
 out = tf.nn.dropout(x, keep_prob, noise_shape)
 return out


def layers_dropout(x, keep_prob, noise_shape, training=False):
 out = tf.layers.dropout(x, keep_prob, noise_shape, training=training)
 return out


def sparse_dropout(x, keep_prob, noise_shape):
 keep_tensor = keep_prob + tf.random_uniform(noise_shape)
 drop_mask = tf.cast(tf.floor(keep_tensor), dtype=tf.bool)
 out = tf.sparse_retain(x, drop_mask)
 return out * (1.0/keep_prob)


if __name__ == '__main__':
 inputs1 = tf.SparseTensor(indices=[[0, 0], [0, 2], [1, 1], [1, 2]], values=[1.0, 2.0, 3.0, 4.0], dense_shape=[2, 3])
 inputs2 = tf.sparse_tensor_to_dense(inputs1)
 nn_d_out = nn_dropout(inputs2, 0.5, [2, 3])
 layers_d_out = layers_dropout(inputs2, 0.5, [2, 3], training=True)
 sparse_d_out = sparse_dropout(inputs1, 0.5, [4])
 with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  (in1, in2) = sess.run([inputs1, inputs2])
  print(in1)
  print(in2)
  (out1, out2, out3) = sess.run([nn_d_out, layers_d_out, sparse_d_out])
  print(out1)
  print(out2)
  print(out3)

tensorflow中,稀疏张量为SparseTensor,稀疏张量的值为SparseTensorValue。3种dropout的输出如下,

SparseTensorValue(indices=array([[0, 0],
  [0, 2],
  [1, 1],
  [1, 2]], dtype=int64), values=array([ 1., 2., 3., 4.], dtype=float32), dense_shape=array([2, 3], dtype=int64))
[[ 1. 0. 2.]
 [ 0. 3. 4.]]
 
[[ 2. 0. 0.]
 [ 0. 0. 0.]]
[[ 0. 0. 4.]
 [ 0. 0. 0.]]
SparseTensorValue(indices=array([], shape=(0, 2), dtype=int64), values=array([], dtype=float32), dense_shape=array([2, 3], dtype=int64))

到此这篇关于Tensorflow中的dropout的使用方法的文章就介绍到这了,更多相关Tensorflow dropout内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
举例讲解Python面向对象编程中类的继承
Jun 17 Python
深入探究Django中的Session与Cookie
Jul 30 Python
Python实现自动发送邮件功能
Mar 02 Python
Python下载网络小说实例代码
Feb 03 Python
Python中常用的内置方法
Jan 28 Python
Python基于Opencv来快速实现人脸识别过程详解(完整版)
Jul 11 Python
python实现爬虫抓取小说功能示例【抓取金庸小说】
Aug 09 Python
在flask中使用python-dotenv+flask-cli自定义命令(推荐)
Jan 05 Python
pytorch 彩色图像转灰度图像实例
Jan 13 Python
django admin 根据choice字段选择的不同来显示不同的页面方式
May 13 Python
DRF框架API版本管理实现方法解析
Aug 21 Python
Pyecharts 中Geo函数常用参数的用法说明
Feb 01 Python
python实现简单俄罗斯方块
Mar 13 #Python
Python实现检测文件的MD5值来查找重复文件案例
Mar 12 #Python
python 判断txt每行内容中是否包含子串并重新写入保存的实例
Mar 12 #Python
python 两个一样的字符串用==结果为false问题的解决
Mar 12 #Python
python不相等的两个字符串的 if 条件判断为True详解
Mar 12 #Python
Python 实现使用空值进行赋值 None
Mar 12 #Python
PyCharm永久激活方式(推荐)
Sep 22 #Python
You might like
php生成N个不重复的随机数实例
2013/11/12 PHP
PHP上传文件时文件过大$_FILES为空的解决方法
2013/11/26 PHP
php中array_multisort对多维数组排序的方法
2020/06/21 PHP
php需登录的文件上传管理系统
2020/03/21 PHP
php图形jpgraph操作实例分析
2017/02/22 PHP
php安装dblib扩展,连接mssql的具体步骤
2017/03/02 PHP
PHP实现表单提交数据的验证处理功能【防SQL注入和XSS攻击等】
2017/07/21 PHP
可以文本显示的公告栏的js代码
2007/03/11 Javascript
JavaScript 对象成员的可见性说明
2009/10/16 Javascript
JQuery的Alert消息框插件使用介绍
2010/10/09 Javascript
js call方法详细介绍(js 的继承)
2013/11/18 Javascript
node.js中RPC(远程过程调用)的实现原理介绍
2014/12/05 Javascript
利用AJAX实现WordPress中的文章列表及评论的分页功能
2016/05/17 Javascript
BootStrap selectpicker后台动态绑定数据的方法
2017/07/28 Javascript
express如何使用session与cookie的方法
2018/01/30 Javascript
微信小程序使用Promise简化回调
2018/02/06 Javascript
node.js遍历目录的方法示例
2018/08/01 Javascript
angularJS1 url中携带参数的获取方法
2018/10/09 Javascript
在layui中使用form表单监听ajax异步验证注册的实例
2019/09/03 Javascript
JS实现瀑布流效果
2020/03/07 Javascript
python批量添加zabbix Screens的两个脚本分享
2017/01/16 Python
解决python3爬虫无法显示中文的问题
2018/04/12 Python
5款Python程序员高频使用开发工具推荐
2019/04/10 Python
Python+OpenCV采集本地摄像头的视频
2019/04/25 Python
Python程序暂停的正常处理方法
2019/11/07 Python
Python序列化pickle模块使用详解
2020/03/05 Python
django 取消csrf限制的实例
2020/03/13 Python
你需要学会的8个Python列表技巧
2020/06/24 Python
css3图片边框border-image的用法
2017/06/30 HTML / CSS
美国受信赖的教育产品供应商:Nest Learning
2018/06/14 全球购物
iPad和Surface Pro蓝牙键盘:Brydge
2018/11/10 全球购物
加拿大鞋网:Globo Shoes
2019/12/26 全球购物
团日活动策划书
2014/02/01 职场文书
2014工程部年度工作总结
2014/12/17 职场文书
2015人事行政工作总结范文
2015/05/21 职场文书
MySQL8.0.24版本Release Note的一些改进点
2021/04/22 MySQL