pytorch在fintune时将sequential中的层输出方法,以vgg为例


Posted in Python onAugust 20, 2019

有时候我们在fintune时发现pytorch把许多层都集合在一个sequential里,但是我们希望能把中间层的结果引出来做下一步操作,于是我自己琢磨了一个方法,以vgg为例,有点僵硬哈!

首先pytorch自带的vgg16模型的网络结构如下:

VGG(
 (features): Sequential(
 (0): Conv2d (3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): ReLU(inplace)
 (2): Conv2d (64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (3): ReLU(inplace)
 (4): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (5): Conv2d (64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (6): ReLU(inplace)
 (7): Conv2d (128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (8): ReLU(inplace)
 (9): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (10): Conv2d (128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (11): ReLU(inplace)
 (12): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (13): ReLU(inplace)
 (14): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (15): ReLU(inplace)
 (16): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (17): Conv2d (256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (18): ReLU(inplace)
 (19): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (20): ReLU(inplace)
 (21): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (22): ReLU(inplace)
 (23): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (24): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (25): ReLU(inplace)
 (26): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (27): ReLU(inplace)
 (28): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (29): ReLU(inplace)
 (30): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 )
 (classifier): Sequential(
 (0): Linear(in_features=25088, out_features=4096)
 (1): ReLU(inplace)
 (2): Dropout(p=0.5)
 (3): Linear(in_features=4096, out_features=4096)
 (4): ReLU(inplace)
 (5): Dropout(p=0.5)
 (6): Linear(in_features=4096, out_features=1000)
 )
)

我们需要fintune vgg16的features部分,并且我希望把3,8, 15, 22, 29这五个作为输出进一步操作。我的想法是自己写一个vgg网络,这个网络参数与pytorch的网络一致但是保证我们需要的层输出在sequential外。于是我写的网络如下:

class our_vgg(nn.Module):
 def __init__(self):
  super(our_vgg, self).__init__()
  self.conv1 = nn.Sequential(
   # conv1
   nn.Conv2d(3, 64, 3, padding=35),
   nn.ReLU(inplace=True),
   nn.Conv2d(64, 64, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv2 = nn.Sequential(
   # conv2
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/2
   nn.Conv2d(64, 128, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(128, 128, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv3 = nn.Sequential(
   # conv3
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/4
   nn.Conv2d(128, 256, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(256, 256, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(256, 256, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv4 = nn.Sequential(
   # conv4
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/8
   nn.Conv2d(256, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv5 = nn.Sequential(
   # conv5
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/16
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
  )


 def forward(self, x):

  conv1 = self.conv1(x)
  conv2 = self.conv2(conv1)
  conv3 = self.conv3(conv2)
  conv4 = self.conv4(conv3)
  conv5 = self.conv5(conv4)

  return conv5

接着就是copy weights了:

def convert_vgg(vgg16):#vgg16是pytorch自带的
 net = our_vgg()# 我写的vgg

 vgg_items = net.state_dict().items()
 vgg16_items = vgg16.items()

 pretrain_model = {}
 j = 0
 for k, v in net.state_dict().iteritems():#按顺序依次填入
  v = vgg16_items[j][1]
  k = vgg_items[j][0]
  pretrain_model[k] = v
  j += 1
 return pretrain_model


## net是我们最后使用的网络,也是我们想要放置weights的网络
net = net()

print ('load the weight from vgg')
pretrained_dict = torch.load('vgg16.pth')
pretrained_dict = convert_vgg(pretrained_dict)
model_dict = net.state_dict()
# 1. 把不属于我们需要的层剔除
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. 把参数存入已经存在的model_dict
model_dict.update(pretrained_dict) 
# 3. 加载更新后的model_dict
net.load_state_dict(model_dict)
print ('copy the weight sucessfully')

这样我就基本达成目标了,注意net也就是我们要使用的网络fintune部分需要和our_vgg一致。

以上这篇pytorch在fintune时将sequential中的层输出方法,以vgg为例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
跟老齐学Python之大话题小函数(1)
Oct 10 Python
Python的Tornado框架实现图片上传及图片大小修改功能
Jun 30 Python
Golang与python线程详解及简单实例
Apr 27 Python
Python文件操作基本流程代码实例
Dec 11 Python
python3爬虫之设计签名小程序
Jun 19 Python
Python 的AES加密与解密实现
Jul 09 Python
Python3 使用pillow库生成随机验证码
Aug 26 Python
Python调用shell命令常用方法(4种)
May 11 Python
Python基于gevent实现高并发代码实例
May 15 Python
Python面向对象多态实现原理及代码实例
Sep 16 Python
用python获取txt文件中关键字的数量
Dec 24 Python
Python tkinter之ComboBox(下拉框)的使用简介
Feb 05 Python
python实现证件照换底功能
Aug 20 #Python
pytorch多进程加速及代码优化方法
Aug 19 #Python
用Pytorch训练CNN(数据集MNIST,使用GPU的方法)
Aug 19 #Python
python opencv实现证件照换底功能
Aug 19 #Python
解决pytorch GPU 计算过程中出现内存耗尽的问题
Aug 19 #Python
将Pytorch模型从CPU转换成GPU的实现方法
Aug 19 #Python
pytorch 使用单个GPU与多个GPU进行训练与测试的方法
Aug 19 #Python
You might like
PHP函数学习之PHP函数点评
2012/07/05 PHP
在php的yii2框架中整合hbase库的方法
2018/09/20 PHP
JavaScript 题型问答有答案参考
2010/02/17 Javascript
一个简单的js鼠标划过切换效果
2010/06/30 Javascript
JS实现下拉菜单赋值到文本框的方法
2015/08/18 Javascript
JS中call/apply、arguments、undefined/null方法详解
2016/02/15 Javascript
JavaScript基础知识点归纳(推荐)
2016/07/09 Javascript
Javascript获取background属性中url的值
2016/10/17 Javascript
xmlplus组件设计系列之分隔框(DividedBox)(8)
2017/05/02 Javascript
javascript checkbox/radio onchange不能兼容ie8处理办法
2017/06/13 Javascript
JavaScript获取用户所在城市及地理位置
2018/04/21 Javascript
React项目动态设置title标题的方法示例
2018/09/26 Javascript
Element-Ui组件 NavMenu 导航菜单的具体使用
2019/10/24 Javascript
[35:39]完美世界DOTA2联赛PWL S2 FTD.C vs Rebirth 第二场 11.22
2020/11/24 DOTA
python 多进程通信模块的简单实现
2014/02/20 Python
python获取外网ip地址的方法总结
2015/07/02 Python
python实现文本去重且不打乱原本顺序
2016/01/26 Python
Python 多线程实例详解
2017/03/25 Python
对python中词典的values值的修改或新增KEY详解
2019/01/20 Python
通过PYTHON来实现图像分割详解
2019/06/26 Python
python之拟合的实现
2019/07/19 Python
django数据库自动重连的方法实例
2019/07/21 Python
python实现生成Word、docx文件的方法分析
2019/08/30 Python
keras模型可视化,层可视化及kernel可视化实例
2020/01/24 Python
HTML5标签嵌套规则详解【必看】
2016/04/26 HTML / CSS
煤矿班组长的职责
2013/12/25 职场文书
人力资源专员岗位职责
2014/01/30 职场文书
陈欧广告词
2014/03/14 职场文书
相亲活动方案
2014/08/26 职场文书
国庆节促销广告语2014
2014/09/19 职场文书
机关副主任个人四风问题整改措施
2014/09/26 职场文书
机电专业毕业生自我鉴定2014
2014/10/04 职场文书
音乐教师个人工作总结
2015/02/06 职场文书
实习单位鉴定意见
2015/06/04 职场文书
MySQL外键约束(FOREIGN KEY)案例讲解
2021/08/23 MySQL
html5+实现plus.io进行拍照和图片等获取
2022/06/01 HTML / CSS