Tensorflow中的dropout的使用方法


Posted in Python onMarch 13, 2020

Hinton在论文《Improving neural networks by preventing co-adaptation of feature detectors》中提出了Dropout。Dropout用来防止神经网络的过拟合。Tensorflow中可以通过如下3中方式实现dropout。

tf.nn.dropout

def dropout(x, keep_prob, noise_shape=None, seed=None, name=None):

其中,x为浮点类型的tensor,keep_prob为浮点类型的scalar,范围在(0,1]之间,表示x中的元素被保留下来的概率,noise_shape为一维的tensor(int32类型),表示标记张量的形状(representing the shape for randomly generated keep/drop flags),并且noise_shape指定的形状必须对x的形状是可广播的。如果x的形状是[k, l, m, n],并且noise_shape为[k, l, m, n],那么x中的每一个元素是否保留都是独立,但如果x的形状是[k, l, m, n],并且noise_shape为[k, 1, 1, n],则x中的元素沿着第0个维度第3个维度以相互独立的概率保留或者丢弃,而元素沿着第1个维度和第2个维度要么同时保留,要么同时丢弃。

关于Tensorflow中的广播机制,可以参考《TensorFlow 和 NumPy 的 Broadcasting 机制探秘》

最终,会输出一个与x形状相同的张量ret,如果x中的元素被丢弃,则在ret中的对应位置元素为0,如果x中的元素被保留,则在ret中对应位置上的值为Tensorflow中的dropout的使用方法,这么做是为了使得ret中的元素之和等于x中的元素之和。

tf.layers.dropout

def dropout(inputs,
   rate=0.5,
   noise_shape=None,
   seed=None,
   training=False,
   name=None):

参数inputs为输入的张量,与tf.nn.dropout的参数keep_prob不同,rate指定元素被丢弃的概率,如果rate=0.1,则inputs中10%的元素将被丢弃,noise_shape与tf.nn.dropout的noise_shape一致,training参数用来指示当前阶段是出于训练阶段还是测试阶段,如果training为true(即训练阶段),则会进行dropout,否则不进行dropout,直接返回inputs。

自定义稀疏张量的dropout

上述的两种方法都是针对dense tensor的dropout,但有的时候,输入可能是稀疏张量,仿照tf.nn.dropout和tf.layers.dropout的内部实现原理,自定义稀疏张量的dropout。

def sparse_dropout(x, keep_prob, noise_shape):
 keep_tensor = keep_prob + tf.random_uniform(noise_shape)
 drop_mask = tf.cast(tf.floor(keep_tensor), dtype=tf.bool)
 out = tf.sparse_retain(x, drop_mask)
 return out * (1.0/keep_prob)

其中,参数x和keep_prob与tf.nn.dropout一致,noise_shape为x中非空元素的个数,如果x中有4个非空值,则noise_shape为[4],keep_tensor的元素为[keep_prob, 1.0 + keep_prob)的均匀分布,通过tf.floor向下取整得到标记张量drop_mask,tf.sparse_retain用于在一个 SparseTensor 中保留指定的非空值。

案例

def nn_dropout(x, keep_prob, noise_shape):
 out = tf.nn.dropout(x, keep_prob, noise_shape)
 return out


def layers_dropout(x, keep_prob, noise_shape, training=False):
 out = tf.layers.dropout(x, keep_prob, noise_shape, training=training)
 return out


def sparse_dropout(x, keep_prob, noise_shape):
 keep_tensor = keep_prob + tf.random_uniform(noise_shape)
 drop_mask = tf.cast(tf.floor(keep_tensor), dtype=tf.bool)
 out = tf.sparse_retain(x, drop_mask)
 return out * (1.0/keep_prob)


if __name__ == '__main__':
 inputs1 = tf.SparseTensor(indices=[[0, 0], [0, 2], [1, 1], [1, 2]], values=[1.0, 2.0, 3.0, 4.0], dense_shape=[2, 3])
 inputs2 = tf.sparse_tensor_to_dense(inputs1)
 nn_d_out = nn_dropout(inputs2, 0.5, [2, 3])
 layers_d_out = layers_dropout(inputs2, 0.5, [2, 3], training=True)
 sparse_d_out = sparse_dropout(inputs1, 0.5, [4])
 with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  (in1, in2) = sess.run([inputs1, inputs2])
  print(in1)
  print(in2)
  (out1, out2, out3) = sess.run([nn_d_out, layers_d_out, sparse_d_out])
  print(out1)
  print(out2)
  print(out3)

tensorflow中,稀疏张量为SparseTensor,稀疏张量的值为SparseTensorValue。3种dropout的输出如下,

SparseTensorValue(indices=array([[0, 0],
  [0, 2],
  [1, 1],
  [1, 2]], dtype=int64), values=array([ 1., 2., 3., 4.], dtype=float32), dense_shape=array([2, 3], dtype=int64))
[[ 1. 0. 2.]
 [ 0. 3. 4.]]
 
[[ 2. 0. 0.]
 [ 0. 0. 0.]]
[[ 0. 0. 4.]
 [ 0. 0. 0.]]
SparseTensorValue(indices=array([], shape=(0, 2), dtype=int64), values=array([], dtype=float32), dense_shape=array([2, 3], dtype=int64))

到此这篇关于Tensorflow中的dropout的使用方法的文章就介绍到这了,更多相关Tensorflow dropout内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
通过代码实例展示Python中列表生成式的用法
Mar 31 Python
疯狂上涨的Python 开发者应从2.x还是3.x着手?
Nov 16 Python
详解TensorFlow在windows上安装与简单示例
Mar 05 Python
flask/django 动态查询表结构相同表名不同数据的Model实现方法
Aug 29 Python
pygame实现俄罗斯方块游戏(AI篇1)
Oct 29 Python
python ftplib模块使用代码实例
Dec 31 Python
python3中sorted函数里cmp参数改变详解
Mar 12 Python
pycharm解决关闭flask后依旧可以访问服务的问题
Apr 03 Python
Python selenium 加载并保存QQ群成员,去除其群主、管理员信息的示例代码
May 28 Python
Pytorch 卷积中的 Input Shape用法
Jun 29 Python
python 如何将office文件转换为PDF
Sep 22 Python
python基于socket模拟实现ssh远程执行命令
Dec 05 Python
python实现简单俄罗斯方块
Mar 13 #Python
Python实现检测文件的MD5值来查找重复文件案例
Mar 12 #Python
python 判断txt每行内容中是否包含子串并重新写入保存的实例
Mar 12 #Python
python 两个一样的字符串用==结果为false问题的解决
Mar 12 #Python
python不相等的两个字符串的 if 条件判断为True详解
Mar 12 #Python
Python 实现使用空值进行赋值 None
Mar 12 #Python
PyCharm永久激活方式(推荐)
Sep 22 #Python
You might like
mysql 字段类型说明
2007/04/27 PHP
PHP5中新增stdClass 内部保留类
2011/06/13 PHP
url decode problem 解决方法
2011/12/26 PHP
php使用mb_check_encoding检查字符串在指定的编码里是否有效
2013/11/07 PHP
php时间函数用法分析
2016/05/28 PHP
WordPress 插件——CoolCode使用方法与下载
2007/07/02 Javascript
Jquery Ajax学习实例4 向WebService发出请求,返回实体对象的异步调用
2010/03/16 Javascript
读jQuery之六 缓存数据功能介绍
2011/06/21 Javascript
js实现的黑背景灰色二级导航菜单效果代码
2015/08/24 Javascript
基于jQuery实现在线选座之高铁版
2015/08/24 Javascript
Bootstrap Table表格一直加载(load)不了数据的快速解决方法
2016/09/17 Javascript
Angular4 中内置指令的基本用法
2017/07/31 Javascript
Vue导出json数据到Excel电子表格的示例
2017/12/04 Javascript
对mac下nodejs 更新到最新版本的最新方法(推荐)
2018/05/17 NodeJs
webpack-url-loader 解决项目中图片打包路径问题
2019/02/15 Javascript
微信小程序 scroll-view的使用案例代码详解
2020/06/11 Javascript
vue实现移动端项目多行文本溢出省略
2020/07/29 Javascript
Win7下搭建python开发环境图文教程(安装Python、pip、解释器)
2016/05/17 Python
python爬虫的工作原理
2017/03/05 Python
python爬虫获取京东手机图片的图文教程
2017/12/29 Python
python bmp转换为jpg 并删除原图的方法
2018/10/25 Python
python中通过selenium简单操作及元素定位知识点总结
2019/09/10 Python
Python3实现监控新型冠状病毒肺炎疫情的示例代码
2020/02/13 Python
django使用多个数据库的方法实例
2021/03/04 Python
css3实现元素环绕中心点布局的方法示例
2019/01/15 HTML / CSS
详解background属性的8个属性值(面试题)
2020/11/02 HTML / CSS
俄罗斯首家面向中国消费者的一站式购物网站:Wruru
2020/05/08 全球购物
学校门卫工作职责
2013/12/07 职场文书
车间班组长的职责
2013/12/13 职场文书
保护环境的建议书
2014/03/12 职场文书
法律顾问服务方案
2014/05/15 职场文书
PHP判断是否是json字符串
2021/04/01 PHP
zabbix agent2 监控oracle数据库的方法
2021/05/13 Oracle
Vue3.0 手写放大镜效果
2021/07/25 Vue.js
 Redis 串行生成顺序编码的方法实现
2022/04/03 Redis
vue选项卡切换的实现案例
2022/04/11 Vue.js