pytorch 自定义卷积核进行卷积操作方式


Posted in Python onDecember 30, 2019

一 卷积操作:在pytorch搭建起网络时,大家通常都使用已有的框架进行训练,在网络中使用最多就是卷积操作,最熟悉不过的就是

torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)

通过上面的输入发现想自定义自己的卷积核,比如高斯核,发现是行不通的,因为上面的参数里面只有卷积核尺寸,而权值weight是通过梯度一直更新的,是不确定的。

二 需要自己定义卷积核的目的:目前是需要通过一个VGG网络提取特征特后需要对其进行高斯卷积,卷积后再继续输入到网络中训练。

三 解决方案。使用

torch.nn.functional.conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1)

pytorch 自定义卷积核进行卷积操作方式

这里注意下weight的参数。与nn.Conv2d的参数不一样

可以发现F.conv2d可以直接输入卷积的权值weight,也就是卷积核。那么接下来就要首先生成一个高斯权重了。这里不直接一步步写了,直接输入就行。

kernel = [[0.03797616, 0.044863533, 0.03797616],
     [0.044863533, 0.053, 0.044863533],
     [0.03797616, 0.044863533, 0.03797616]]

四 完整代码

class GaussianBlur(nn.Module):
  def __init__(self):
    super(GaussianBlur, self).__init__()
    kernel = [[0.03797616, 0.044863533, 0.03797616],
         [0.044863533, 0.053, 0.044863533],
         [0.03797616, 0.044863533, 0.03797616]]
    kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0)
    self.weight = nn.Parameter(data=kernel, requires_grad=False)
 
  def forward(self, x):
    x1 = x[:, 0]
    x2 = x[:, 1]
    x3 = x[:, 2]
    x1 = F.conv2d(x1.unsqueeze(1), self.weight, padding=2)
    x2 = F.conv2d(x2.unsqueeze(1), self.weight, padding=2)
    x3 = F.conv2d(x3.unsqueeze(1), self.weight, padding=2)
    x = torch.cat([x1, x2, x3], dim=1)
    return x

这里为了网络模型需要写成了一个类,这里假设输入的x也就是经过网络提取后的三通道特征图(当然不一定是三通道可以是任意通道)

如果是任意通道的话,使用torch.expand()向输入的维度前面进行扩充。如下:

def blur(self, tensor_image):
    kernel = [[0.03797616, 0.044863533, 0.03797616],
        [0.044863533, 0.053, 0.044863533],
        [0.03797616, 0.044863533, 0.03797616]]
    
    min_batch=tensor_image.size()[0]
    channels=tensor_image.size()[1]
    out_channel=channels
    kernel = torch.FloatTensor(kernel).expand(out_channel,channels,3,3)
    self.weight = nn.Parameter(data=kernel, requires_grad=False)
 
    return F.conv2d(tensor_image,self.weight,1,1)

以上这篇pytorch 自定义卷积核进行卷积操作方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python开发之文件操作用法实例
Nov 13 Python
浅析Python 中整型对象存储的位置
May 16 Python
Python之web模板应用
Dec 26 Python
python自动重试第三方包retrying模块的方法
Apr 24 Python
python 按照固定长度分割字符串的方法小结
Apr 30 Python
python Web开发你要理解的WSGI & uwsgi详解
Aug 01 Python
python tkinter基本属性详解
Sep 16 Python
Python列表切片常用操作实例解析
Dec 16 Python
python不同版本的_new_不同点总结
Dec 09 Python
Python+Opencv实现把图片、视频互转的示例
Dec 17 Python
使用Python爬取小姐姐图片(beautifulsoup法)
Feb 11 Python
移除Selenium中window.navigator.webdriver值
Jun 10 Python
PyTorch中反卷积的用法详解
Dec 30 #Python
python使用正则表达式(Regular Expression)方法超详细
Dec 30 #Python
Pytorch实现各种2d卷积示例
Dec 30 #Python
Python面向对象之多态原理与用法案例分析
Dec 30 #Python
Pytoch之torchvision.transforms图像变换实例
Dec 30 #Python
python面向对象之类属性和类方法案例分析
Dec 30 #Python
基于Python执行dos命令并获取输出的结果
Dec 30 #Python
You might like
php数组函数序列之array_search()- 按元素值返回键名
2011/11/04 PHP
php开发文档 会员收费1期
2012/08/14 PHP
php自动加载机制的深入分析
2013/06/08 PHP
PHP文件操作方法汇总
2015/07/01 PHP
PHP实现登录注册之BootStrap表单功能
2017/09/03 PHP
PHP连接及操作PostgreSQL数据库的方法详解
2019/01/30 PHP
用js计算页面执行时间的函数
2006/12/07 Javascript
Jquery下:nth-child(an+b)的使用注意
2011/05/28 Javascript
jQuery中获取Radio元素值的方法
2013/07/02 Javascript
js中的异常处理try...catch使用介绍
2013/09/21 Javascript
jQuery通过点击行来删除HTML表格行的实现示例
2014/09/10 Javascript
JS合并数组的几种方法及优劣比较
2014/09/19 Javascript
jQuery+JSON实现AJAX二级联动实例分析
2015/12/18 Javascript
AngularJS基础 ng-mouseover 指令简单示例
2016/08/02 Javascript
jquery根据name取得select选中的值实例(超简单)
2018/01/25 jQuery
jquery的 filter()方法使用教程
2018/03/22 jQuery
vue权限路由实现的方法示例总结
2018/07/29 Javascript
深度解读vue-resize的具体用法
2020/07/08 Javascript
Python爬取APP下载链接的实现方法
2016/09/30 Python
利用Python开发实现简单的记事本
2016/11/15 Python
SELENIUM自动化模拟键盘快捷键操作实现解析
2019/10/28 Python
Python列表原理与用法详解【创建、元素增加、删除、访问、计数、切片、遍历等】
2019/10/30 Python
Keras在训练期间可视化训练误差和测试误差实例
2020/06/16 Python
解决python3.x安装numpy成功但import出错的问题
2020/11/17 Python
Mavi牛仔裤美国官网:土耳其著名牛仔品牌
2016/09/24 全球购物
微软巴西官方网站:Microsoft Brasil
2019/09/26 全球购物
假日旅行社实习自我鉴定
2013/09/24 职场文书
机电一体化大学生求职信
2013/11/08 职场文书
企业宗旨标语
2014/06/10 职场文书
事业单位工作人员年度考核个人总结
2015/02/12 职场文书
2016年度员工工作表现评语
2015/12/02 职场文书
给原生html中添加水印遮罩层的实现示例
2021/04/02 Javascript
python实现三次密码验证的示例
2021/04/29 Python
详解Python函数print用法
2021/06/18 Python
浅谈Python数学建模之数据导入
2021/06/23 Python
JavaScript架构localStorage特殊场景下二次封装操作
2022/06/21 Javascript