Tensorflow中的dropout的使用方法


Posted in Python onMarch 13, 2020

Hinton在论文《Improving neural networks by preventing co-adaptation of feature detectors》中提出了Dropout。Dropout用来防止神经网络的过拟合。Tensorflow中可以通过如下3中方式实现dropout。

tf.nn.dropout

def dropout(x, keep_prob, noise_shape=None, seed=None, name=None):

其中,x为浮点类型的tensor,keep_prob为浮点类型的scalar,范围在(0,1]之间,表示x中的元素被保留下来的概率,noise_shape为一维的tensor(int32类型),表示标记张量的形状(representing the shape for randomly generated keep/drop flags),并且noise_shape指定的形状必须对x的形状是可广播的。如果x的形状是[k, l, m, n],并且noise_shape为[k, l, m, n],那么x中的每一个元素是否保留都是独立,但如果x的形状是[k, l, m, n],并且noise_shape为[k, 1, 1, n],则x中的元素沿着第0个维度第3个维度以相互独立的概率保留或者丢弃,而元素沿着第1个维度和第2个维度要么同时保留,要么同时丢弃。

关于Tensorflow中的广播机制,可以参考《TensorFlow 和 NumPy 的 Broadcasting 机制探秘》

最终,会输出一个与x形状相同的张量ret,如果x中的元素被丢弃,则在ret中的对应位置元素为0,如果x中的元素被保留,则在ret中对应位置上的值为Tensorflow中的dropout的使用方法,这么做是为了使得ret中的元素之和等于x中的元素之和。

tf.layers.dropout

def dropout(inputs,
   rate=0.5,
   noise_shape=None,
   seed=None,
   training=False,
   name=None):

参数inputs为输入的张量,与tf.nn.dropout的参数keep_prob不同,rate指定元素被丢弃的概率,如果rate=0.1,则inputs中10%的元素将被丢弃,noise_shape与tf.nn.dropout的noise_shape一致,training参数用来指示当前阶段是出于训练阶段还是测试阶段,如果training为true(即训练阶段),则会进行dropout,否则不进行dropout,直接返回inputs。

自定义稀疏张量的dropout

上述的两种方法都是针对dense tensor的dropout,但有的时候,输入可能是稀疏张量,仿照tf.nn.dropout和tf.layers.dropout的内部实现原理,自定义稀疏张量的dropout。

def sparse_dropout(x, keep_prob, noise_shape):
 keep_tensor = keep_prob + tf.random_uniform(noise_shape)
 drop_mask = tf.cast(tf.floor(keep_tensor), dtype=tf.bool)
 out = tf.sparse_retain(x, drop_mask)
 return out * (1.0/keep_prob)

其中,参数x和keep_prob与tf.nn.dropout一致,noise_shape为x中非空元素的个数,如果x中有4个非空值,则noise_shape为[4],keep_tensor的元素为[keep_prob, 1.0 + keep_prob)的均匀分布,通过tf.floor向下取整得到标记张量drop_mask,tf.sparse_retain用于在一个 SparseTensor 中保留指定的非空值。

案例

def nn_dropout(x, keep_prob, noise_shape):
 out = tf.nn.dropout(x, keep_prob, noise_shape)
 return out


def layers_dropout(x, keep_prob, noise_shape, training=False):
 out = tf.layers.dropout(x, keep_prob, noise_shape, training=training)
 return out


def sparse_dropout(x, keep_prob, noise_shape):
 keep_tensor = keep_prob + tf.random_uniform(noise_shape)
 drop_mask = tf.cast(tf.floor(keep_tensor), dtype=tf.bool)
 out = tf.sparse_retain(x, drop_mask)
 return out * (1.0/keep_prob)


if __name__ == '__main__':
 inputs1 = tf.SparseTensor(indices=[[0, 0], [0, 2], [1, 1], [1, 2]], values=[1.0, 2.0, 3.0, 4.0], dense_shape=[2, 3])
 inputs2 = tf.sparse_tensor_to_dense(inputs1)
 nn_d_out = nn_dropout(inputs2, 0.5, [2, 3])
 layers_d_out = layers_dropout(inputs2, 0.5, [2, 3], training=True)
 sparse_d_out = sparse_dropout(inputs1, 0.5, [4])
 with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  (in1, in2) = sess.run([inputs1, inputs2])
  print(in1)
  print(in2)
  (out1, out2, out3) = sess.run([nn_d_out, layers_d_out, sparse_d_out])
  print(out1)
  print(out2)
  print(out3)

tensorflow中,稀疏张量为SparseTensor,稀疏张量的值为SparseTensorValue。3种dropout的输出如下,

SparseTensorValue(indices=array([[0, 0],
  [0, 2],
  [1, 1],
  [1, 2]], dtype=int64), values=array([ 1., 2., 3., 4.], dtype=float32), dense_shape=array([2, 3], dtype=int64))
[[ 1. 0. 2.]
 [ 0. 3. 4.]]
 
[[ 2. 0. 0.]
 [ 0. 0. 0.]]
[[ 0. 0. 4.]
 [ 0. 0. 0.]]
SparseTensorValue(indices=array([], shape=(0, 2), dtype=int64), values=array([], dtype=float32), dense_shape=array([2, 3], dtype=int64))

到此这篇关于Tensorflow中的dropout的使用方法的文章就介绍到这了,更多相关Tensorflow dropout内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python发送arp欺骗攻击代码分析
Jan 16 Python
Python 列表(List)操作方法详解
Mar 11 Python
零基础写python爬虫之抓取糗事百科代码分享
Nov 06 Python
python实现文件快照加密保护的方法
Jun 30 Python
Python读取图片属性信息的实现方法
Sep 11 Python
在Linux命令行终端中使用python的简单方法(推荐)
Jan 23 Python
Python书单 不将就
Jul 11 Python
python中使用PIL制作并验证图片验证码
Mar 15 Python
python xlsxwriter库生成图表的应用示例
Mar 16 Python
python实现文件助手中查看微信撤回消息
Apr 29 Python
基于python进行抽样分布描述及实践详解
Sep 02 Python
Python3将ipa包中的文件按大小排序
Apr 17 Python
python实现简单俄罗斯方块
Mar 13 #Python
Python实现检测文件的MD5值来查找重复文件案例
Mar 12 #Python
python 判断txt每行内容中是否包含子串并重新写入保存的实例
Mar 12 #Python
python 两个一样的字符串用==结果为false问题的解决
Mar 12 #Python
python不相等的两个字符串的 if 条件判断为True详解
Mar 12 #Python
Python 实现使用空值进行赋值 None
Mar 12 #Python
PyCharm永久激活方式(推荐)
Sep 22 #Python
You might like
PHP注释实例技巧
2008/10/03 PHP
PHP 面向对象 PHP5 中的常量
2010/05/05 PHP
11个PHPer必须要了解的编程规范
2014/09/22 PHP
php 命名空间(namespace)原理与用法实例小结
2019/11/13 PHP
解决windows上php xdebug 无法调试的问题
2020/02/19 PHP
JS option location 页面跳转实现代码
2008/12/27 Javascript
Exitjs获取DataView中图片文件名
2009/11/26 Javascript
Visual Studio中的jQuery智能提示设置方法
2010/03/27 Javascript
JSON 教程 json入门学习笔记
2020/09/22 Javascript
JS中图片缓冲loading技术的实例代码
2013/08/29 Javascript
javascript弹出页面回传值的方法
2015/01/28 Javascript
深入学习JavaScript中的Rest参数和参数默认值
2015/07/28 Javascript
js实现简洁的滑动门菜单(选项卡)效果代码
2015/09/04 Javascript
JavaScript是如何实现继承的(六种方式)
2016/03/31 Javascript
AngularJs Dependency Injection(DI,依赖注入)
2016/09/02 Javascript
jQuery下拉菜单的实现代码
2016/11/03 Javascript
利用js获取下拉框中所选的值
2016/12/01 Javascript
nodejs实现发出蜂鸣声音(系统报警声)的方法
2017/01/18 NodeJs
jQuery中Datatables增加跳转到指定页功能
2017/02/08 Javascript
jQuery插件HighCharts绘制2D带有Legend的饼图效果示例【附demo源码下载】
2017/03/10 Javascript
require.js与bootstrap结合实现简单的页面登录和页面跳转功能
2017/05/12 Javascript
浅谈vue.js中v-for循环渲染
2017/07/26 Javascript
总结js中的一些兼容性易错的问题
2017/12/18 Javascript
JS二级菜单不同实现方法分析【4种方法】
2018/12/21 Javascript
超轻量级的js时间库miment使用解析
2019/08/02 Javascript
浅谈vue项目用到的mock数据接口的两种方式
2019/10/09 Javascript
[22:07]DOTA2-DPC中国联赛 正赛 iG vs Magma 选手采访
2021/03/11 DOTA
python生成随机验证码(中文验证码)示例
2014/04/03 Python
python 时间信息“2018-02-04 18:23:35“ 解析成字典形式的结果代码详解
2018/04/19 Python
pygame游戏之旅 载入小车图片、更新窗口
2018/11/20 Python
css3 实现圆形旋转倒计时
2018/02/24 HTML / CSS
HTML利用九宫格原理进行网页布局
2020/03/13 HTML / CSS
GAP阿联酋官网:GAP UAE
2017/11/30 全球购物
2015年中学图书馆工作总结
2015/07/22 职场文书
红领巾广播站广播稿
2015/08/19 职场文书
Python3.10的一些新特性原理分析
2021/09/15 Python