pytorch在fintune时将sequential中的层输出方法,以vgg为例


Posted in Python onAugust 20, 2019

有时候我们在fintune时发现pytorch把许多层都集合在一个sequential里,但是我们希望能把中间层的结果引出来做下一步操作,于是我自己琢磨了一个方法,以vgg为例,有点僵硬哈!

首先pytorch自带的vgg16模型的网络结构如下:

VGG(
 (features): Sequential(
 (0): Conv2d (3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): ReLU(inplace)
 (2): Conv2d (64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (3): ReLU(inplace)
 (4): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (5): Conv2d (64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (6): ReLU(inplace)
 (7): Conv2d (128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (8): ReLU(inplace)
 (9): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (10): Conv2d (128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (11): ReLU(inplace)
 (12): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (13): ReLU(inplace)
 (14): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (15): ReLU(inplace)
 (16): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (17): Conv2d (256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (18): ReLU(inplace)
 (19): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (20): ReLU(inplace)
 (21): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (22): ReLU(inplace)
 (23): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (24): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (25): ReLU(inplace)
 (26): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (27): ReLU(inplace)
 (28): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (29): ReLU(inplace)
 (30): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 )
 (classifier): Sequential(
 (0): Linear(in_features=25088, out_features=4096)
 (1): ReLU(inplace)
 (2): Dropout(p=0.5)
 (3): Linear(in_features=4096, out_features=4096)
 (4): ReLU(inplace)
 (5): Dropout(p=0.5)
 (6): Linear(in_features=4096, out_features=1000)
 )
)

我们需要fintune vgg16的features部分,并且我希望把3,8, 15, 22, 29这五个作为输出进一步操作。我的想法是自己写一个vgg网络,这个网络参数与pytorch的网络一致但是保证我们需要的层输出在sequential外。于是我写的网络如下:

class our_vgg(nn.Module):
 def __init__(self):
  super(our_vgg, self).__init__()
  self.conv1 = nn.Sequential(
   # conv1
   nn.Conv2d(3, 64, 3, padding=35),
   nn.ReLU(inplace=True),
   nn.Conv2d(64, 64, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv2 = nn.Sequential(
   # conv2
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/2
   nn.Conv2d(64, 128, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(128, 128, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv3 = nn.Sequential(
   # conv3
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/4
   nn.Conv2d(128, 256, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(256, 256, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(256, 256, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv4 = nn.Sequential(
   # conv4
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/8
   nn.Conv2d(256, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv5 = nn.Sequential(
   # conv5
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/16
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
  )


 def forward(self, x):

  conv1 = self.conv1(x)
  conv2 = self.conv2(conv1)
  conv3 = self.conv3(conv2)
  conv4 = self.conv4(conv3)
  conv5 = self.conv5(conv4)

  return conv5

接着就是copy weights了:

def convert_vgg(vgg16):#vgg16是pytorch自带的
 net = our_vgg()# 我写的vgg

 vgg_items = net.state_dict().items()
 vgg16_items = vgg16.items()

 pretrain_model = {}
 j = 0
 for k, v in net.state_dict().iteritems():#按顺序依次填入
  v = vgg16_items[j][1]
  k = vgg_items[j][0]
  pretrain_model[k] = v
  j += 1
 return pretrain_model


## net是我们最后使用的网络,也是我们想要放置weights的网络
net = net()

print ('load the weight from vgg')
pretrained_dict = torch.load('vgg16.pth')
pretrained_dict = convert_vgg(pretrained_dict)
model_dict = net.state_dict()
# 1. 把不属于我们需要的层剔除
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. 把参数存入已经存在的model_dict
model_dict.update(pretrained_dict) 
# 3. 加载更新后的model_dict
net.load_state_dict(model_dict)
print ('copy the weight sucessfully')

这样我就基本达成目标了,注意net也就是我们要使用的网络fintune部分需要和our_vgg一致。

以上这篇pytorch在fintune时将sequential中的层输出方法,以vgg为例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python发送邮件示例(支持中文邮件标题)
Feb 16 Python
python连接mysql实例分享
Oct 09 Python
python3.6.3+opencv3.3.0实现动态人脸捕获
May 25 Python
Python全局变量与局部变量区别及用法分析
Sep 03 Python
Django之无名分组和有名分组的实现
Apr 16 Python
python写日志文件操作类与应用示例
Jul 01 Python
让Python脚本暂停执行的几种方法(小结)
Jul 11 Python
python中的global关键字的使用方法
Aug 20 Python
django中使用POST方法获取POST数据
Aug 20 Python
python 绘制正态曲线的示例
Sep 24 Python
多个版本的python共存时使用pip的正确做法
Oct 26 Python
Python tkinter实现日期选择器
Feb 22 Python
python实现证件照换底功能
Aug 20 #Python
pytorch多进程加速及代码优化方法
Aug 19 #Python
用Pytorch训练CNN(数据集MNIST,使用GPU的方法)
Aug 19 #Python
python opencv实现证件照换底功能
Aug 19 #Python
解决pytorch GPU 计算过程中出现内存耗尽的问题
Aug 19 #Python
将Pytorch模型从CPU转换成GPU的实现方法
Aug 19 #Python
pytorch 使用单个GPU与多个GPU进行训练与测试的方法
Aug 19 #Python
You might like
php strstr查找字符串中是否包含某些字符的查找函数
2010/06/03 PHP
php中判断文件存在是用file_exists还是is_file的整理
2012/09/12 PHP
深入认识JavaScript中的函数
2007/01/22 Javascript
JavaScript之编码规范 推荐
2012/05/23 Javascript
jQuery 处理页面的事件详解
2015/01/20 Javascript
浏览器中url存储的JavaScript实现
2015/07/07 Javascript
JavaScript中的Function函数
2015/08/27 Javascript
JavaScript动态生成二维码图片
2016/04/20 Javascript
javascript简单实现等比例缩小图片的方法
2016/07/27 Javascript
jQuery ztree实现动态树形多选菜单
2016/08/12 Javascript
ES6教程之for循环和Map,Set用法分析
2017/04/10 Javascript
详解Vue.js在页面加载时执行某个方法
2018/11/20 Javascript
解决vue 界面在苹果手机上滑动点击事件等卡顿问题
2018/11/27 Javascript
微信小程序学习笔记之目录结构、基本配置图文详解
2019/03/28 Javascript
jQuery zTree树插件的使用教程
2019/08/16 jQuery
Javascript新手入门之字符串拼接与变量的应用
2020/12/03 Javascript
[02:04]2014DOTA2国际邀请赛 DK一个时代的落幕
2014/07/21 DOTA
[50:59]2018DOTA2亚洲邀请赛 4.7 总决赛 LGD vs Mineski第四场
2018/04/10 DOTA
[43:41]OG vs Newbee 2019国际邀请赛淘汰赛 胜者组 BO3 第一场 8.21.mp4
2020/07/19 DOTA
Python使用Scrapy爬取妹子图
2015/05/28 Python
Python使用pymysql小技巧
2017/06/04 Python
pycharm修改file type方式
2019/11/19 Python
Python爬虫爬取煎蛋网图片代码实例
2019/12/16 Python
python实现猜拳游戏
2020/03/04 Python
Python使用20行代码实现微信聊天机器人
2020/06/05 Python
基于python实现简单C/S模式代码实例
2020/09/14 Python
暇步士官网:Hush Puppies
2016/09/22 全球购物
法国珠宝店:CLEOR
2017/01/29 全球购物
物业管理计划书
2014/01/10 职场文书
大学秋游活动方案
2014/02/11 职场文书
《庐山的云雾》教学反思
2014/04/22 职场文书
促销活动计划书
2014/05/02 职场文书
2014年文秘工作总结
2014/11/25 职场文书
新娘父亲婚礼致辞
2015/07/27 职场文书
浅谈Python类的单继承相关知识
2021/05/12 Python
CSS中calc(100%-100px)不加空格不生效
2023/05/07 HTML / CSS