Pytorch 抽取vgg各层并进行定制化处理的方法


Posted in Python onAugust 20, 2019

工作中有时候需要对vgg进行定制化处理,比如有些时候需要借助于vgg的层结构,但是需要使用的是2 channels输入,等等需求,这时候可以使用vgg的原始结构用class重写一遍,但是这样的方式比较慢,并且容易出错,下面给出一种比较简单的方式

def define_vgg(vgg,input_channels,endlayer,use_maxpool=False): 
  vgg_ad = copy.deepcopy(vgg)
  model = nn.Sequential()
  i = 0
  for layer in list(vgg_ad.features):
    if i > endlayer:
      break
    if isinstance(layer, nn.Conv2d) and i is 0:
      name = "conv_" + str(i)
      layer = nn.Conv2d(input_channels,
               layer.out_channels,
               layer.kernel_size,
               stride = layer.stride,
               padding=layer.padding)
      model.add_module(name, layer)
    if isinstance(layer, nn.Conv2d):
      name = "conv_" + str(i)
      model.add_module(name, layer)
 
    if isinstance(layer, nn.ReLU):
      name = "leakyrelu_" + str(i)
      layer = nn.LeakyReLU(inplace=True) 
      model.add_module(name, layer)
 
    if isinstance(layer, nn.MaxPool2d):
      name = "pool_" + str(i)
      if use_maxpool:
        model.add_module(name, layer)
      else:
        avgpool = nn.AvgPool2d(kernel_size=layer.kernel_size, stride=layer.stride, padding=layer.padding)
        model.add_module(name, avgpool)
    i += 1
  return model

函数输入项中的vgg 是直接使用的import torchvision.models.vgg16 传入的是vgg16 非预训练版本。end_layer 是需要提取的层数,这里使用了vgg.features 是指仅仅在vgg.features 上进行层的提取;也可以根据定制在classifier上进行提取。

下面是我的一个提取前7层的示例,可以使用pyCharm evaluate 上面函数返回的model,可以看到这个示例的情况,这里我的定制条件是输入通道为2 ,需要提取前7层,并且将ReLu更换为LeakyRelu。

Sequential(
 (conv_0): Conv2d(2, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (leakyrelu_1): LeakyReLU(negative_slope=0.01, inplace)
 (conv_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (leakyrelu_3): LeakyReLU(negative_slope=0.01, inplace)
 (pool_4): AvgPool2d(kernel_size=2, stride=2, padding=0)
 (conv_5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (leakyrelu_6): LeakyReLU(negative_slope=0.01, inplace)
 (conv_7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)

以上这篇Pytorch 抽取vgg各层并进行定制化处理的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python单元测试框架unittest使用方法讲解
Apr 13 Python
在Python中用get()方法获取字典键值的教程
May 21 Python
遍历python字典几种方法总结(推荐)
Sep 11 Python
分析python切片原理和方法
Dec 19 Python
Python通过调用mysql存储过程实现更新数据功能示例
Apr 03 Python
解决新版Pycharm中Matplotlib图像不在弹出独立的显示窗口问题
Jan 15 Python
Python 移动光标位置的方法
Jan 20 Python
Django的Modelforms用法简介
Jul 27 Python
如何获取Python简单for循环索引
Nov 21 Python
Python基于httpx模块实现发送请求
Jul 07 Python
python爬虫看看虎牙女主播中谁最“顶”步骤详解
Dec 01 Python
Python 机器学习工具包SKlearn的安装与使用
May 14 Python
python实现抠图给证件照换背景源码
Aug 20 #Python
python爬虫 基于requests模块发起ajax的get请求实现解析
Aug 20 #Python
pytorch 在sequential中使用view来reshape的例子
Aug 20 #Python
pytorch在fintune时将sequential中的层输出方法,以vgg为例
Aug 20 #Python
python实现证件照换底功能
Aug 20 #Python
pytorch多进程加速及代码优化方法
Aug 19 #Python
用Pytorch训练CNN(数据集MNIST,使用GPU的方法)
Aug 19 #Python
You might like
Php+SqlServer实现分页显示
2006/10/09 PHP
PHP6 先修班 JSON实例代码
2008/08/23 PHP
Joomla开启SEF的方法
2016/05/04 PHP
学习ExtJS Panel常用方法
2009/10/07 Javascript
firefox和IE系列的相关区别整理 以备后用
2009/12/28 Javascript
Jquery实现视频播放页面的关灯开灯效果
2013/05/27 Javascript
浅谈JSON中stringify 函数、toJosn函数和parse函数
2015/01/26 Javascript
BOOTSTRAP时间控件显示在模态框下面的bug修复
2015/02/05 Javascript
JavaScript中用toString()方法返回时间为字符串
2015/06/12 Javascript
在for循环中length值是否需要缓存
2015/07/27 Javascript
jQuery实现带滑动条的菜单效果代码
2015/08/26 Javascript
javaScript数组迭代方法详解
2016/04/14 Javascript
jQuery实现区域打印功能代码详解
2016/06/17 Javascript
JavaScript读二进制文件并用ajax传输二进制流的方法
2016/07/18 Javascript
JavaScript优化以及前段开发小技巧
2017/02/02 Javascript
Vue 2.X的状态管理vuex记录详解
2017/03/23 Javascript
微信小程序 navbar实例详解
2017/05/11 Javascript
vue 挂载路由到头部导航的方法
2017/11/13 Javascript
详解Bootstrap 学习(一)入门
2019/04/12 Javascript
layui插件表单验证提交触发提交的例子
2019/09/09 Javascript
微信小程序通过websocket实时语音识别的实现代码
2020/08/19 Javascript
[48:52]DOTA2上海特级锦标赛A组小组赛#2 Secret VS CDEC第一局
2016/02/25 DOTA
Python3中多线程编程的队列运作示例
2015/04/16 Python
解决pyqt中ui编译成窗体.py中文乱码的问题
2016/12/23 Python
对python .txt文件读取及数据处理方法总结
2018/04/23 Python
PyTorch学习笔记之回归实战
2018/05/28 Python
python ipset管理 增删白名单的方法
2019/01/14 Python
python实现简单成绩录入系统
2019/09/19 Python
Python的pygame安装教程详解
2020/02/10 Python
马来西亚在线购物:POPLOOK.com
2019/12/09 全球购物
澳大利亚购买太阳镜和眼镜网站:Glamoureyes
2020/09/22 全球购物
护士实习生自我鉴定范文
2013/12/10 职场文书
单位未婚证明范本
2014/01/18 职场文书
爱国演讲稿400字
2014/05/07 职场文书
文体活动总结
2015/02/04 职场文书
大学军训通讯稿
2015/07/18 职场文书