Tensorflow中的dropout的使用方法


Posted in Python onMarch 13, 2020

Hinton在论文《Improving neural networks by preventing co-adaptation of feature detectors》中提出了Dropout。Dropout用来防止神经网络的过拟合。Tensorflow中可以通过如下3中方式实现dropout。

tf.nn.dropout

def dropout(x, keep_prob, noise_shape=None, seed=None, name=None):

其中,x为浮点类型的tensor,keep_prob为浮点类型的scalar,范围在(0,1]之间,表示x中的元素被保留下来的概率,noise_shape为一维的tensor(int32类型),表示标记张量的形状(representing the shape for randomly generated keep/drop flags),并且noise_shape指定的形状必须对x的形状是可广播的。如果x的形状是[k, l, m, n],并且noise_shape为[k, l, m, n],那么x中的每一个元素是否保留都是独立,但如果x的形状是[k, l, m, n],并且noise_shape为[k, 1, 1, n],则x中的元素沿着第0个维度第3个维度以相互独立的概率保留或者丢弃,而元素沿着第1个维度和第2个维度要么同时保留,要么同时丢弃。

关于Tensorflow中的广播机制,可以参考《TensorFlow 和 NumPy 的 Broadcasting 机制探秘》

最终,会输出一个与x形状相同的张量ret,如果x中的元素被丢弃,则在ret中的对应位置元素为0,如果x中的元素被保留,则在ret中对应位置上的值为Tensorflow中的dropout的使用方法,这么做是为了使得ret中的元素之和等于x中的元素之和。

tf.layers.dropout

def dropout(inputs,
   rate=0.5,
   noise_shape=None,
   seed=None,
   training=False,
   name=None):

参数inputs为输入的张量,与tf.nn.dropout的参数keep_prob不同,rate指定元素被丢弃的概率,如果rate=0.1,则inputs中10%的元素将被丢弃,noise_shape与tf.nn.dropout的noise_shape一致,training参数用来指示当前阶段是出于训练阶段还是测试阶段,如果training为true(即训练阶段),则会进行dropout,否则不进行dropout,直接返回inputs。

自定义稀疏张量的dropout

上述的两种方法都是针对dense tensor的dropout,但有的时候,输入可能是稀疏张量,仿照tf.nn.dropout和tf.layers.dropout的内部实现原理,自定义稀疏张量的dropout。

def sparse_dropout(x, keep_prob, noise_shape):
 keep_tensor = keep_prob + tf.random_uniform(noise_shape)
 drop_mask = tf.cast(tf.floor(keep_tensor), dtype=tf.bool)
 out = tf.sparse_retain(x, drop_mask)
 return out * (1.0/keep_prob)

其中,参数x和keep_prob与tf.nn.dropout一致,noise_shape为x中非空元素的个数,如果x中有4个非空值,则noise_shape为[4],keep_tensor的元素为[keep_prob, 1.0 + keep_prob)的均匀分布,通过tf.floor向下取整得到标记张量drop_mask,tf.sparse_retain用于在一个 SparseTensor 中保留指定的非空值。

案例

def nn_dropout(x, keep_prob, noise_shape):
 out = tf.nn.dropout(x, keep_prob, noise_shape)
 return out


def layers_dropout(x, keep_prob, noise_shape, training=False):
 out = tf.layers.dropout(x, keep_prob, noise_shape, training=training)
 return out


def sparse_dropout(x, keep_prob, noise_shape):
 keep_tensor = keep_prob + tf.random_uniform(noise_shape)
 drop_mask = tf.cast(tf.floor(keep_tensor), dtype=tf.bool)
 out = tf.sparse_retain(x, drop_mask)
 return out * (1.0/keep_prob)


if __name__ == '__main__':
 inputs1 = tf.SparseTensor(indices=[[0, 0], [0, 2], [1, 1], [1, 2]], values=[1.0, 2.0, 3.0, 4.0], dense_shape=[2, 3])
 inputs2 = tf.sparse_tensor_to_dense(inputs1)
 nn_d_out = nn_dropout(inputs2, 0.5, [2, 3])
 layers_d_out = layers_dropout(inputs2, 0.5, [2, 3], training=True)
 sparse_d_out = sparse_dropout(inputs1, 0.5, [4])
 with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  (in1, in2) = sess.run([inputs1, inputs2])
  print(in1)
  print(in2)
  (out1, out2, out3) = sess.run([nn_d_out, layers_d_out, sparse_d_out])
  print(out1)
  print(out2)
  print(out3)

tensorflow中,稀疏张量为SparseTensor,稀疏张量的值为SparseTensorValue。3种dropout的输出如下,

SparseTensorValue(indices=array([[0, 0],
  [0, 2],
  [1, 1],
  [1, 2]], dtype=int64), values=array([ 1., 2., 3., 4.], dtype=float32), dense_shape=array([2, 3], dtype=int64))
[[ 1. 0. 2.]
 [ 0. 3. 4.]]
 
[[ 2. 0. 0.]
 [ 0. 0. 0.]]
[[ 0. 0. 4.]
 [ 0. 0. 0.]]
SparseTensorValue(indices=array([], shape=(0, 2), dtype=int64), values=array([], dtype=float32), dense_shape=array([2, 3], dtype=int64))

到此这篇关于Tensorflow中的dropout的使用方法的文章就介绍到这了,更多相关Tensorflow dropout内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python实现rest请求api示例
Apr 22 Python
python中子类继承父类的__init__方法实例
Dec 15 Python
python里使用正则的findall函数的实例详解
Oct 19 Python
python爬取网页内容转换为PDF文件
Jul 28 Python
一百行python代码将图片转成字符画
Feb 19 Python
详解python:time模块用法
Mar 25 Python
python使用pymongo操作mongo的完整步骤
Apr 13 Python
python二维码操作:对QRCode和MyQR入门详解
Jun 24 Python
django页面跳转问题及注意事项
Jul 18 Python
Python 50行爬虫抓取并处理图灵书目过程详解
Sep 20 Python
python使用nibabel和sitk读取保存nii.gz文件实例
Jul 01 Python
如何在向量化NumPy数组上进行移动窗口
May 18 Python
python实现简单俄罗斯方块
Mar 13 #Python
Python实现检测文件的MD5值来查找重复文件案例
Mar 12 #Python
python 判断txt每行内容中是否包含子串并重新写入保存的实例
Mar 12 #Python
python 两个一样的字符串用==结果为false问题的解决
Mar 12 #Python
python不相等的两个字符串的 if 条件判断为True详解
Mar 12 #Python
Python 实现使用空值进行赋值 None
Mar 12 #Python
PyCharm永久激活方式(推荐)
Sep 22 #Python
You might like
phpMyAdmin安装并配置允许空密码登录
2015/07/04 PHP
PHPUnit + Laravel单元测试常用技能
2019/11/06 PHP
JavaScript中的面向对象介绍
2012/06/30 Javascript
JavaScript实现表格排序方法
2013/06/14 Javascript
jQuery中before()方法用法实例
2014/12/25 Javascript
原生js结合html5制作简易的双色子游戏
2015/03/30 Javascript
javascript格式化指定日期对象的方法
2015/04/21 Javascript
Jquery结合HTML5实现文件上传
2015/06/25 Javascript
JavaScript对象学习小结
2015/09/02 Javascript
轻松掌握JavaScript状态模式
2016/09/07 Javascript
JavaScript中闭包的详解
2017/04/01 Javascript
js链表操作(实例讲解)
2017/08/29 Javascript
jQuery实现base64前台加密解密功能详解
2017/08/29 jQuery
JS中跳出循环的示例代码
2017/09/14 Javascript
Js面试算法详解
2018/04/08 Javascript
JavaScript轮播停留效果的实现思路
2018/05/24 Javascript
Postman如何实现参数化执行及断言处理
2020/07/28 Javascript
Vue scoped及deep使用方法解析
2020/08/01 Javascript
使用python实现strcmp函数功能示例
2014/03/25 Python
浅谈Python的异常处理
2016/06/19 Python
Python实现爬虫设置代理IP和伪装成浏览器的方法分享
2018/05/07 Python
Python中collections模块的基本使用教程
2018/12/07 Python
使用Python和百度语音识别生成视频字幕的实现
2020/04/09 Python
matplotlib quiver箭图绘制案例
2020/04/17 Python
Django利用elasticsearch(搜索引擎)实现搜索功能
2020/11/26 Python
澳大利亚自然和有机的健康美容产品一站式商店:Ziani Beauty
2017/12/28 全球购物
惠普香港官方商店:HP香港
2019/04/30 全球购物
优秀求职信范文分享
2013/12/19 职场文书
生日答谢词
2015/01/05 职场文书
安全保证书
2015/01/16 职场文书
千手观音观后感
2015/06/03 职场文书
公司劳动纪律管理制度
2015/08/04 职场文书
小学生禁毒教育心得体会
2016/01/15 职场文书
解决Pytorch中关于model.eval的问题
2021/05/22 Python
教你部署vue项目到docker
2022/04/05 Vue.js
Vscode中SSH插件如何远程连接Linux
2022/05/02 Servers