PyTorch 中的傅里叶卷积实现示例


Posted in Python onDecember 11, 2020

卷积

卷积在数据分析中无处不在。几十年来,它们一直被用于信号和图像处理。最近,它们成为现代神经网络的重要组成部分。如果你处理数据的话,你可能会遇到错综复杂的问题。

数学上,卷积表示为:

PyTorch 中的傅里叶卷积实现示例

尽管离散卷积在计算应用程序中更为常见,但在本文的大部分内容中我将使用连续形式,因为使用连续变量来证明卷积定理(下面讨论)要容易得多。之后,我们将回到离散情况,并使用傅立叶变换在 PyTorch 中实现它。离散卷积可以看作是连续卷积的近似,其中连续函数离散在规则网格上。因此,我们不会为这个离散的案例重新证明卷积定理。

卷积定理

从数学上来说,卷积定理可以这样描述:

PyTorch 中的傅里叶卷积实现示例

其中的连续傅里叶变换是(达到正常化常数) :

PyTorch 中的傅里叶卷积实现示例

换句话说,位置空间中的卷积等价于频率空间中的直乘。这个想法是相当不直观的,但是对于连续的情况来说,证明卷积定理是惊人的容易。要做到这一点,首先要写出等式的左边。

PyTorch 中的傅里叶卷积实现示例

现在切换积分的顺序,替换变量(x = y + z) ,并分离两个被积函数。

PyTorch 中的傅里叶卷积实现示例

我们为什么要关心这一切?

因为快速傅里叶变换的算法复杂度低于卷积。直接卷积运算具有复杂度 O(n^2) ,因为在 f 中,我们传递 g 中的每个元素,所以可以在 O(nlogn)时间内计算出快速傅立叶变换。当输入数组很大时,它们比卷积要快得多。在这些情况下,我们可以使用卷积定理计算频率空间中的卷积,然后执行逆傅里叶变换回到位置空间。

当输入较小时(例如3x3卷积内核) ,直接卷积仍然更快。在机器学习应用程序中,使用小内核更为常见,因此像 PyTorch 和 Tensorflow 这样的深度学习库只提供直接卷积的实现。但是在现实世界中有很多使用大内核的用例,其中傅立叶卷积算法更有效。

PyTorch 实现

现在,我将演示如何在 PyTorch 中实现傅里叶卷积函数。它应该模仿 torch.nn.functional.convNd 的功能,并利用 fft,而不需要用户做任何额外的工作。因此,它应该接受三个 Tensors (signal、kernel 和可选 bias)和应用于输入的 padding。从概念上讲,这个函数的内部工作原理是:

def fft_conv(
  signal: Tensor, kernel: Tensor, bias: Tensor = None, padding: int = 0,
) -> Tensor:
  # 1. Pad the input signal & kernel tensors
  # 2. Compute FFT for both signal & kernel
  # 3. Multiply the transformed Tensors together
  # 4. Compute inverse FFT
  # 5. Add bias and return

让我们按照上面显示的操作顺序逐步构建 FFT 卷积。对于这个例子,我将构建一个一维傅里叶卷积,但是将其扩展到二维和三维卷积是很简单的。

1. 填充输入数组

我们需要确保 signal 和 kernel 在填充之后有相同的大小。应用初始填充 signal,然后调整 kernel 的填充以匹配。

# 1. Pad the input signal & kernel tensors
signal = f.pad(signal, [padding, padding])
kernel_padding = [0, signal.size(-1) - kernel.size(-1)]
padded_kernel = f.pad(kernel, kernel_padding)

注意,我只在一边填充 kernel。我们希望原始内核位于填充数组的左侧,这样它就可以与 signal 数组的开始对齐。

2. 计算傅立叶变换

这非常简单,因为 n 维 fft 已经在 PyTorch 中实现了。我们简单地使用内置函数,并计算沿每个张量的最后一个维数的 FFT。

# 2. Perform fourier convolution
signal_fr = rfftn(signal, dim=-1)
kernel_fr = rfftn(padded_kernel, dim=-1)

3. 变换张量相乘

令人惊讶的是,这是我们功能中最复杂的部分。这有两个原因。(1) PyTorch 卷积运行于多维张量上,因此我们的 signal 和 kernel 张量实际上是三维的。从 PyTorch 文档中的这个方程式,我们可以看到矩阵乘法是在前两个维度上运行的(不包括偏差项) :

PyTorch 中的傅里叶卷积实现示例

我们将需要包括这个矩阵乘法,以及对转换后的维度的直接乘法。

PyTorch 实际上实现了互相关/值方法而不是卷积方法。(TensorFlow 和其他深度学习库也是如此。)互相关与卷积密切相关,但有一个重要的标志变化:

PyTorch 中的傅里叶卷积实现示例

与卷积相比,这有效地逆转了核的方向(g)。我们不是手动翻转内核,而是在傅里叶空间中利用内核的共轭复数来纠正这个问题。由于我们不需要创建一个全新的 Tensor,所以这样做的速度明显更快,内存效率也更高。(本文末尾的附录中简要说明了这种方法的工作原理。)

# 3. Multiply the transformed matrices
 
def complex_matmul(a: Tensor, b: Tensor) -> Tensor:
  """Multiplies two complex-valued tensors."""
  # Scalar matrix multiplication of two tensors, over only the first two dimensions.
  # Dimensions 3 and higher will have the same shape after multiplication.
  scalar_matmul = partial(torch.einsum, "ab..., cb... -> ac...") 
 
  # Compute the real and imaginary parts independently, then manually insert them
  # into the output Tensor. This is fairly hacky but necessary for PyTorch 1.7.0,
  # because Autograd is not enabled for complex matrix operations yet. Not exactly
  # idiomatic PyTorch code, but it should work for all future versions (>= 1.7.0).
  real = scalar_matmul(a.real, b.real) - scalar_matmul(a.imag, b.imag)
  imag = scalar_matmul(a.imag, b.real) + scalar_matmul(a.real, b.imag)
  c = torch.zeros(real.shape, dtype=torch.complex64)
  c.real, c.imag = real, imag
  return c 

# Conjugate the kernel for cross-correlation
kernel_fr.imag *= -1
output_fr = complex_matmul(signal_fr, kernel_fr)

PyTorch 1.7改进了对复数的支持,但是在 autograd 中还不支持对复数张量的许多操作。现在,我们必须编写我们自己的复杂 matmul 方法作为一个补丁。虽然不是很理想,但是它确实有效,并且在未来的版本中不会出现问题。

4. 计算逆变换

使用 torch.irfftn 可以直接计算逆变换,然后裁剪出额外的数组填充。

# 4. Compute inverse FFT, and remove extra padded values
output = irfftn(output_fr, dim=-1)
output = output[:, :, :signal.size(-1) - kernel.size(-1) + 1]

5. 添加偏执项并返回

添加偏差项也很容易。请记住,对于输出阵列中的每个通道,偏置项都有一个元素,并相应地调整其形状。

# 5. Optionally, add a bias term before returning.
if bias is not None:
  output += bias.view(1, -1, 1)

将上述代码整合在一起

为了完整起见,让我们将所有这些代码片段编译成一个内聚函数。

def fft_conv_1d(
  signal: Tensor, kernel: Tensor, bias: Tensor = None, padding: int = 0,
) -> Tensor:
  """
  Args:
    signal: (Tensor) Input tensor to be convolved with the kernel.
    kernel: (Tensor) Convolution kernel.
    bias: (Optional, Tensor) Bias tensor to add to the output.
    padding: (int) Number of zero samples to pad the input on the last dimension.
  Returns:
    (Tensor) Convolved tensor
  """
  # 1. Pad the input signal & kernel tensors
  signal = f.pad(signal, [padding, padding])
  kernel_padding = [0, signal.size(-1) - kernel.size(-1)]
  padded_kernel = f.pad(kernel, kernel_padding)
 
  # 2. Perform fourier convolution
  signal_fr = rfftn(signal, dim=-1)
  kernel_fr = rfftn(padded_kernel, dim=-1)
 
  # 3. Multiply the transformed matrices
  kernel_fr.imag *= -1
  output_fr = complex_matmul(signal_fr, kernel_fr)
 
  # 4. Compute inverse FFT, and remove extra padded values
  output = irfftn(output_fr, dim=-1)
  output = output[:, :, :signal.size(-1) - kernel.size(-1) + 1]
 
  # 5. Optionally, add a bias term before returning.
  if bias is not None:
    output += bias.view(1, -1, 1)
 
 
  return output

直接卷积测试

最后,我们将使用 torch.nn.functional.conv1d 来确认这在数值上等同于直接一维卷积。我们为所有输入构造随机张量,并测量输出值的相对差异。

import torch
import torch.nn.functional as f 
 
torch.manual_seed(1234)
kernel = torch.randn(2, 3, 1025)
signal = torch.randn(3, 3, 4096)
bias = torch.randn(2)
 
y0 = f.conv1d(signal, kernel, bias=bias, padding=512)
y1 = fft_conv_1d(signal, kernel, bias=bias, padding=512)
 
abs_error = torch.abs(y0 - y1)
print(f'\nAbs Error Mean: {abs_error.mean():.3E}')
print(f'Abs Error Std Dev: {abs_error.std():.3E}')
 
# Abs Error Mean: 1.272E-05

考虑到我们使用的是32位精度,每个元素相差大约1e-5?相当精确!让我们也执行一个快速的基准来测量每个方法的速度:

from timeit import timeit
direct_time = timeit(
  "f.conv1d(signal, kernel, bias=bias, padding=512)", 
  globals=locals(), 
  number=100
) / 100
fourier_time = timeit(
  "fft_conv_1d(signal, kernel, bias=bias, padding=512)", 
  globals=locals(), 
  number=100
) / 100
print(f"Direct time: {direct_time:.3E} s")
print(f"Fourier time: {fourier_time:.3E} s")
 
# Direct time: 1.523E-02 s
# Fourier time: 1.149E-03 s

测量的基准将随着您使用的机器而发生显著的变化。(我正在用一台非常旧的 Macbook Pro 进行测试。)对于1025的内核,傅里叶卷积似乎要快10倍以上。

总结

我希望这已经提供了一个彻底的介绍傅里叶卷积。我认为这是一个非常酷的技巧,在现实世界中有很多应用程序可以使用它。我也喜欢数学,所以看到编程和纯数学的结合是很有趣的。欢迎和鼓励所有的评论和建设性的批评,如果你喜欢这篇文章,请鼓掌!

附录:

卷积 vs. 互相关

在本文的前面,我们通过在傅里叶空间中取得内核的互相关共轭复数来实现。这实际上颠倒了 kernel 的方向,现在我想演示一下为什么会这样。首先,记住卷积和互相关的公式:

PyTorch 中的傅里叶卷积实现示例

然后,让我们来看看 g(x) 的傅里叶变换:

PyTorch 中的傅里叶卷积实现示例

注意,g(x)是实值的,所以它不受共轭复数变化的影响。然后,更改变量(y =-x)并简化表达式。

PyTorch 中的傅里叶卷积实现示例

到此这篇关于PyTorch 中的傅里叶卷积实现示例的文章就介绍到这了,更多相关PyTorch 傅里叶卷积内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python构造icmp echo请求和实现网络探测器功能代码分享
Jan 10 Python
Python简单遍历字典及删除元素的方法
Sep 18 Python
Python中表示字符串的三种方法
Sep 06 Python
python如何实现int函数的方法示例
Feb 19 Python
python实现手机通讯录搜索功能
Feb 22 Python
django 将model转换为字典的方法示例
Oct 16 Python
Python 使用Numpy对矩阵进行转置的方法
Jan 28 Python
利用Python对文件夹下图片数据进行批量改名的代码实例
Feb 21 Python
解决tensorflow训练时内存持续增加并占满的问题
Jan 19 Python
Python通过2种方法输出带颜色字体
Mar 02 Python
Keras框架中的epoch、bacth、batch size、iteration使用介绍
Jun 10 Python
使用pytorch实现论文中的unet网络
Jun 24 Python
python中append函数用法讲解
Dec 11 #Python
python实现图像随机裁剪的示例代码
Dec 10 #Python
python opencv图像处理(素描、怀旧、光照、流年、滤镜 原理及实现)
Dec 10 #Python
python 实现的IP 存活扫描脚本
Dec 10 #Python
class类在python中获取金融数据的实例方法
Dec 10 #Python
Python制作简单的剪刀石头布游戏
Dec 10 #Python
python给list排序的简单方法
Dec 10 #Python
You might like
浅谈PHP解析URL函数parse_url和parse_str
2014/11/11 PHP
php通过记录IP来防止表单重复提交方法分析
2014/12/16 PHP
CodeIgniter表单验证方法实例详解
2016/03/03 PHP
Laravel如何使用Redis共享Session
2018/02/23 PHP
TNC vs IO BO3 第一场2.13
2021/03/10 DOTA
JQuery 文本框使用小结
2010/05/22 Javascript
IE图片缓存document.execCommand("BackgroundImageCache",false,true)
2011/03/01 Javascript
找出字符串中出现次数最多的字母和出现次数精简版
2012/11/07 Javascript
javascript框架设计之种子模块
2015/06/23 Javascript
JS原型链怎么理解
2016/06/27 Javascript
JavaScript制作简易计算器(不用eval)
2017/02/05 Javascript
vue开发拖拽进度条滑动组件
2019/09/21 Javascript
JavaScript实现单图片上传并预览功能
2019/09/30 Javascript
jQuery HTML获取内容和属性操作实例分析
2020/05/20 jQuery
js实现炫酷光感效果
2020/09/05 Javascript
Python爬虫框架Scrapy安装使用步骤
2014/04/01 Python
python利用beautifulSoup实现爬虫
2014/09/29 Python
Python实现控制台进度条功能
2016/01/04 Python
Python实现简单文本字符串处理的方法
2018/01/22 Python
Window 64位下python3.6.2环境搭建图文教程
2018/09/19 Python
django框架事务处理小结【ORM 事务及raw sql,customize sql 事务处理】
2019/06/27 Python
python实现ip代理池功能示例
2019/07/05 Python
python处理RSTP视频流过程解析
2020/01/11 Python
Python通过2种方法输出带颜色字体
2020/03/02 Python
python3实现语音转文字(语音识别)和文字转语音(语音合成)
2020/10/14 Python
html5中的一些标签学习(心得)
2016/10/18 HTML / CSS
Ralph Lauren意大利官方网站:时尚界最负盛名的品牌之一
2018/10/18 全球购物
Timberland俄罗斯官方网上商店:全球领先的户外品牌
2020/03/15 全球购物
工商管理专业学生的自我评价
2013/10/01 职场文书
大学新生军训个人的自我评价
2013/10/03 职场文书
机械化及自动化毕业生的自我评价分享
2013/11/06 职场文书
战略合作协议书范本
2014/04/18 职场文书
公司市场部岗位职责
2015/04/15 职场文书
2016年学校党支部公开承诺书
2016/03/25 职场文书
毕业生自荐求职信书写的技巧
2019/08/26 职场文书
关于Python OS模块常用文件/目录函数详解
2021/07/01 Python