Tensorflow中的dropout的使用方法


Posted in Python onMarch 13, 2020

Hinton在论文《Improving neural networks by preventing co-adaptation of feature detectors》中提出了Dropout。Dropout用来防止神经网络的过拟合。Tensorflow中可以通过如下3中方式实现dropout。

tf.nn.dropout

def dropout(x, keep_prob, noise_shape=None, seed=None, name=None):

其中,x为浮点类型的tensor,keep_prob为浮点类型的scalar,范围在(0,1]之间,表示x中的元素被保留下来的概率,noise_shape为一维的tensor(int32类型),表示标记张量的形状(representing the shape for randomly generated keep/drop flags),并且noise_shape指定的形状必须对x的形状是可广播的。如果x的形状是[k, l, m, n],并且noise_shape为[k, l, m, n],那么x中的每一个元素是否保留都是独立,但如果x的形状是[k, l, m, n],并且noise_shape为[k, 1, 1, n],则x中的元素沿着第0个维度第3个维度以相互独立的概率保留或者丢弃,而元素沿着第1个维度和第2个维度要么同时保留,要么同时丢弃。

关于Tensorflow中的广播机制,可以参考《TensorFlow 和 NumPy 的 Broadcasting 机制探秘》

最终,会输出一个与x形状相同的张量ret,如果x中的元素被丢弃,则在ret中的对应位置元素为0,如果x中的元素被保留,则在ret中对应位置上的值为Tensorflow中的dropout的使用方法,这么做是为了使得ret中的元素之和等于x中的元素之和。

tf.layers.dropout

def dropout(inputs,
   rate=0.5,
   noise_shape=None,
   seed=None,
   training=False,
   name=None):

参数inputs为输入的张量,与tf.nn.dropout的参数keep_prob不同,rate指定元素被丢弃的概率,如果rate=0.1,则inputs中10%的元素将被丢弃,noise_shape与tf.nn.dropout的noise_shape一致,training参数用来指示当前阶段是出于训练阶段还是测试阶段,如果training为true(即训练阶段),则会进行dropout,否则不进行dropout,直接返回inputs。

自定义稀疏张量的dropout

上述的两种方法都是针对dense tensor的dropout,但有的时候,输入可能是稀疏张量,仿照tf.nn.dropout和tf.layers.dropout的内部实现原理,自定义稀疏张量的dropout。

def sparse_dropout(x, keep_prob, noise_shape):
 keep_tensor = keep_prob + tf.random_uniform(noise_shape)
 drop_mask = tf.cast(tf.floor(keep_tensor), dtype=tf.bool)
 out = tf.sparse_retain(x, drop_mask)
 return out * (1.0/keep_prob)

其中,参数x和keep_prob与tf.nn.dropout一致,noise_shape为x中非空元素的个数,如果x中有4个非空值,则noise_shape为[4],keep_tensor的元素为[keep_prob, 1.0 + keep_prob)的均匀分布,通过tf.floor向下取整得到标记张量drop_mask,tf.sparse_retain用于在一个 SparseTensor 中保留指定的非空值。

案例

def nn_dropout(x, keep_prob, noise_shape):
 out = tf.nn.dropout(x, keep_prob, noise_shape)
 return out


def layers_dropout(x, keep_prob, noise_shape, training=False):
 out = tf.layers.dropout(x, keep_prob, noise_shape, training=training)
 return out


def sparse_dropout(x, keep_prob, noise_shape):
 keep_tensor = keep_prob + tf.random_uniform(noise_shape)
 drop_mask = tf.cast(tf.floor(keep_tensor), dtype=tf.bool)
 out = tf.sparse_retain(x, drop_mask)
 return out * (1.0/keep_prob)


if __name__ == '__main__':
 inputs1 = tf.SparseTensor(indices=[[0, 0], [0, 2], [1, 1], [1, 2]], values=[1.0, 2.0, 3.0, 4.0], dense_shape=[2, 3])
 inputs2 = tf.sparse_tensor_to_dense(inputs1)
 nn_d_out = nn_dropout(inputs2, 0.5, [2, 3])
 layers_d_out = layers_dropout(inputs2, 0.5, [2, 3], training=True)
 sparse_d_out = sparse_dropout(inputs1, 0.5, [4])
 with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  (in1, in2) = sess.run([inputs1, inputs2])
  print(in1)
  print(in2)
  (out1, out2, out3) = sess.run([nn_d_out, layers_d_out, sparse_d_out])
  print(out1)
  print(out2)
  print(out3)

tensorflow中,稀疏张量为SparseTensor,稀疏张量的值为SparseTensorValue。3种dropout的输出如下,

SparseTensorValue(indices=array([[0, 0],
  [0, 2],
  [1, 1],
  [1, 2]], dtype=int64), values=array([ 1., 2., 3., 4.], dtype=float32), dense_shape=array([2, 3], dtype=int64))
[[ 1. 0. 2.]
 [ 0. 3. 4.]]
 
[[ 2. 0. 0.]
 [ 0. 0. 0.]]
[[ 0. 0. 4.]
 [ 0. 0. 0.]]
SparseTensorValue(indices=array([], shape=(0, 2), dtype=int64), values=array([], dtype=float32), dense_shape=array([2, 3], dtype=int64))

到此这篇关于Tensorflow中的dropout的使用方法的文章就介绍到这了,更多相关Tensorflow dropout内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python 矩阵增加一行或一列的实例
Apr 04 Python
一个简单的python爬虫程序 爬取豆瓣热度Top100以内的电影信息
Apr 17 Python
python pandas读取csv后,获取列标签的方法
Nov 12 Python
详解python配置虚拟环境
Apr 08 Python
python使用turtle绘制国际象棋棋盘
May 23 Python
python画图把时间作为横坐标的方法
Jul 07 Python
Django 创建/删除用户的示例代码
Jul 24 Python
简单了解Python多态与属性运行原理
Jun 15 Python
python import 上级目录的导入
Nov 03 Python
Pycharm 设置默认解释器路径和编码格式的操作
Feb 05 Python
用Python简陋模拟n阶魔方
Apr 17 Python
一小时学会TensorFlow2之基本操作2实例代码
Sep 04 Python
python实现简单俄罗斯方块
Mar 13 #Python
Python实现检测文件的MD5值来查找重复文件案例
Mar 12 #Python
python 判断txt每行内容中是否包含子串并重新写入保存的实例
Mar 12 #Python
python 两个一样的字符串用==结果为false问题的解决
Mar 12 #Python
python不相等的两个字符串的 if 条件判断为True详解
Mar 12 #Python
Python 实现使用空值进行赋值 None
Mar 12 #Python
PyCharm永久激活方式(推荐)
Sep 22 #Python
You might like
社区(php&&mysql)六
2006/10/09 PHP
PHP学习 运算符与运算符优先级
2008/06/15 PHP
php连接mysql数据库代码
2009/03/10 PHP
destoon实现商铺管理主页设置增加新菜单的方法
2014/06/26 PHP
PHP实现动态web服务器方法
2015/07/29 PHP
基于yaf框架和uploadify插件,做的一个导入excel文件,查看并保存数据的功能
2017/01/24 PHP
phpStudy配置多站点多域名方法及遇到的403错误解决方法
2017/10/19 PHP
Laravel6.18.19如何优雅的切换发件账户
2020/06/14 PHP
解决extjs在firefox中关闭窗口再打开后iframe中js函数访问不到的问题
2008/11/06 Javascript
js 如何实现对数据库的增删改查
2012/11/23 Javascript
手写的一个兼容各种浏览器的javascript getStyle函数(获取元素的样式)
2014/06/06 Javascript
浅析javascript的间隔调用和延时调用
2014/11/12 Javascript
js编写一个简单的产品放大效果代码
2016/06/27 Javascript
详解Angular 4.x Injector
2017/05/04 Javascript
解决vue.js在编写过程中出现空格不规范报错的问题
2017/09/20 Javascript
vue打包之后生成一个配置文件修改接口的方法
2018/12/09 Javascript
JavaScript中的this妙用实例分析
2020/05/09 Javascript
Vue组件生命周期运行原理解析
2020/11/25 Vue.js
[01:02:25]2014 DOTA2华西杯精英邀请赛 5 24 iG VS DK
2014/05/26 DOTA
[38:42]完美世界DOTA2联赛循环赛 Matador vs Forest BO2第二场 11.05
2020/11/05 DOTA
Python爬取三国演义的实现方法
2016/09/12 Python
python爬虫之遍历单个域名
2019/11/20 Python
python进度条显示-tqmd模块的实现示例
2020/08/23 Python
Python中正则表达式对单个字符,多个字符和匹配边界等使用
2021/01/27 Python
为世界各地的女性设计和生产时尚服装:ROMWE
2016/09/17 全球购物
芬兰攀岩、山地运动和户外活动用品购物网站:Bergfreunde
2016/10/06 全球购物
英国潮流网站:END.(全球免邮)
2017/01/16 全球购物
Champion官网:美国冠军运动服装
2017/01/25 全球购物
销售人员自我评价
2014/02/01 职场文书
英文推荐信格式范文
2014/05/09 职场文书
疾病捐款倡议书
2014/05/13 职场文书
安全生产目标管理责任书
2014/07/25 职场文书
委托书如何写
2014/08/30 职场文书
小班下学期幼儿评语
2014/12/30 职场文书
Oracle设置DB、监听和EM开机启动的方法
2021/04/25 Oracle
MySQL数据库之存储过程 procedure
2022/06/16 MySQL