基于Tensorflow一维卷积用法详解


Posted in Python onMay 22, 2020

我就废话不多说了,大家还是直接看代码吧!

import tensorflow as tf
import numpy as np
input = tf.constant(1,shape=(64,10,1),dtype=tf.float32,name='input')#shape=(batch,in_width,in_channels)
w = tf.constant(3,shape=(3,1,32),dtype=tf.float32,name='w')#shape=(filter_width,in_channels,out_channels)
conv1 = tf.nn.conv1d(input,w,2,'VALID') #2为步长
print(conv1.shape)#宽度计算(width-kernel_size+1)/strides ,(10-3+1)/2=4 (64,4,32)
conv2 = tf.nn.conv1d(input,w,2,'SAME') #步长为2
print(conv2.shape)#宽度计算width/strides 10/2=5 (64,5,32)
conv3 = tf.nn.conv1d(input,w,1,'SAME') #步长为1
print(conv3.shape) # (64,10,32)
with tf.Session() as sess:
 print(sess.run(conv1))
 print(sess.run(conv2))
 print(sess.run(conv3))

基于Tensorflow一维卷积用法详解

以下是input_shape=(1,10,1), w = (3,1,1)时,conv1的shape

基于Tensorflow一维卷积用法详解

以下是input_shape=(1,10,1), w = (3,1,3)时,conv1的shape

基于Tensorflow一维卷积用法详解

补充知识:tensorflow中一维卷积conv1d处理语言序列举例

tf.nn.conv1d:

函数形式: tf.nn.conv1d(value, filters, stride, padding, use_cudnn_on_gpu=None, data_format=None, name=None):

程序举例:

import tensorflow as tf
import numpy as np
sess = tf.InteractiveSession()
 
# --------------- tf.nn.conv1d -------------------
inputs=tf.ones((64,10,3)) # [batch, n_sqs, embedsize]
w=tf.constant(1,tf.float32,(5,3,32)) # [w_high, embedsize, n_filers]
conv1 = tf.nn.conv1d(inputs,w,stride=2 ,padding='SAME') # conv1=[batch, round(n_sqs/stride), n_filers],stride是步长。
 
tf.global_variables_initializer().run()
out = sess.run(conv1)
print(out)

注:一维卷积中padding='SAME'只在输入的末尾填充0

tf.layters.conv1d:

函数形式:tf.layters.conv1d(inputs, filters, kernel_size, strides=1, padding='valid', data_format='channels_last', dilation_rate=1, activation=None, use_bias=True,...)

程序举例:

import tensorflow as tf
import numpy as np
sess = tf.InteractiveSession()
 
# --------------- tf.layters.conv1d -------------------
inputs=tf.ones((64,10,3)) # [batch, n_sqs, embedsize]
num_filters=32
kernel_size =5
conv2 = tf.layers.conv1d(inputs, num_filters, kernel_size,strides=2, padding='valid',name='conv2') # shape = (batchsize, round(n_sqs/strides),num_filters)
tf.global_variables_initializer().run()
out = sess.run(conv2)
print(out)

二维卷积实现一维卷积:

import tensorflow as tf
sess = tf.InteractiveSession()
def conv2d(x, W):
 return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding='SAME')
def max_pool_1x2(x):
 return tf.nn.avg_pool(x, ksize=[1,1,2,1], strides=[1,1,2,1], padding='SAME')
'''
ksize = [x, pool_height, pool_width, x]
strides = [x, pool_height, pool_width, x]
'''
 
x = tf.Variable([[1,2,3,4]], dtype=tf.float32)
x = tf.reshape(x, [1,1,4,1]) #这一步必不可少,否则会报错说维度不一致;
'''
[batch, in_height, in_width, in_channels] = [1,1,4,1]
'''
 
W_conv1 = tf.Variable([1,1,1],dtype=tf.float32) # 权重值
W_conv1 = tf.reshape(W_conv1, [1,3,1,1]) # 这一步同样必不可少
'''
[filter_height, filter_width, in_channels, out_channels]
'''
h_conv1 = conv2d(x, W_conv1) # 结果:[4,8,12,11]
h_pool1 = max_pool_1x2(h_conv1)
tf.global_variables_initializer().run()
print(sess.run(h_conv1)) # 结果array([6,11.5])x

两种池化操作:

# 1:stride max pooling
convs = tf.expand_dims(conv, axis=-1) # shape=[?,596,256,1]
smp = tf.nn.max_pool(value=convs, ksize=[1, 3, self.config.num_filters, 1], strides=[1, 3, 1, 1],
     padding='SAME') # shape=[?,299,256,1]
smp = tf.squeeze(smp, -1) # shape=[?,299,256]
smp = tf.reshape(smp, shape=(-1, 199 * self.config.num_filters))
 
# 2: global max pooling layer
gmp = tf.reduce_max(conv, reduction_indices=[1], name='gmp')

不同核尺寸卷积操作:

kernel_sizes = [3,4,5] # 分别用窗口大小为3/4/5的卷积核
with tf.name_scope("mul_cnn"):
 pooled_outputs = []
 for kernel_size in kernel_sizes:
  # CNN layer
  conv = tf.layers.conv1d(embedding_inputs, self.config.num_filters, kernel_size, name='conv-%s' % kernel_size)
  # global max pooling layer
  gmp = tf.reduce_max(conv, reduction_indices=[1], name='gmp')
  pooled_outputs.append(gmp)
 self.h_pool = tf.concat(pooled_outputs, 1) #池化后进行拼接

以上这篇基于Tensorflow一维卷积用法详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现抓取城市的PM2.5浓度和排名
Mar 19 Python
python实现对一个完整url进行分割的方法
Apr 29 Python
python中list常用操作实例详解
Jun 03 Python
深入理解Python中的super()方法
Nov 20 Python
python用户评论标签匹配的解决方法
May 31 Python
django开发post接口简单案例,获取参数值的方法
Dec 11 Python
python3实现钉钉消息推送的方法示例
Mar 14 Python
python3 json数据格式的转换(dumps/loads的使用、dict to str/str to dict、json字符串/字典的相互转换)
Apr 01 Python
手把手教你Python yLab的绘制折线图的画法
Oct 23 Python
Python3.7将普通图片(png)转换为SVG图片格式(网站logo图标)动起来
Apr 21 Python
python必学知识之文件操作(建议收藏)
May 30 Python
Python爬取用户观影数据并分析用户与电影之间的隐藏信息!
Jun 29 Python
Python参数传递机制传值和传引用原理详解
May 22 #Python
python filecmp.dircmp实现递归比对两个目录的方法
May 22 #Python
关于keras.layers.Conv1D的kernel_size参数使用介绍
May 22 #Python
Python参数传递对象的引用原理解析
May 22 #Python
Python configparser模块常用方法解析
May 22 #Python
keras中的卷积层&池化层的用法
May 22 #Python
Keras Convolution1D与Convolution2D区别说明
May 22 #Python
You might like
php MySQL与分页效率
2008/06/04 PHP
php中jQuery插件autocomplate的简单使用笔记
2012/06/14 PHP
shell脚本作为保证PHP脚本不挂掉的守护进程实例分享
2013/07/15 PHP
PHP面向对象程序设计实例分析
2016/01/26 PHP
php+redis实现多台服务器内网存储session并读取示例
2017/01/12 PHP
PHP实现的折半查询算法示例
2017/10/09 PHP
js wmp操作代码小结(音乐连播功能)
2008/11/08 Javascript
JS 实现双色表格实现代码
2009/11/24 Javascript
JavaScript Object的extend是一个常用的功能
2009/12/02 Javascript
js 实现图片预加载(js操作 Image对象属性complete ,事件onload 异步加载图片)
2011/03/25 Javascript
js 获取屏幕各种宽高的方法(浏览器兼容)
2013/05/15 Javascript
node.js中的fs.existsSync方法使用说明
2014/12/17 Javascript
jquery 无限极下拉菜单的简单实例(精简浓缩版)
2016/05/31 Javascript
canvas实现图片根据滑块放大缩小效果
2017/02/24 Javascript
nodejs取得当前执行路径的方法
2018/05/13 NodeJs
在Bootstrap开发框架中使用dataTable直接录入表格行数据的方法
2018/10/25 Javascript
Layui 数据表格批量删除和多条件搜索的实例
2019/09/04 Javascript
使用 js 简单的实现 bind、call 、aplly代码实例
2019/09/07 Javascript
微信小程序事件流原理解析
2019/11/27 Javascript
javascript实现切割轮播效果
2019/11/28 Javascript
Python 中 list 的各项操作技巧
2017/04/13 Python
Django视图和URL配置详解
2018/01/31 Python
使用Keras预训练好的模型进行目标类别预测详解
2020/06/27 Python
pandas处理csv文件的方法步骤
2020/10/16 Python
Notino芬兰:购买香水和化妆品
2019/04/15 全球购物
意大利一家专营包包和配饰的网上商店:Borse Last Minute
2019/08/26 全球购物
如何实现jdbc性能优化
2012/07/30 面试题
非功能性需求都包括哪些方面
2013/10/29 面试题
七一表彰活动方案
2014/01/18 职场文书
五型班组建设方案
2014/02/10 职场文书
四风对照检查材料范文
2014/09/27 职场文书
紫日观后感
2015/06/05 职场文书
公司食堂管理制度
2015/08/05 职场文书
win10下go mod配置方式
2021/04/25 Golang
聊聊Lombok中的@Builder注解使用教程
2021/11/17 Java/Android
Python加密与解密模块hashlib与hmac
2022/06/05 Python