Tensorflow中的dropout的使用方法


Posted in Python onMarch 13, 2020

Hinton在论文《Improving neural networks by preventing co-adaptation of feature detectors》中提出了Dropout。Dropout用来防止神经网络的过拟合。Tensorflow中可以通过如下3中方式实现dropout。

tf.nn.dropout

def dropout(x, keep_prob, noise_shape=None, seed=None, name=None):

其中,x为浮点类型的tensor,keep_prob为浮点类型的scalar,范围在(0,1]之间,表示x中的元素被保留下来的概率,noise_shape为一维的tensor(int32类型),表示标记张量的形状(representing the shape for randomly generated keep/drop flags),并且noise_shape指定的形状必须对x的形状是可广播的。如果x的形状是[k, l, m, n],并且noise_shape为[k, l, m, n],那么x中的每一个元素是否保留都是独立,但如果x的形状是[k, l, m, n],并且noise_shape为[k, 1, 1, n],则x中的元素沿着第0个维度第3个维度以相互独立的概率保留或者丢弃,而元素沿着第1个维度和第2个维度要么同时保留,要么同时丢弃。

关于Tensorflow中的广播机制,可以参考《TensorFlow 和 NumPy 的 Broadcasting 机制探秘》

最终,会输出一个与x形状相同的张量ret,如果x中的元素被丢弃,则在ret中的对应位置元素为0,如果x中的元素被保留,则在ret中对应位置上的值为Tensorflow中的dropout的使用方法,这么做是为了使得ret中的元素之和等于x中的元素之和。

tf.layers.dropout

def dropout(inputs,
   rate=0.5,
   noise_shape=None,
   seed=None,
   training=False,
   name=None):

参数inputs为输入的张量,与tf.nn.dropout的参数keep_prob不同,rate指定元素被丢弃的概率,如果rate=0.1,则inputs中10%的元素将被丢弃,noise_shape与tf.nn.dropout的noise_shape一致,training参数用来指示当前阶段是出于训练阶段还是测试阶段,如果training为true(即训练阶段),则会进行dropout,否则不进行dropout,直接返回inputs。

自定义稀疏张量的dropout

上述的两种方法都是针对dense tensor的dropout,但有的时候,输入可能是稀疏张量,仿照tf.nn.dropout和tf.layers.dropout的内部实现原理,自定义稀疏张量的dropout。

def sparse_dropout(x, keep_prob, noise_shape):
 keep_tensor = keep_prob + tf.random_uniform(noise_shape)
 drop_mask = tf.cast(tf.floor(keep_tensor), dtype=tf.bool)
 out = tf.sparse_retain(x, drop_mask)
 return out * (1.0/keep_prob)

其中,参数x和keep_prob与tf.nn.dropout一致,noise_shape为x中非空元素的个数,如果x中有4个非空值,则noise_shape为[4],keep_tensor的元素为[keep_prob, 1.0 + keep_prob)的均匀分布,通过tf.floor向下取整得到标记张量drop_mask,tf.sparse_retain用于在一个 SparseTensor 中保留指定的非空值。

案例

def nn_dropout(x, keep_prob, noise_shape):
 out = tf.nn.dropout(x, keep_prob, noise_shape)
 return out


def layers_dropout(x, keep_prob, noise_shape, training=False):
 out = tf.layers.dropout(x, keep_prob, noise_shape, training=training)
 return out


def sparse_dropout(x, keep_prob, noise_shape):
 keep_tensor = keep_prob + tf.random_uniform(noise_shape)
 drop_mask = tf.cast(tf.floor(keep_tensor), dtype=tf.bool)
 out = tf.sparse_retain(x, drop_mask)
 return out * (1.0/keep_prob)


if __name__ == '__main__':
 inputs1 = tf.SparseTensor(indices=[[0, 0], [0, 2], [1, 1], [1, 2]], values=[1.0, 2.0, 3.0, 4.0], dense_shape=[2, 3])
 inputs2 = tf.sparse_tensor_to_dense(inputs1)
 nn_d_out = nn_dropout(inputs2, 0.5, [2, 3])
 layers_d_out = layers_dropout(inputs2, 0.5, [2, 3], training=True)
 sparse_d_out = sparse_dropout(inputs1, 0.5, [4])
 with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  (in1, in2) = sess.run([inputs1, inputs2])
  print(in1)
  print(in2)
  (out1, out2, out3) = sess.run([nn_d_out, layers_d_out, sparse_d_out])
  print(out1)
  print(out2)
  print(out3)

tensorflow中,稀疏张量为SparseTensor,稀疏张量的值为SparseTensorValue。3种dropout的输出如下,

SparseTensorValue(indices=array([[0, 0],
  [0, 2],
  [1, 1],
  [1, 2]], dtype=int64), values=array([ 1., 2., 3., 4.], dtype=float32), dense_shape=array([2, 3], dtype=int64))
[[ 1. 0. 2.]
 [ 0. 3. 4.]]
 
[[ 2. 0. 0.]
 [ 0. 0. 0.]]
[[ 0. 0. 4.]
 [ 0. 0. 0.]]
SparseTensorValue(indices=array([], shape=(0, 2), dtype=int64), values=array([], dtype=float32), dense_shape=array([2, 3], dtype=int64))

到此这篇关于Tensorflow中的dropout的使用方法的文章就介绍到这了,更多相关Tensorflow dropout内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
使用httplib模块来制作Python下HTTP客户端的方法
Jun 19 Python
Win10下Python环境搭建与配置教程
Nov 18 Python
基于Python List的赋值方法
Jun 23 Python
python 获取一个值在某个区间的指定倍数的值方法
Nov 12 Python
Django框架模板语言实例小结【变量,标签,过滤器,继承,html转义】
May 23 Python
解决django中ModelForm多表单组合的问题
Jul 18 Python
Flask使用Pyecharts在单个页面展示多个图表的方法
Aug 05 Python
Python调用Windows命令打印文件
Feb 07 Python
django实现模型字段动态choice的操作
Apr 01 Python
Python - 10行代码集2000张美女图
May 23 Python
Jupyter Notebook内使用argparse报错的解决方案
Jun 03 Python
python之基数排序的实现
Jul 26 Python
python实现简单俄罗斯方块
Mar 13 #Python
Python实现检测文件的MD5值来查找重复文件案例
Mar 12 #Python
python 判断txt每行内容中是否包含子串并重新写入保存的实例
Mar 12 #Python
python 两个一样的字符串用==结果为false问题的解决
Mar 12 #Python
python不相等的两个字符串的 if 条件判断为True详解
Mar 12 #Python
Python 实现使用空值进行赋值 None
Mar 12 #Python
PyCharm永久激活方式(推荐)
Sep 22 #Python
You might like
PHP内置过滤器FILTER使用实例
2014/06/25 PHP
PHP使用DOMDocument类生成HTML实例(包含常见标签元素)
2014/06/25 PHP
PHP关于foreach复制知识点总结
2019/01/28 PHP
快速保存网页中所有图片的方法
2006/06/23 Javascript
jquery获取下拉列表的值为null的解决方法
2011/03/18 Javascript
几种延迟加载JS代码的方法加快网页的访问速度
2013/10/12 Javascript
在Python中使用glob模块查找文件路径的方法
2015/06/17 Javascript
jquery实现可关闭的倒计时广告特效代码
2015/09/02 Javascript
JavaScript 最佳实践:帮你提升代码质量
2016/12/03 Javascript
微信小程序 利用css实现遮罩效果实例详解
2017/01/21 Javascript
详解nodejs中exports和module.exports的区别
2017/02/17 NodeJs
实例讲解DataTables固定表格宽度(设置横向滚动条)
2017/07/11 Javascript
Vue内容分发slot(全面解析)
2017/08/19 Javascript
解决使用vue.js路由后失效的问题
2018/03/17 Javascript
Vue3.0结合bootstrap创建多页面应用
2019/05/28 Javascript
js实现跳一跳小游戏
2020/07/31 Javascript
jquery自定义组件实例详解
2020/12/31 jQuery
初学Python实用技巧两则
2014/08/29 Python
Python实现非正太分布的异常值检测方式
2019/12/09 Python
TensorFlow梯度求解tf.gradients实例
2020/02/04 Python
Python中有几个关键字
2020/06/04 Python
小天鹅官方商城:LittleSwan
2017/06/16 全球购物
Java里面如何创建一个内部类的实例
2015/01/19 面试题
高中英语教学反思
2014/02/04 职场文书
《蓝色的树叶》教学反思
2014/02/24 职场文书
端午节活动策划方案
2014/03/09 职场文书
生产文员岗位职责
2014/04/05 职场文书
个人反四风对照检查材料思想汇报
2014/09/23 职场文书
安阳殷墟导游词
2015/02/10 职场文书
在职人员跳槽求职信
2015/03/20 职场文书
2015新学期校长寄语(3篇)
2015/03/25 职场文书
优秀教师工作总结2015
2015/07/22 职场文书
2015年十月一日放假通知
2015/08/18 职场文书
施工安全协议书
2016/03/22 职场文书
python使用XPath解析数据爬取起点小说网数据
2021/04/22 Python
Axios取消重复请求的方法实例详解
2021/06/15 Javascript