深入理解Tensorflow中的masking和padding


Posted in Python onFebruary 24, 2020

TensorFlow是一个采用数据流图(data flow graphs),用于数值计算的开源软件库。节点(Nodes)在图中表示数学操作,图中的线(edges)则表示在节点间相互联系的多维数据数组,即张量(tensor)。它灵活的架构让你可以在多种平台上展开计算,例如台式计算机中的一个或多个CPU(或GPU),服务器,移动设备等等。TensorFlow 最初由Google大脑小组(隶属于Google机器智能研究机构)的研究员和工程师们开发出来,用于机器学习和深度神经网络方面的研究,但这个系统的通用性使其也可广泛用于其他计算领域。

声明:

需要读者对tensorflow和深度学习有一定了解

tf.boolean_mask实现类似numpy数组的mask操作

Python的numpy array可以使用boolean类型的数组作为索引,获得numpy array中对应boolean值为True的项。示例如下:

# numpy array中的boolean mask
import numpy as np
target_arr = np.arange(5)
print "numpy array before being masked:"
print target_arr
mask_arr = [True, False, True, False, False]
masked_arr = target_arr[mask_arr]
print "numpy array after being masked:"
print masked_arr

运行结果如下:

numpy array before being masked: [0 1 2 3 4] numpy array after being masked: [0 2]

tf.boolean_maks对目标tensor实现同上述numpy array一样的mask操作,该函数的参数也比较简单,如下所示:

tf.boolean_mask(
 tensor, # target tensor
 mask, # mask tensor
 axis=None,
 name='boolean_mask'
)

下面,我们来尝试一下tf.boolean_mask函数,示例如下:

import tensorflow as tf
# tensorflow中的boolean mask
target_tensor = tf.constant([[1, 2], [3, 4], [5, 6]])
mask_tensor = tf.constant([True, False, True])
masked_tensor = tf.boolean_mask(target_tensor, mask_tensor, axis=0)
sess = tf.InteractiveSession()
print masked_tensor.eval()

mask tensor中的第0和第2个元素是True,mask axis是第0维,也就是我们只选择了target tensor的第0行和第1行。

[[1 2] [5 6]]

如果把mask tensor也换成2维的tensor会怎样呢?

mask_tensor2 = tf.constant([[True, False], [False, False], [True, False]])
masked_tensor2 = tf.boolean_mask(target_tensor, mask_tensor, axis=0)
print masked_tensor2.eval()

[[1 2] [5 6]]

我们发现,结果不是[[1], [5]]。tf.boolean_mask不做元素维度的mask,tersorflow中有tf.ragged.boolean_mask实现元素维度的mask。

tf.ragged.boolean_mask
tf.ragged.boolean_mask(
 data,
 mask,
 name=None
)

tensorflow中的sparse向量和sparse mask tensorflow中的sparse tensor由三部分组成,分别是indices、values、dense_shape。对于稀疏张量SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]),转化成dense tensor的值为:

[[1, 0, 0, 0] [0, 0, 2, 0] [0, 0, 0, 0]]

使用tf.sparse.mask可以对sparse tensor执行mask操作。

tf.sparse.mask(
 a,
 mask_indices,
 name=None
)

上文定义的sparse tensor有1和2两个值,对应的indices为[[0, 0], [1, 2]],执行tf.sparsse.mask(a, [[1, 2]])后,稀疏向量转化成dense的值为:

[[1, 0, 0, 0] [0, 0, 0, 0] [0, 0, 0, 0]]

由于tf.sparse中的大多数函数都只在tensorflow2.0版本中有,所以没有实例演示。

padded_batch

tf.Dataset中的padded_batch函数,根据输入序列中的最大长度,自动的pad一个batch的序列。

padded_batch(
 batch_size,
 padded_shapes,
 padding_values=None,
 drop_remainder=False
)

这个函数与tf.Dataset中的batch函数对应,都是基于dataset构造batch,但是batch函数需要dataset中的所有样本形状相同,而padded_batch可以将不同形状的样本在构造batch时padding成一样的形状。

elements = [[1, 2], 
  [3, 4, 5], 
  [6, 7], 
  [8]] 
A = tf.data.Dataset.from_generator(lambda: iter(elements), tf.int32) 
B = A.padded_batch(2, padded_shapes=[None]) 
B_iter = B.make_one_shot_iterator()
print B_iter.get_next().eval()

[[1 2 0] [3 4 5]]

总结

到此这篇关于深入理解Tensorflow中的masking和padding的文章就介绍到这了,更多相关Tensorflow中的masking和padding内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python专用方法与迭代机制实例分析
Sep 15 Python
Python列表append和+的区别浅析
Feb 02 Python
Python中的Classes和Metaclasses详解
Apr 02 Python
Python中规范定义命名空间的一些建议
Jun 04 Python
windows下python连接oracle数据库
Jun 07 Python
django中静态文件配置static的方法
May 20 Python
python实现批量图片格式转换
Jun 16 Python
对python中Librosa的mfcc步骤详解
Jan 09 Python
python 图像平移和旋转的实例
Jan 10 Python
django的403/404/500错误自定义页面的配置方式
May 21 Python
详解python中GPU版本的opencv常用方法介绍
Jul 24 Python
超级详细实用的pycharm常用快捷键
May 12 Python
K最近邻算法(KNN)---sklearn+python实现方式
Feb 24 #Python
Python3.6 + TensorFlow 安装配置图文教程(Windows 64 bit)
Feb 24 #Python
Python enumerate内置库用法解析
Feb 24 #Python
Python模块/包/库安装的六种方法及区别
Feb 24 #Python
python之MSE、MAE、RMSE的使用
Feb 24 #Python
Python接口自动化判断元素原理解析
Feb 24 #Python
python使用turtle库绘制奥运五环
Feb 24 #Python
You might like
虫族 Zerg 历史背景
2020/03/14 星际争霸
扩展你的 PHP 之入门篇
2006/12/04 PHP
用穿越火线快速入门php面向对象
2012/02/22 PHP
PHP错误WARNING: SESSION_START() [FUNCTION.SESSION-START]解决方法
2014/05/04 PHP
PHP实现生成唯一会员卡号
2015/08/24 PHP
详解WordPress中添加友情链接的方法
2016/05/21 PHP
php页面跳转session cookie丢失导致不能登录等问题的解决方法
2016/12/12 PHP
PHP网页安全认证的实例详解
2017/09/28 PHP
PHP读取文件或采集时解决中文乱码
2021/03/09 PHP
几个有趣的Javascript Hack
2010/07/24 Javascript
jquery鼠标停止移动事件
2013/12/21 Javascript
js读写json文件实例代码
2014/10/21 Javascript
jQuery源码分析之jQuery.fn.each与jQuery.each用法
2015/01/23 Javascript
js实现select下拉框菜单
2015/12/08 Javascript
JavaScript原生对象常用方法总结(推荐)
2016/05/13 Javascript
微信小程序入门教程
2016/11/18 Javascript
原生JavaScript实现Tooltip浮动提示框特效
2017/03/07 Javascript
Bootstrap modal只加载一次数据的解决办法(推荐)
2017/11/24 Javascript
Vue EventBus自定义组件事件传递
2018/06/25 Javascript
使用vue-cli脚手架工具搭建vue-webpack项目
2019/01/14 Javascript
小程序显示弹窗时禁止下层的内容滚动实现方法
2019/03/20 Javascript
nodejs环境使用Typeorm连接查询Oracle数据
2019/12/05 NodeJs
Vue data的数据响应式到底是如何实现的
2020/02/11 Javascript
《javascript设计模式》学习笔记三:Javascript面向对象程序设计单例模式原理与实现方法分析
2020/04/07 Javascript
[44:43]完美世界DOTA2联赛决赛日 FTD vs GXR 第一场 11.08
2020/11/11 DOTA
Python fileinput模块使用介绍
2014/11/30 Python
解决python2.7用pip安装包时出现错误的问题
2017/01/23 Python
对numpy Array [: ,] 的取值方法详解
2018/07/02 Python
numpy 计算两个数组重复程度的方法
2018/11/07 Python
python虚拟环境迁移方法
2019/01/03 Python
python实现在列表中查找某个元素的下标示例
2020/11/16 Python
Pottery Barn阿联酋:购买家具、家居装饰及更多
2019/12/08 全球购物
世界经理人咨询有限公司面试
2014/09/23 面试题
护士毕业实习感言
2014/03/05 职场文书
个人投资合作协议书
2014/10/12 职场文书
go原生库的中bytes.Buffer用法
2021/04/25 Golang