Tensorflow中的dropout的使用方法


Posted in Python onMarch 13, 2020

Hinton在论文《Improving neural networks by preventing co-adaptation of feature detectors》中提出了Dropout。Dropout用来防止神经网络的过拟合。Tensorflow中可以通过如下3中方式实现dropout。

tf.nn.dropout

def dropout(x, keep_prob, noise_shape=None, seed=None, name=None):

其中,x为浮点类型的tensor,keep_prob为浮点类型的scalar,范围在(0,1]之间,表示x中的元素被保留下来的概率,noise_shape为一维的tensor(int32类型),表示标记张量的形状(representing the shape for randomly generated keep/drop flags),并且noise_shape指定的形状必须对x的形状是可广播的。如果x的形状是[k, l, m, n],并且noise_shape为[k, l, m, n],那么x中的每一个元素是否保留都是独立,但如果x的形状是[k, l, m, n],并且noise_shape为[k, 1, 1, n],则x中的元素沿着第0个维度第3个维度以相互独立的概率保留或者丢弃,而元素沿着第1个维度和第2个维度要么同时保留,要么同时丢弃。

关于Tensorflow中的广播机制,可以参考《TensorFlow 和 NumPy 的 Broadcasting 机制探秘》

最终,会输出一个与x形状相同的张量ret,如果x中的元素被丢弃,则在ret中的对应位置元素为0,如果x中的元素被保留,则在ret中对应位置上的值为Tensorflow中的dropout的使用方法,这么做是为了使得ret中的元素之和等于x中的元素之和。

tf.layers.dropout

def dropout(inputs,
   rate=0.5,
   noise_shape=None,
   seed=None,
   training=False,
   name=None):

参数inputs为输入的张量,与tf.nn.dropout的参数keep_prob不同,rate指定元素被丢弃的概率,如果rate=0.1,则inputs中10%的元素将被丢弃,noise_shape与tf.nn.dropout的noise_shape一致,training参数用来指示当前阶段是出于训练阶段还是测试阶段,如果training为true(即训练阶段),则会进行dropout,否则不进行dropout,直接返回inputs。

自定义稀疏张量的dropout

上述的两种方法都是针对dense tensor的dropout,但有的时候,输入可能是稀疏张量,仿照tf.nn.dropout和tf.layers.dropout的内部实现原理,自定义稀疏张量的dropout。

def sparse_dropout(x, keep_prob, noise_shape):
 keep_tensor = keep_prob + tf.random_uniform(noise_shape)
 drop_mask = tf.cast(tf.floor(keep_tensor), dtype=tf.bool)
 out = tf.sparse_retain(x, drop_mask)
 return out * (1.0/keep_prob)

其中,参数x和keep_prob与tf.nn.dropout一致,noise_shape为x中非空元素的个数,如果x中有4个非空值,则noise_shape为[4],keep_tensor的元素为[keep_prob, 1.0 + keep_prob)的均匀分布,通过tf.floor向下取整得到标记张量drop_mask,tf.sparse_retain用于在一个 SparseTensor 中保留指定的非空值。

案例

def nn_dropout(x, keep_prob, noise_shape):
 out = tf.nn.dropout(x, keep_prob, noise_shape)
 return out


def layers_dropout(x, keep_prob, noise_shape, training=False):
 out = tf.layers.dropout(x, keep_prob, noise_shape, training=training)
 return out


def sparse_dropout(x, keep_prob, noise_shape):
 keep_tensor = keep_prob + tf.random_uniform(noise_shape)
 drop_mask = tf.cast(tf.floor(keep_tensor), dtype=tf.bool)
 out = tf.sparse_retain(x, drop_mask)
 return out * (1.0/keep_prob)


if __name__ == '__main__':
 inputs1 = tf.SparseTensor(indices=[[0, 0], [0, 2], [1, 1], [1, 2]], values=[1.0, 2.0, 3.0, 4.0], dense_shape=[2, 3])
 inputs2 = tf.sparse_tensor_to_dense(inputs1)
 nn_d_out = nn_dropout(inputs2, 0.5, [2, 3])
 layers_d_out = layers_dropout(inputs2, 0.5, [2, 3], training=True)
 sparse_d_out = sparse_dropout(inputs1, 0.5, [4])
 with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  (in1, in2) = sess.run([inputs1, inputs2])
  print(in1)
  print(in2)
  (out1, out2, out3) = sess.run([nn_d_out, layers_d_out, sparse_d_out])
  print(out1)
  print(out2)
  print(out3)

tensorflow中,稀疏张量为SparseTensor,稀疏张量的值为SparseTensorValue。3种dropout的输出如下,

SparseTensorValue(indices=array([[0, 0],
  [0, 2],
  [1, 1],
  [1, 2]], dtype=int64), values=array([ 1., 2., 3., 4.], dtype=float32), dense_shape=array([2, 3], dtype=int64))
[[ 1. 0. 2.]
 [ 0. 3. 4.]]
 
[[ 2. 0. 0.]
 [ 0. 0. 0.]]
[[ 0. 0. 4.]
 [ 0. 0. 0.]]
SparseTensorValue(indices=array([], shape=(0, 2), dtype=int64), values=array([], dtype=float32), dense_shape=array([2, 3], dtype=int64))

到此这篇关于Tensorflow中的dropout的使用方法的文章就介绍到这了,更多相关Tensorflow dropout内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python实现将DOC文档转换为PDF的方法
Jul 25 Python
Python smtplib实现发送邮件功能
May 22 Python
python实现事件驱动
Nov 21 Python
三步实现Django Paginator分页的方法
Jun 11 Python
通过python实现随机交换礼物程序详解
Jul 10 Python
python监控进程状态,记录重启时间及进程号的实例
Jul 15 Python
pip安装python库的方法总结
Aug 02 Python
python解析yaml文件过程详解
Aug 30 Python
Python Celery多队列配置代码实例
Nov 22 Python
Python自动创建Excel并获取内容
Sep 16 Python
python3实现语音转文字(语音识别)和文字转语音(语音合成)
Oct 14 Python
matplotlib相关系统目录获取方式小结
Feb 03 Python
python实现简单俄罗斯方块
Mar 13 #Python
Python实现检测文件的MD5值来查找重复文件案例
Mar 12 #Python
python 判断txt每行内容中是否包含子串并重新写入保存的实例
Mar 12 #Python
python 两个一样的字符串用==结果为false问题的解决
Mar 12 #Python
python不相等的两个字符串的 if 条件判断为True详解
Mar 12 #Python
Python 实现使用空值进行赋值 None
Mar 12 #Python
PyCharm永久激活方式(推荐)
Sep 22 #Python
You might like
用php来检测proxy
2006/10/09 PHP
PHP中输出转义JavaScript代码的实现代码
2011/04/22 PHP
php继承中方法重载(覆盖)的应用场合
2015/02/09 PHP
php 与 nginx 的处理方式及nginx与php-fpm通信的两种方式
2018/09/28 PHP
laravel-admin表单提交隐藏一些数据,回调时获取数据的方法
2019/10/08 PHP
PHP xpath提取网页数据内容代码解析
2020/07/16 PHP
JavaScript 学习初步 入门教程
2010/03/25 Javascript
js switch case default 的用法示例介绍
2013/10/23 Javascript
node.js中的fs.readlink方法使用说明
2014/12/17 Javascript
js+html5绘制图片到canvas的方法
2015/06/05 Javascript
JavaScript闭包和范围实例详解
2016/12/19 Javascript
Vue.js自定义指令的用法与实例解析
2017/01/18 Javascript
Ionic项目中Native Camera的使用方法
2017/06/07 Javascript
微信小程序注册60s倒计时功能 使用JS实现注册60s倒计时功能
2017/08/16 Javascript
Vue 列表上下过渡效果的实例代码
2019/06/25 Javascript
Node.js中文件系统fs模块的使用及常用接口
2020/03/06 Javascript
微信小程序自定义支持图片的弹窗
2020/12/21 Javascript
Python文件操作之合并文本文件内容示例代码
2017/09/19 Python
Python异常对代码运行性能的影响实例解析
2018/02/08 Python
Python3+Appium安装使用教程
2019/07/05 Python
Pycharm使用之设置代码字体大小和颜色主题的教程
2019/07/12 Python
win10系统下python3安装及pip换源和使用教程
2020/01/06 Python
简单了解python关键字global nonlocal区别
2020/09/21 Python
东方电视购物:东方CJ
2016/10/12 全球购物
巴西男士个人护理产品商店:SHOP4MEN
2017/08/07 全球购物
万宝龙英国官网:Montblanc手表、书写工具、皮革和珠宝
2018/10/16 全球购物
俄罗斯韩国化妆品网上商店:Cosmasi.ru
2019/10/31 全球购物
巴西葡萄酒商店:Divvino
2020/02/22 全球购物
Kipling澳洲官网:购买凯浦林包包
2020/12/17 全球购物
美德好少年事迹材料
2014/01/19 职场文书
闪闪红星观后感
2015/06/08 职场文书
升学宴家长答谢词
2015/09/29 职场文书
预防职务犯罪警示教育心得体会
2016/01/15 职场文书
护士爱岗敬业心得体会
2016/01/25 职场文书
浅谈Python数学建模之线性规划
2021/06/23 Python
九大龙王魂骨,山龙王留下躯干骨,榜首死的最憋屈(被捏碎)
2022/03/18 国漫