基于Tensorflow一维卷积用法详解


Posted in Python onMay 22, 2020

我就废话不多说了,大家还是直接看代码吧!

import tensorflow as tf
import numpy as np
input = tf.constant(1,shape=(64,10,1),dtype=tf.float32,name='input')#shape=(batch,in_width,in_channels)
w = tf.constant(3,shape=(3,1,32),dtype=tf.float32,name='w')#shape=(filter_width,in_channels,out_channels)
conv1 = tf.nn.conv1d(input,w,2,'VALID') #2为步长
print(conv1.shape)#宽度计算(width-kernel_size+1)/strides ,(10-3+1)/2=4 (64,4,32)
conv2 = tf.nn.conv1d(input,w,2,'SAME') #步长为2
print(conv2.shape)#宽度计算width/strides 10/2=5 (64,5,32)
conv3 = tf.nn.conv1d(input,w,1,'SAME') #步长为1
print(conv3.shape) # (64,10,32)
with tf.Session() as sess:
 print(sess.run(conv1))
 print(sess.run(conv2))
 print(sess.run(conv3))

基于Tensorflow一维卷积用法详解

以下是input_shape=(1,10,1), w = (3,1,1)时,conv1的shape

基于Tensorflow一维卷积用法详解

以下是input_shape=(1,10,1), w = (3,1,3)时,conv1的shape

基于Tensorflow一维卷积用法详解

补充知识:tensorflow中一维卷积conv1d处理语言序列举例

tf.nn.conv1d:

函数形式: tf.nn.conv1d(value, filters, stride, padding, use_cudnn_on_gpu=None, data_format=None, name=None):

程序举例:

import tensorflow as tf
import numpy as np
sess = tf.InteractiveSession()
 
# --------------- tf.nn.conv1d -------------------
inputs=tf.ones((64,10,3)) # [batch, n_sqs, embedsize]
w=tf.constant(1,tf.float32,(5,3,32)) # [w_high, embedsize, n_filers]
conv1 = tf.nn.conv1d(inputs,w,stride=2 ,padding='SAME') # conv1=[batch, round(n_sqs/stride), n_filers],stride是步长。
 
tf.global_variables_initializer().run()
out = sess.run(conv1)
print(out)

注:一维卷积中padding='SAME'只在输入的末尾填充0

tf.layters.conv1d:

函数形式:tf.layters.conv1d(inputs, filters, kernel_size, strides=1, padding='valid', data_format='channels_last', dilation_rate=1, activation=None, use_bias=True,...)

程序举例:

import tensorflow as tf
import numpy as np
sess = tf.InteractiveSession()
 
# --------------- tf.layters.conv1d -------------------
inputs=tf.ones((64,10,3)) # [batch, n_sqs, embedsize]
num_filters=32
kernel_size =5
conv2 = tf.layers.conv1d(inputs, num_filters, kernel_size,strides=2, padding='valid',name='conv2') # shape = (batchsize, round(n_sqs/strides),num_filters)
tf.global_variables_initializer().run()
out = sess.run(conv2)
print(out)

二维卷积实现一维卷积:

import tensorflow as tf
sess = tf.InteractiveSession()
def conv2d(x, W):
 return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding='SAME')
def max_pool_1x2(x):
 return tf.nn.avg_pool(x, ksize=[1,1,2,1], strides=[1,1,2,1], padding='SAME')
'''
ksize = [x, pool_height, pool_width, x]
strides = [x, pool_height, pool_width, x]
'''
 
x = tf.Variable([[1,2,3,4]], dtype=tf.float32)
x = tf.reshape(x, [1,1,4,1]) #这一步必不可少,否则会报错说维度不一致;
'''
[batch, in_height, in_width, in_channels] = [1,1,4,1]
'''
 
W_conv1 = tf.Variable([1,1,1],dtype=tf.float32) # 权重值
W_conv1 = tf.reshape(W_conv1, [1,3,1,1]) # 这一步同样必不可少
'''
[filter_height, filter_width, in_channels, out_channels]
'''
h_conv1 = conv2d(x, W_conv1) # 结果:[4,8,12,11]
h_pool1 = max_pool_1x2(h_conv1)
tf.global_variables_initializer().run()
print(sess.run(h_conv1)) # 结果array([6,11.5])x

两种池化操作:

# 1:stride max pooling
convs = tf.expand_dims(conv, axis=-1) # shape=[?,596,256,1]
smp = tf.nn.max_pool(value=convs, ksize=[1, 3, self.config.num_filters, 1], strides=[1, 3, 1, 1],
     padding='SAME') # shape=[?,299,256,1]
smp = tf.squeeze(smp, -1) # shape=[?,299,256]
smp = tf.reshape(smp, shape=(-1, 199 * self.config.num_filters))
 
# 2: global max pooling layer
gmp = tf.reduce_max(conv, reduction_indices=[1], name='gmp')

不同核尺寸卷积操作:

kernel_sizes = [3,4,5] # 分别用窗口大小为3/4/5的卷积核
with tf.name_scope("mul_cnn"):
 pooled_outputs = []
 for kernel_size in kernel_sizes:
  # CNN layer
  conv = tf.layers.conv1d(embedding_inputs, self.config.num_filters, kernel_size, name='conv-%s' % kernel_size)
  # global max pooling layer
  gmp = tf.reduce_max(conv, reduction_indices=[1], name='gmp')
  pooled_outputs.append(gmp)
 self.h_pool = tf.concat(pooled_outputs, 1) #池化后进行拼接

以上这篇基于Tensorflow一维卷积用法详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中django框架通过正则搜索页面上email地址的方法
Mar 21 Python
Python自定义scrapy中间模块避免重复采集的方法
Apr 07 Python
手把手教你用python抢票回家过年(代码简单)
Jan 21 Python
python中将一个全部为int的list 转化为str的list方法
Apr 09 Python
python绘制漏斗图步骤详解
Mar 04 Python
Python批量查询关键词微信指数实例方法
Jun 27 Python
python把ipynb文件转换成pdf文件过程详解
Jul 09 Python
基于django ManyToMany 使用的注意事项详解
Aug 09 Python
解决python中的幂函数、指数函数问题
Nov 25 Python
python中使用paramiko模块并实现远程连接服务器执行上传下载功能
Feb 29 Python
python异步Web框架sanic的实现
Apr 27 Python
使用Python+Appuim 清理微信的方法
Jan 26 Python
Python参数传递机制传值和传引用原理详解
May 22 #Python
python filecmp.dircmp实现递归比对两个目录的方法
May 22 #Python
关于keras.layers.Conv1D的kernel_size参数使用介绍
May 22 #Python
Python参数传递对象的引用原理解析
May 22 #Python
Python configparser模块常用方法解析
May 22 #Python
keras中的卷积层&池化层的用法
May 22 #Python
Keras Convolution1D与Convolution2D区别说明
May 22 #Python
You might like
php中unserialize返回false的解决方法
2014/09/22 PHP
php开发工具有哪五款
2015/11/09 PHP
使用Yii2实现主从数据库设置
2016/11/20 PHP
PDO::rollBack讲解
2019/01/29 PHP
php多进程应用场景实例详解
2019/07/22 PHP
让任务管理器中的CPU跳舞的js代码
2008/11/01 Javascript
JavaScript mapreduce工作原理简析
2012/11/25 Javascript
js实现一个省市区三级联动选择框代码分享
2013/03/06 Javascript
关于IE BUG与字符串截取substr的解决办法
2013/04/10 Javascript
input链接页面、打开新网页等等的具体实现
2013/12/30 Javascript
Javascript的严格模式strict mode详细介绍
2014/06/06 Javascript
Javascript实现网络监测的方法
2015/07/31 Javascript
jquery插件之文字间歇自动向上滚动效果代码
2016/02/25 Javascript
详解angularJs指令的3种绑定策略
2017/04/13 Javascript
vue+node 实现视频在线播放的实例代码
2020/10/19 Javascript
python使用MySQLdb访问mysql数据库的方法
2015/08/03 Python
python将ansible配置转为json格式实例代码
2017/05/15 Python
Python书单 不将就
2017/07/11 Python
python模块之sys模块和序列化模块(实例讲解)
2017/09/13 Python
很酷的python表白工具 你喜欢我吗
2019/04/11 Python
浅析Python 中几种字符串格式化方法及其比较
2019/07/02 Python
Python3 tkinter 实现文件读取及保存功能
2019/09/12 Python
python 链接sqlserver 写接口实例
2020/03/11 Python
python中shell执行知识点
2020/05/06 Python
Python-for循环的内部机制
2020/06/12 Python
Melissa香港官网:MDreams
2016/07/01 全球购物
董事长秘书岗位职责
2013/11/29 职场文书
程序员求职信
2014/04/16 职场文书
小学优秀辅导员事迹材料
2014/05/11 职场文书
幼儿园感恩节活动方案
2014/10/06 职场文书
亲属关系公证书样本
2015/01/23 职场文书
保研推荐信范文
2015/03/25 职场文书
2015年学校综合治理工作总结
2015/07/20 职场文书
详解RedisTemplate下Redis分布式锁引发的系列问题
2021/04/27 Redis
如何利用React实现图片识别App
2022/02/18 Javascript
python实现学生信息管理系统(面向对象)
2022/06/05 Python