Tensorflow中的dropout的使用方法


Posted in Python onMarch 13, 2020

Hinton在论文《Improving neural networks by preventing co-adaptation of feature detectors》中提出了Dropout。Dropout用来防止神经网络的过拟合。Tensorflow中可以通过如下3中方式实现dropout。

tf.nn.dropout

def dropout(x, keep_prob, noise_shape=None, seed=None, name=None):

其中,x为浮点类型的tensor,keep_prob为浮点类型的scalar,范围在(0,1]之间,表示x中的元素被保留下来的概率,noise_shape为一维的tensor(int32类型),表示标记张量的形状(representing the shape for randomly generated keep/drop flags),并且noise_shape指定的形状必须对x的形状是可广播的。如果x的形状是[k, l, m, n],并且noise_shape为[k, l, m, n],那么x中的每一个元素是否保留都是独立,但如果x的形状是[k, l, m, n],并且noise_shape为[k, 1, 1, n],则x中的元素沿着第0个维度第3个维度以相互独立的概率保留或者丢弃,而元素沿着第1个维度和第2个维度要么同时保留,要么同时丢弃。

关于Tensorflow中的广播机制,可以参考《TensorFlow 和 NumPy 的 Broadcasting 机制探秘》

最终,会输出一个与x形状相同的张量ret,如果x中的元素被丢弃,则在ret中的对应位置元素为0,如果x中的元素被保留,则在ret中对应位置上的值为Tensorflow中的dropout的使用方法,这么做是为了使得ret中的元素之和等于x中的元素之和。

tf.layers.dropout

def dropout(inputs,
   rate=0.5,
   noise_shape=None,
   seed=None,
   training=False,
   name=None):

参数inputs为输入的张量,与tf.nn.dropout的参数keep_prob不同,rate指定元素被丢弃的概率,如果rate=0.1,则inputs中10%的元素将被丢弃,noise_shape与tf.nn.dropout的noise_shape一致,training参数用来指示当前阶段是出于训练阶段还是测试阶段,如果training为true(即训练阶段),则会进行dropout,否则不进行dropout,直接返回inputs。

自定义稀疏张量的dropout

上述的两种方法都是针对dense tensor的dropout,但有的时候,输入可能是稀疏张量,仿照tf.nn.dropout和tf.layers.dropout的内部实现原理,自定义稀疏张量的dropout。

def sparse_dropout(x, keep_prob, noise_shape):
 keep_tensor = keep_prob + tf.random_uniform(noise_shape)
 drop_mask = tf.cast(tf.floor(keep_tensor), dtype=tf.bool)
 out = tf.sparse_retain(x, drop_mask)
 return out * (1.0/keep_prob)

其中,参数x和keep_prob与tf.nn.dropout一致,noise_shape为x中非空元素的个数,如果x中有4个非空值,则noise_shape为[4],keep_tensor的元素为[keep_prob, 1.0 + keep_prob)的均匀分布,通过tf.floor向下取整得到标记张量drop_mask,tf.sparse_retain用于在一个 SparseTensor 中保留指定的非空值。

案例

def nn_dropout(x, keep_prob, noise_shape):
 out = tf.nn.dropout(x, keep_prob, noise_shape)
 return out


def layers_dropout(x, keep_prob, noise_shape, training=False):
 out = tf.layers.dropout(x, keep_prob, noise_shape, training=training)
 return out


def sparse_dropout(x, keep_prob, noise_shape):
 keep_tensor = keep_prob + tf.random_uniform(noise_shape)
 drop_mask = tf.cast(tf.floor(keep_tensor), dtype=tf.bool)
 out = tf.sparse_retain(x, drop_mask)
 return out * (1.0/keep_prob)


if __name__ == '__main__':
 inputs1 = tf.SparseTensor(indices=[[0, 0], [0, 2], [1, 1], [1, 2]], values=[1.0, 2.0, 3.0, 4.0], dense_shape=[2, 3])
 inputs2 = tf.sparse_tensor_to_dense(inputs1)
 nn_d_out = nn_dropout(inputs2, 0.5, [2, 3])
 layers_d_out = layers_dropout(inputs2, 0.5, [2, 3], training=True)
 sparse_d_out = sparse_dropout(inputs1, 0.5, [4])
 with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  (in1, in2) = sess.run([inputs1, inputs2])
  print(in1)
  print(in2)
  (out1, out2, out3) = sess.run([nn_d_out, layers_d_out, sparse_d_out])
  print(out1)
  print(out2)
  print(out3)

tensorflow中,稀疏张量为SparseTensor,稀疏张量的值为SparseTensorValue。3种dropout的输出如下,

SparseTensorValue(indices=array([[0, 0],
  [0, 2],
  [1, 1],
  [1, 2]], dtype=int64), values=array([ 1., 2., 3., 4.], dtype=float32), dense_shape=array([2, 3], dtype=int64))
[[ 1. 0. 2.]
 [ 0. 3. 4.]]
 
[[ 2. 0. 0.]
 [ 0. 0. 0.]]
[[ 0. 0. 4.]
 [ 0. 0. 0.]]
SparseTensorValue(indices=array([], shape=(0, 2), dtype=int64), values=array([], dtype=float32), dense_shape=array([2, 3], dtype=int64))

到此这篇关于Tensorflow中的dropout的使用方法的文章就介绍到这了,更多相关Tensorflow dropout内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python使用xmlrpc实例讲解
Dec 17 Python
python爬取51job中hr的邮箱
May 14 Python
今天 平安夜 Python 送你一顶圣诞帽 @微信官方
Dec 25 Python
python简单验证码识别的实现方法
May 10 Python
浅谈Python大神都是这样处理XML文件的
May 31 Python
python hashlib加密实现代码
Oct 17 Python
Python搭建代理IP池实现检测IP的方法
Oct 27 Python
python主线程与子线程的结束顺序实例解析
Dec 17 Python
Python将二维列表list的数据输出(TXT,Excel)
Apr 23 Python
Python  word实现读取及导出代码解析
Jul 09 Python
python关于倒排列的知识点总结
Oct 13 Python
python处理json数据文件
Apr 11 Python
python实现简单俄罗斯方块
Mar 13 #Python
Python实现检测文件的MD5值来查找重复文件案例
Mar 12 #Python
python 判断txt每行内容中是否包含子串并重新写入保存的实例
Mar 12 #Python
python 两个一样的字符串用==结果为false问题的解决
Mar 12 #Python
python不相等的两个字符串的 if 条件判断为True详解
Mar 12 #Python
Python 实现使用空值进行赋值 None
Mar 12 #Python
PyCharm永久激活方式(推荐)
Sep 22 #Python
You might like
php简单静态页生成过程
2008/03/27 PHP
php判断两个浮点数是否相等的方法
2015/03/14 PHP
Json_decode 解析json字符串为NULL的解决方法(必看)
2017/02/17 PHP
PHP laravel中的多对多关系实例详解
2017/06/07 PHP
javascript面向对象的方式实现的弹出层效果代码
2010/01/28 Javascript
javascript定义函数的方法
2010/12/06 Javascript
利用ajaxfileupload插件实现文件上传无刷新的具体方法
2013/06/08 Javascript
jquery 层次选择器siblings与nextAll的区别介绍
2013/08/02 Javascript
js数组操作常用方法
2014/05/08 Javascript
浅谈js的setInterval事件
2014/12/05 Javascript
使用CoffeeScrip优美方式编写javascript代码
2015/10/28 Javascript
jQuery提示插件qTip2用法分析(支持ajax及多种样式)
2016/06/08 Javascript
jQuery获取浏览器类型和版本号的方法
2016/07/05 Javascript
Sortable.js拖拽排序使用方法解析
2016/11/04 Javascript
js图片轮播手动切换特效
2017/01/12 Javascript
js中的面向对象入门
2017/03/06 Javascript
JavaScript 上传文件(psd,压缩包等),图片,视频的实现方法
2017/06/19 Javascript
javascript定时器取消定时器及优化方法
2017/07/08 Javascript
label+input实现按钮开关切换效果的实例
2017/08/16 Javascript
Vue添加请求拦截器及vue-resource 拦截器使用
2017/11/23 Javascript
Webpack实战加载SVG的方法
2017/12/26 Javascript
nodeJs爬虫的技术点总结
2018/05/13 NodeJs
vue用递归组件写树形控件的实例代码
2018/07/19 Javascript
jQuery操作事件完整实例分析
2020/01/10 jQuery
python 读取txt中每行数据,并且保存到excel中的实例
2018/04/29 Python
pycharm的console输入实现换行的方法
2019/01/16 Python
Pytorch修改ResNet模型全连接层进行直接训练实例
2019/09/10 Python
Python 限定函数参数的类型及默认值方式
2019/12/24 Python
在python中使用pymysql往mysql数据库中插入(insert)数据实例
2020/03/02 Python
python将YUV420P文件转PNG图片格式的两种方法
2021/01/22 Python
详解Css3新特性应用之过渡与动画
2017/01/10 HTML / CSS
关键字throw与throws的用法差异
2016/11/22 面试题
水利公司纪检监察自我鉴定
2014/02/25 职场文书
财产公证书格式
2014/04/10 职场文书
英语系本科生求职信
2014/07/15 职场文书
党政领导班子群众路线对照检查材料
2014/10/26 职场文书