Pytorch 抽取vgg各层并进行定制化处理的方法


Posted in Python onAugust 20, 2019

工作中有时候需要对vgg进行定制化处理,比如有些时候需要借助于vgg的层结构,但是需要使用的是2 channels输入,等等需求,这时候可以使用vgg的原始结构用class重写一遍,但是这样的方式比较慢,并且容易出错,下面给出一种比较简单的方式

def define_vgg(vgg,input_channels,endlayer,use_maxpool=False): 
  vgg_ad = copy.deepcopy(vgg)
  model = nn.Sequential()
  i = 0
  for layer in list(vgg_ad.features):
    if i > endlayer:
      break
    if isinstance(layer, nn.Conv2d) and i is 0:
      name = "conv_" + str(i)
      layer = nn.Conv2d(input_channels,
               layer.out_channels,
               layer.kernel_size,
               stride = layer.stride,
               padding=layer.padding)
      model.add_module(name, layer)
    if isinstance(layer, nn.Conv2d):
      name = "conv_" + str(i)
      model.add_module(name, layer)
 
    if isinstance(layer, nn.ReLU):
      name = "leakyrelu_" + str(i)
      layer = nn.LeakyReLU(inplace=True) 
      model.add_module(name, layer)
 
    if isinstance(layer, nn.MaxPool2d):
      name = "pool_" + str(i)
      if use_maxpool:
        model.add_module(name, layer)
      else:
        avgpool = nn.AvgPool2d(kernel_size=layer.kernel_size, stride=layer.stride, padding=layer.padding)
        model.add_module(name, avgpool)
    i += 1
  return model

函数输入项中的vgg 是直接使用的import torchvision.models.vgg16 传入的是vgg16 非预训练版本。end_layer 是需要提取的层数,这里使用了vgg.features 是指仅仅在vgg.features 上进行层的提取;也可以根据定制在classifier上进行提取。

下面是我的一个提取前7层的示例,可以使用pyCharm evaluate 上面函数返回的model,可以看到这个示例的情况,这里我的定制条件是输入通道为2 ,需要提取前7层,并且将ReLu更换为LeakyRelu。

Sequential(
 (conv_0): Conv2d(2, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (leakyrelu_1): LeakyReLU(negative_slope=0.01, inplace)
 (conv_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (leakyrelu_3): LeakyReLU(negative_slope=0.01, inplace)
 (pool_4): AvgPool2d(kernel_size=2, stride=2, padding=0)
 (conv_5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (leakyrelu_6): LeakyReLU(negative_slope=0.01, inplace)
 (conv_7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)

以上这篇Pytorch 抽取vgg各层并进行定制化处理的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现ip查询示例
Mar 26 Python
Pycharm远程调试openstack的方法
Nov 21 Python
Python实现矩阵加法和乘法的方法分析
Dec 19 Python
轻松实现TensorFlow微信跳一跳的AI
Jan 05 Python
python中的内置函数max()和min()及mas()函数的高级用法
Mar 29 Python
如何使用Python的Requests包实现模拟登陆
Apr 27 Python
基于python的ini配置文件操作工具类
Apr 24 Python
利用selenium爬虫抓取数据的基础教程
Jun 10 Python
python针对mysql数据库的连接、查询、更新、删除操作示例
Sep 11 Python
在Python中使用MySQL--PyMySQL的基本使用方法
Nov 19 Python
Pycharm连接gitlab实现过程图解
Sep 01 Python
Python按顺序遍历并读取文件夹中文件
Apr 29 Python
python实现抠图给证件照换背景源码
Aug 20 #Python
python爬虫 基于requests模块发起ajax的get请求实现解析
Aug 20 #Python
pytorch 在sequential中使用view来reshape的例子
Aug 20 #Python
pytorch在fintune时将sequential中的层输出方法,以vgg为例
Aug 20 #Python
python实现证件照换底功能
Aug 20 #Python
pytorch多进程加速及代码优化方法
Aug 19 #Python
用Pytorch训练CNN(数据集MNIST,使用GPU的方法)
Aug 19 #Python
You might like
一个PHP日历程序
2006/12/06 PHP
ThinkPHP实现登录退出功能
2017/06/29 PHP
jquery中的sortable排序之后的保存状态的解决方法
2010/01/28 Javascript
通过Jquery遍历Json的两种数据结构的实现代码
2011/01/19 Javascript
jQuery之end()和pushStack()使用介绍
2012/02/07 Javascript
node.js中的fs.fstatSync方法使用说明
2014/12/15 Javascript
鼠标事件的screenY,pageY,clientY,layerY,offsetY属性详解
2015/03/12 Javascript
JavaScript中的getTime()方法使用详解
2015/06/10 Javascript
深入解析Backbone.js框架的依赖库Underscore.js的作用
2016/05/07 Javascript
JS基于构造函数实现的菜单滑动显隐效果【测试可用】
2016/06/21 Javascript
Angularjs自定义指令Directive详解
2017/05/27 Javascript
详解Angular2响应式表单
2017/06/14 Javascript
JS实现延迟隐藏功能的方法(类似QQ头像鼠标放上展示信息)
2017/12/28 Javascript
vue 中引用gojs绘制E-R图的方法示例
2018/08/24 Javascript
小程序视频或音频自定义可拖拽进度条的示例代码
2018/09/30 Javascript
vue前后分离调起微信支付
2019/07/29 Javascript
vue element 关闭当前tab 跳转到上一路由操作
2020/07/22 Javascript
vue-cli3项目打包后自动化部署到服务器的方法
2020/09/16 Javascript
[02:40]DOTA2英雄基础教程 先知
2013/11/29 DOTA
[00:32]2018DOTA2亚洲邀请赛出场——VP
2018/04/04 DOTA
Python使用scrapy采集数据时为每个请求随机分配user-agent的方法
2015/04/08 Python
简单的python后台管理程序
2017/04/13 Python
python实现kNN算法
2017/12/20 Python
PyQt5每天必学之创建窗口居中效果
2018/04/19 Python
python一键去抖音视频水印工具
2018/09/14 Python
pandas去重复行并分类汇总的实现方法
2019/01/29 Python
Python中typing模块与类型注解的使用方法
2019/08/05 Python
Python 多线程其他属性以及继承Thread类详解
2019/08/28 Python
意大利比基尼品牌:MISS BIKINI
2019/11/02 全球购物
大学军训感言200字
2014/02/26 职场文书
保护环境建议书
2014/03/12 职场文书
2015新生加入学生会自荐书
2015/03/24 职场文书
检察院起诉书
2015/05/20 职场文书
迎国庆主题班会
2015/08/17 职场文书
Java 语言中Object 类和System 类详解
2021/07/07 Java/Android
Python制作动态字符画的源码
2021/08/04 Python