Pytorch 抽取vgg各层并进行定制化处理的方法


Posted in Python onAugust 20, 2019

工作中有时候需要对vgg进行定制化处理,比如有些时候需要借助于vgg的层结构,但是需要使用的是2 channels输入,等等需求,这时候可以使用vgg的原始结构用class重写一遍,但是这样的方式比较慢,并且容易出错,下面给出一种比较简单的方式

def define_vgg(vgg,input_channels,endlayer,use_maxpool=False): 
  vgg_ad = copy.deepcopy(vgg)
  model = nn.Sequential()
  i = 0
  for layer in list(vgg_ad.features):
    if i > endlayer:
      break
    if isinstance(layer, nn.Conv2d) and i is 0:
      name = "conv_" + str(i)
      layer = nn.Conv2d(input_channels,
               layer.out_channels,
               layer.kernel_size,
               stride = layer.stride,
               padding=layer.padding)
      model.add_module(name, layer)
    if isinstance(layer, nn.Conv2d):
      name = "conv_" + str(i)
      model.add_module(name, layer)
 
    if isinstance(layer, nn.ReLU):
      name = "leakyrelu_" + str(i)
      layer = nn.LeakyReLU(inplace=True) 
      model.add_module(name, layer)
 
    if isinstance(layer, nn.MaxPool2d):
      name = "pool_" + str(i)
      if use_maxpool:
        model.add_module(name, layer)
      else:
        avgpool = nn.AvgPool2d(kernel_size=layer.kernel_size, stride=layer.stride, padding=layer.padding)
        model.add_module(name, avgpool)
    i += 1
  return model

函数输入项中的vgg 是直接使用的import torchvision.models.vgg16 传入的是vgg16 非预训练版本。end_layer 是需要提取的层数,这里使用了vgg.features 是指仅仅在vgg.features 上进行层的提取;也可以根据定制在classifier上进行提取。

下面是我的一个提取前7层的示例,可以使用pyCharm evaluate 上面函数返回的model,可以看到这个示例的情况,这里我的定制条件是输入通道为2 ,需要提取前7层,并且将ReLu更换为LeakyRelu。

Sequential(
 (conv_0): Conv2d(2, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (leakyrelu_1): LeakyReLU(negative_slope=0.01, inplace)
 (conv_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (leakyrelu_3): LeakyReLU(negative_slope=0.01, inplace)
 (pool_4): AvgPool2d(kernel_size=2, stride=2, padding=0)
 (conv_5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (leakyrelu_6): LeakyReLU(negative_slope=0.01, inplace)
 (conv_7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)

以上这篇Pytorch 抽取vgg各层并进行定制化处理的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中使用items()方法返回字典元素对的教程
May 21 Python
Django框架中render_to_response()函数的使用方法
Jul 16 Python
python开发之IDEL(Python GUI)的使用方法图文详解
Nov 12 Python
Python实现简易端口扫描器代码实例
Mar 15 Python
Python 含参构造函数实例详解
May 25 Python
Python3中关于cookie的创建与保存
Oct 21 Python
pandas DataFrame索引行列的实现
Jun 04 Python
python3 tkinter实现添加图片和文本
Nov 26 Python
python函数map()和partial()的知识点总结
May 26 Python
Python偏函数Partial function使用方法实例详解
Jun 17 Python
django教程如何自学
Jul 31 Python
用Python生成会跳舞的美女
Jan 18 Python
python实现抠图给证件照换背景源码
Aug 20 #Python
python爬虫 基于requests模块发起ajax的get请求实现解析
Aug 20 #Python
pytorch 在sequential中使用view来reshape的例子
Aug 20 #Python
pytorch在fintune时将sequential中的层输出方法,以vgg为例
Aug 20 #Python
python实现证件照换底功能
Aug 20 #Python
pytorch多进程加速及代码优化方法
Aug 19 #Python
用Pytorch训练CNN(数据集MNIST,使用GPU的方法)
Aug 19 #Python
You might like
php并发对MYSQL造成压力的解决方法
2013/02/21 PHP
Yii分页用法实例详解
2014/12/04 PHP
php mysqli查询语句返回值类型实例分析
2016/06/29 PHP
swoole锁的机制代码实例讲解
2021/03/04 PHP
jQuery 表单验证扩展代码(二)
2010/10/20 Javascript
Ext对基本类型的扩展 ext,extjs,format
2010/12/25 Javascript
ExtJS4中的requires使用方法示例介绍
2013/12/03 Javascript
原生Ajax 和jQuery Ajax的区别示例分析
2014/12/17 Javascript
avalon js实现仿google plus图片多张拖动排序附源码下载
2015/09/24 Javascript
jQuery mobile转换url地址及获取url中目录部分的方法
2015/12/04 Javascript
gameboy网页闯关游戏(riddle webgame)--仿微信聊天的前端页面设计和难点
2016/02/21 Javascript
原生JS实现旋转木马式图片轮播插件
2016/04/25 Javascript
js判断所有表单项不为空则提交表单的实现方法
2016/09/09 Javascript
移动端js触摸事件详解
2016/09/18 Javascript
简单理解vue中track-by属性
2016/10/26 Javascript
CodeMirror js代码加亮使用总结
2017/03/25 Javascript
VUE2.0中Jsonp的使用方法
2018/05/22 Javascript
当vue路由变化时,改变导航栏的样式方法
2018/08/22 Javascript
JavaScript使用indexOf()实现数组去重的方法分析
2018/09/04 Javascript
浅谈vue项目打包优化策略
2018/09/29 Javascript
JS使用正则表达式提交页面验证的代码
2019/10/16 Javascript
Node.js API详解之 net模块实例分析
2020/05/18 Javascript
VUE实时监听元素距离顶部高度的操作
2020/07/29 Javascript
vue中渲染对象中属性时显示未定义的解决
2020/07/31 Javascript
Python中的XML库4Suite Server的介绍
2015/04/14 Python
深入了解Python数据类型之列表
2016/06/24 Python
python-opencv 将连续图片写成视频格式的方法
2019/01/08 Python
django的模型类管理器——数据库操作的封装详解
2020/04/01 Python
德购商城:德国进口直邮商城
2017/06/13 全球购物
Nice Kicks网上商店:ShopNiceKicks.com
2018/12/25 全球购物
Glamest意大利:女性在线奢侈品零售店
2019/04/28 全球购物
自荐信格式范文
2013/10/07 职场文书
高中毕业的自我鉴定
2013/12/09 职场文书
班组长安全职责
2014/01/05 职场文书
活动志愿者自荐信
2014/01/27 职场文书
2015年前台文员工作总结
2015/05/18 职场文书