对tensorflow中tf.nn.conv1d和layers.conv1d的区别详解


Posted in Python onFebruary 11, 2020

在用tensorflow做一维的卷积神经网络的时候会遇到tf.nn.conv1d和layers.conv1d这两个函数,但是这两个函数有什么区别呢,通过计算得到一些规律。

1.关于tf.nn.conv1d的解释,以下是Tensor Flow中关于tf.nn.conv1d的API注解:

Computes a 1-D convolution given 3-D input and filter tensors.

Given an input tensor of shape [batch, in_width, in_channels] if data_format is "NHWC", or [batch, in_channels, in_width] if data_format is "NCHW", and a filter / kernel tensor of shape [filter_width, in_channels, out_channels], this op reshapes the arguments to pass them to conv2d to perform the equivalent convolution operation.

Internally, this op reshapes the input tensors and invokes `tf.nn.conv2d`. For example, if `data_format` does not start with "NC", a tensor of shape [batch, in_width, in_channels] is reshaped to [batch, 1, in_width, in_channels], and the filter is reshaped to [1, filter_width, in_channels, out_channels]. The result is then reshaped back to [batch, out_width, out_channels] whereoutwidthisafunctionofthestrideandpaddingasinconv2dwhereoutwidthisafunctionofthestrideandpaddingasinconv2d and returned to the caller.

Args: value: A 3D `Tensor`. Must be of type `float32` or `float64`. filters: A 3D `Tensor`. Must have the same type as `input`. stride: An `integer`. The number of entries by which the filter is moved right at each step. padding: 'SAME' or 'VALID' use_cudnn_on_gpu: An optional `bool`. Defaults to `True`. data_format: An optional `string` from `"NHWC", "NCHW"`. Defaults to `"NHWC"`, the data is stored in the order of [batch, in_width, in_channels]. The `"NCHW"` format stores data as [batch, in_channels, in_width]. name: A name for the operation (optional).

Returns:

A `Tensor`. Has the same type as input.

Raises:

ValueError: if `data_format` is invalid.

什么意思呢?就是说conv1d的参数含义:(以NHWC格式为例,即,通道维在最后)

1、value:在注释中,value的格式为:[batch, in_width, in_channels],batch为样本维,表示多少个样本,in_width为宽度维,表示样本的宽度,in_channels维通道维,表示样本有多少个通道。 事实上,也可以把格式看作如下:[batch, 行数, 列数],把每一个样本看作一个平铺开的二维数组。这样的话可以方便理解。

2、filters:在注释中,filters的格式为:[filter_width, in_channels, out_channels]。按照value的第二种看法,filter_width可以看作每次与value进行卷积的行数,in_channels表示value一共有多少列(与value中的in_channels相对应)。out_channels表示输出通道,可以理解为一共有多少个卷积核,即卷积核的数目。

3、stride:一个整数,表示步长,每次(向下)移动的距离(TensorFlow中解释是向右移动的距离,这里可以看作向下移动的距离)。

4、padding:同conv2d,value是否需要在下方填补0。

5、name:名称。可省略。

首先从参数列表可以看出value指的输入的数据,stride就是卷积的步长,这里我们最有疑问的就是filters这个参数,那么我们对filter进行简单的说明。从上面可以看到filters的格式为:[filter_width, in_channels, out_channels],这是一个数组的维度,对应的是卷积核的大小,输入的channel的格式,和卷积核的个数,下面我们用例子说明问题:

import tensorflow as tf
import numpy as np
 
 
if __name__ == '__main__':
  inputs = tf.constant(np.arange(1, 6, dtype=np.float32), shape=[1, 5, 1])
  w = np.array([1, 2], dtype=np.float32).reshape([2, 1, 1])
  # filter width, filter channels and out channels(number of kernels)
  cov1 = tf.nn.conv1d(inputs, w, stride=1, padding='VALID')
  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    out = sess.run(cov1)
    print(out)

其输出为:

[[[ 5.],
    [ 8.],
    [11.],
    [14.]]]

我们分析一下,输入的数据为[[[1],[2],[3],[4],[5]]],有5个特征,分别对应的数值为1,2,3,4,5,那么经过卷积的结果为5,8,11,14,那么这个结果是怎么来的呢,我们根据卷积的计算,可以得到5 = 1*1 + 2*2, 8=2*1+ 3*2, 11 = 3*1+4*2, 14=4*1+5*2, 也就是W1=1, W2=2,正好和我们先面filters设置的数值相等,

w = np.array([1, 2], dtype=np.float32).reshape([2, 1, 1])

所以可以看到这个filtes设置的是是卷积核矩阵的,换句话说,卷积核矩阵我们是可以设置的。

2. 1.关于tf.layers.conv1d,函数的定义如下

tf.layers.conv1d(
 
inputs,
 
filters,
 
kernel_size,
 
strides=1,
 
padding='valid',
 
data_format='channels_last',
 
dilation_rate=1,
 
activation=None,
 
use_bias=True,
 
kernel_initializer=None,
 
bias_initializer=tf.zeros_initializer(),
 
kernel_regularizer=None,
 
bias_regularizer=None,
 
activity_regularizer=None,
 
kernel_constraint=None,
 
bias_constraint=None,
 
trainable=True,
 
name=None,
 
reuse=None
 
)

比较重要的几个参数是inputs, filters, kernel_size,下面分别说明

inputs : 输入tensor, 维度(None, a, b) 是一个三维的tensor

None : 一般是填充样本的个数,batch_size

a : 句子中的词数或者字数

b : 字或者词的向量维度

filters : 过滤器的个数

kernel_size : 卷积核的大小,卷积核其实应该是一个二维的,这里只需要指定一维,是因为卷积核的第二维与输入的词向量维度是一致的,因为对于句子而言,卷积的移动方向只能是沿着词的方向,即只能在列维度移动。一个例子:

import tensorflow as tf
import numpy as np
 
 
if __name__ == '__main__':
  inputs = tf.constant(np.arange(1, 6, dtype=np.float32), shape=[1, 5, 1])
  cov2 = tf.layers.conv1d(inputs, filters=1, kernel_size=2, strides=1, padding='VALID')
  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    out = sess.run(cov2)
    print(out)

输出结果:

[[[-1.9953331]
 [-3.5520997]
 [-5.108866 ]
 [-6.6656327]]]

也许你得到的结果和我得到的结果不同,因为在这个函数里面只是设置了卷积核的尺寸和步长,没有设置具体的卷积核矩阵,所以这个卷积核矩阵是随机生成的,就会出现可能运行上面的程序出现不同结果的情况。

以上这篇对tensorflow中tf.nn.conv1d和layers.conv1d的区别详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python Django批量导入不重复数据
Mar 25 Python
Python入门之三角函数atan2()函数详解
Nov 08 Python
Numpy中stack(),hstack(),vstack()函数用法介绍及实例
Jan 09 Python
Django集成CAS单点登录的方法示例
Jun 10 Python
PyQt5 多窗口连接实例
Jun 19 Python
python实现的批量分析xml标签中各个类别个数功能示例
Dec 30 Python
python GUI库图形界面开发之PyQt5树形结构控件QTreeWidget详细使用方法与实例
Mar 02 Python
python闭包与引用以及需要注意的陷阱
Sep 18 Python
python时间time模块处理大全
Oct 25 Python
详解Python流程控制语句
Oct 28 Python
Pytorch数据读取之Dataset和DataLoader知识总结
May 23 Python
python实现学员管理系统(面向对象版)
Jun 05 Python
python中文分词库jieba使用方法详解
Feb 11 #Python
Transpose 数组行列转置的限制方式
Feb 11 #Python
Tensorflow:转置函数 transpose的使用详解
Feb 11 #Python
tensorflow多维张量计算实例
Feb 11 #Python
python误差棒图errorbar()函数实例解析
Feb 11 #Python
解决Python3.8用pip安装turtle-0.0.2出现错误问题
Feb 11 #Python
python scatter函数用法实例详解
Feb 11 #Python
You might like
如何隐藏你的.php文件
2007/01/04 PHP
php中inlcude()性能对比详解
2012/09/16 PHP
分享下页面关键字抓取components.arrow.com站点代码
2014/01/30 PHP
PHP中如何防止外部恶意提交调用ajax接口
2016/04/11 PHP
PHP设计模式之工厂模式与单例模式
2016/09/28 PHP
解决php用mysql方式连接数据库出现Deprecated报错问题
2019/12/25 PHP
JavaScript中遍历对象的property的3种方法介绍
2014/12/30 Javascript
JS运动相关知识点小结(附弹性运动示例)
2016/01/08 Javascript
javascript鼠标右键菜单自定义效果
2020/12/08 Javascript
JavaScript中各种引用类型的常用操作方法小结
2016/05/05 Javascript
Highcharts 多个Y轴动态刷新数据的实现代码
2016/05/28 Javascript
jquery编写日期选择器
2017/03/16 Javascript
javascript 动态生成css代码的两种方法
2017/03/17 Javascript
如何使用angularJs
2017/05/08 Javascript
Vue组件通信的几种实现方法
2019/04/25 Javascript
vue读取本地的excel文件并显示在网页上方法示例
2019/05/29 Javascript
js实现数据导出为EXCEL(支持大量数据导出)
2020/03/31 Javascript
Python标准库inspect的具体使用方法
2017/12/06 Python
python增加矩阵维度的实例讲解
2018/04/04 Python
用Python编写一个高效的端口扫描器的方法
2018/12/20 Python
Python3.5内置模块之os模块、sys模块、shutil模块用法实例分析
2019/04/27 Python
Python 识别12306图片验证码物品的实现示例
2020/01/20 Python
浅析python中的del用法
2020/09/02 Python
Python爬虫scrapy框架Cookie池(微博Cookie池)的使用
2021/01/13 Python
如何利用CSS3制作3D效果文字具体实现样式
2013/05/02 HTML / CSS
HTML5 本地存储之如果没有数据库究竟会怎样
2013/04/25 HTML / CSS
详解如何通过H5(浏览器/WebView/其他)唤起本地app
2017/12/11 HTML / CSS
约瑟夫·特纳男装:Joseph Turner
2017/10/10 全球购物
潘多拉珠宝俄罗斯官方网上商店:PANDORA俄罗斯
2020/09/22 全球购物
什么是View State?
2013/01/27 面试题
中学教师管理制度
2014/01/14 职场文书
党支部创先争优公开承诺书
2015/04/30 职场文书
反腐倡廉主题教育活动总结
2015/05/07 职场文书
辅导员学期工作总结
2015/08/14 职场文书
《认识年月日》教学反思
2016/02/19 职场文书
HTML+CSS+JS实现图片的瀑布流布局的示例代码
2021/04/22 HTML / CSS