基于Tensorflow一维卷积用法详解


Posted in Python onMay 22, 2020

我就废话不多说了,大家还是直接看代码吧!

import tensorflow as tf
import numpy as np
input = tf.constant(1,shape=(64,10,1),dtype=tf.float32,name='input')#shape=(batch,in_width,in_channels)
w = tf.constant(3,shape=(3,1,32),dtype=tf.float32,name='w')#shape=(filter_width,in_channels,out_channels)
conv1 = tf.nn.conv1d(input,w,2,'VALID') #2为步长
print(conv1.shape)#宽度计算(width-kernel_size+1)/strides ,(10-3+1)/2=4 (64,4,32)
conv2 = tf.nn.conv1d(input,w,2,'SAME') #步长为2
print(conv2.shape)#宽度计算width/strides 10/2=5 (64,5,32)
conv3 = tf.nn.conv1d(input,w,1,'SAME') #步长为1
print(conv3.shape) # (64,10,32)
with tf.Session() as sess:
 print(sess.run(conv1))
 print(sess.run(conv2))
 print(sess.run(conv3))

基于Tensorflow一维卷积用法详解

以下是input_shape=(1,10,1), w = (3,1,1)时,conv1的shape

基于Tensorflow一维卷积用法详解

以下是input_shape=(1,10,1), w = (3,1,3)时,conv1的shape

基于Tensorflow一维卷积用法详解

补充知识:tensorflow中一维卷积conv1d处理语言序列举例

tf.nn.conv1d:

函数形式: tf.nn.conv1d(value, filters, stride, padding, use_cudnn_on_gpu=None, data_format=None, name=None):

程序举例:

import tensorflow as tf
import numpy as np
sess = tf.InteractiveSession()
 
# --------------- tf.nn.conv1d -------------------
inputs=tf.ones((64,10,3)) # [batch, n_sqs, embedsize]
w=tf.constant(1,tf.float32,(5,3,32)) # [w_high, embedsize, n_filers]
conv1 = tf.nn.conv1d(inputs,w,stride=2 ,padding='SAME') # conv1=[batch, round(n_sqs/stride), n_filers],stride是步长。
 
tf.global_variables_initializer().run()
out = sess.run(conv1)
print(out)

注:一维卷积中padding='SAME'只在输入的末尾填充0

tf.layters.conv1d:

函数形式:tf.layters.conv1d(inputs, filters, kernel_size, strides=1, padding='valid', data_format='channels_last', dilation_rate=1, activation=None, use_bias=True,...)

程序举例:

import tensorflow as tf
import numpy as np
sess = tf.InteractiveSession()
 
# --------------- tf.layters.conv1d -------------------
inputs=tf.ones((64,10,3)) # [batch, n_sqs, embedsize]
num_filters=32
kernel_size =5
conv2 = tf.layers.conv1d(inputs, num_filters, kernel_size,strides=2, padding='valid',name='conv2') # shape = (batchsize, round(n_sqs/strides),num_filters)
tf.global_variables_initializer().run()
out = sess.run(conv2)
print(out)

二维卷积实现一维卷积:

import tensorflow as tf
sess = tf.InteractiveSession()
def conv2d(x, W):
 return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding='SAME')
def max_pool_1x2(x):
 return tf.nn.avg_pool(x, ksize=[1,1,2,1], strides=[1,1,2,1], padding='SAME')
'''
ksize = [x, pool_height, pool_width, x]
strides = [x, pool_height, pool_width, x]
'''
 
x = tf.Variable([[1,2,3,4]], dtype=tf.float32)
x = tf.reshape(x, [1,1,4,1]) #这一步必不可少,否则会报错说维度不一致;
'''
[batch, in_height, in_width, in_channels] = [1,1,4,1]
'''
 
W_conv1 = tf.Variable([1,1,1],dtype=tf.float32) # 权重值
W_conv1 = tf.reshape(W_conv1, [1,3,1,1]) # 这一步同样必不可少
'''
[filter_height, filter_width, in_channels, out_channels]
'''
h_conv1 = conv2d(x, W_conv1) # 结果:[4,8,12,11]
h_pool1 = max_pool_1x2(h_conv1)
tf.global_variables_initializer().run()
print(sess.run(h_conv1)) # 结果array([6,11.5])x

两种池化操作:

# 1:stride max pooling
convs = tf.expand_dims(conv, axis=-1) # shape=[?,596,256,1]
smp = tf.nn.max_pool(value=convs, ksize=[1, 3, self.config.num_filters, 1], strides=[1, 3, 1, 1],
     padding='SAME') # shape=[?,299,256,1]
smp = tf.squeeze(smp, -1) # shape=[?,299,256]
smp = tf.reshape(smp, shape=(-1, 199 * self.config.num_filters))
 
# 2: global max pooling layer
gmp = tf.reduce_max(conv, reduction_indices=[1], name='gmp')

不同核尺寸卷积操作:

kernel_sizes = [3,4,5] # 分别用窗口大小为3/4/5的卷积核
with tf.name_scope("mul_cnn"):
 pooled_outputs = []
 for kernel_size in kernel_sizes:
  # CNN layer
  conv = tf.layers.conv1d(embedding_inputs, self.config.num_filters, kernel_size, name='conv-%s' % kernel_size)
  # global max pooling layer
  gmp = tf.reduce_max(conv, reduction_indices=[1], name='gmp')
  pooled_outputs.append(gmp)
 self.h_pool = tf.concat(pooled_outputs, 1) #池化后进行拼接

以上这篇基于Tensorflow一维卷积用法详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python编程实现随机生成多个椭圆实例代码
Jan 03 Python
详解Python自建logging模块
Jan 29 Python
Python wxpython模块响应鼠标拖动事件操作示例
Aug 23 Python
Pandas:Series和DataFrame删除指定轴上数据的方法
Nov 10 Python
python 接收处理外带的参数方法
Dec 03 Python
在python 中实现运行多条shell命令
Jan 07 Python
Python爬取知乎图片代码实现解析
Sep 17 Python
Win10 安装PyCharm2019.1.1(图文教程)
Sep 29 Python
python爬虫之遍历单个域名
Nov 20 Python
python 双循环遍历list 变量判断代码
May 04 Python
发工资啦!教你用Python实现邮箱自动群发工资条
May 10 Python
python中subplot大小的设置步骤
Jun 28 Python
Python参数传递机制传值和传引用原理详解
May 22 #Python
python filecmp.dircmp实现递归比对两个目录的方法
May 22 #Python
关于keras.layers.Conv1D的kernel_size参数使用介绍
May 22 #Python
Python参数传递对象的引用原理解析
May 22 #Python
Python configparser模块常用方法解析
May 22 #Python
keras中的卷积层&池化层的用法
May 22 #Python
Keras Convolution1D与Convolution2D区别说明
May 22 #Python
You might like
PHP中Header使用的HTTP协议及常用方法小结
2014/11/04 PHP
php使用GD创建保持宽高比缩略图的方法
2015/04/17 PHP
php数值转换时间及时间转换数值用法示例
2017/05/18 PHP
JavaScript 中的replace方法说明
2007/04/13 Javascript
javascript中window.event事件用法详解
2012/12/11 Javascript
jQuery中:checked选择器用法实例
2015/01/04 Javascript
js去除浏览器默认底图的方法
2015/06/08 Javascript
JavaScript获取对象在页面中位置坐标的方法
2016/02/03 Javascript
JavaScript 继承详解(六)
2016/10/11 Javascript
JS 组件系列之Bootstrap Table 冻结列功能IE浏览器兼容性问题解决方案
2017/06/30 Javascript
VUE:vuex 用户登录信息的数据写入与获取方式
2019/11/11 Javascript
uin-app+mockjs实现本地数据模拟
2020/08/26 Javascript
Python爬虫之xlml解析库(全面了解)
2017/08/08 Python
Python实现将一个正整数分解质因数的方法分析
2017/12/14 Python
神经网络理论基础及Python实现详解
2017/12/15 Python
python文本数据相似度的度量
2018/03/12 Python
python的dataframe和matrix的互换方法
2018/04/11 Python
在Python中调用Ping命令,批量IP的方法
2019/01/26 Python
python GUI库图形界面开发之PyQt5多行文本框控件QTextEdit详细使用方法实例
2020/02/28 Python
python中JWT用户认证的实现
2020/05/18 Python
Python实现文件压缩和解压的示例代码
2020/08/12 Python
css3中background新增的4个新的相关属性用法介绍
2013/09/26 HTML / CSS
CSS3实现鼠标悬停显示扩展内容
2016/08/24 HTML / CSS
详解HTML5中ol标签的用法
2015/09/08 HTML / CSS
欧洲最大的滑雪假期供应商之一:Sunweb Holidays
2018/01/06 全球购物
顶丰TOPPIK台湾官网:增发纤维假发,告别秃发困扰
2018/06/13 全球购物
枚举和一组预处理的#define有什么不同
2016/09/21 面试题
六一儿童节演讲稿
2014/05/23 职场文书
舞蹈专业求职信
2014/06/13 职场文书
应届大专生自荐书
2014/06/16 职场文书
乡镇镇长个人整改措施
2014/10/01 职场文书
领导欢迎词范文
2015/01/26 职场文书
大足石刻导游词
2015/02/02 职场文书
预备党员转正党小组意见
2015/06/01 职场文书
2016年党员干部公开承诺书
2016/03/24 职场文书
详解Go语言中Get/Post请求测试
2022/06/01 Golang