pytorch在fintune时将sequential中的层输出方法,以vgg为例


Posted in Python onAugust 20, 2019

有时候我们在fintune时发现pytorch把许多层都集合在一个sequential里,但是我们希望能把中间层的结果引出来做下一步操作,于是我自己琢磨了一个方法,以vgg为例,有点僵硬哈!

首先pytorch自带的vgg16模型的网络结构如下:

VGG(
 (features): Sequential(
 (0): Conv2d (3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): ReLU(inplace)
 (2): Conv2d (64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (3): ReLU(inplace)
 (4): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (5): Conv2d (64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (6): ReLU(inplace)
 (7): Conv2d (128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (8): ReLU(inplace)
 (9): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (10): Conv2d (128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (11): ReLU(inplace)
 (12): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (13): ReLU(inplace)
 (14): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (15): ReLU(inplace)
 (16): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (17): Conv2d (256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (18): ReLU(inplace)
 (19): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (20): ReLU(inplace)
 (21): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (22): ReLU(inplace)
 (23): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (24): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (25): ReLU(inplace)
 (26): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (27): ReLU(inplace)
 (28): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (29): ReLU(inplace)
 (30): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 )
 (classifier): Sequential(
 (0): Linear(in_features=25088, out_features=4096)
 (1): ReLU(inplace)
 (2): Dropout(p=0.5)
 (3): Linear(in_features=4096, out_features=4096)
 (4): ReLU(inplace)
 (5): Dropout(p=0.5)
 (6): Linear(in_features=4096, out_features=1000)
 )
)

我们需要fintune vgg16的features部分,并且我希望把3,8, 15, 22, 29这五个作为输出进一步操作。我的想法是自己写一个vgg网络,这个网络参数与pytorch的网络一致但是保证我们需要的层输出在sequential外。于是我写的网络如下:

class our_vgg(nn.Module):
 def __init__(self):
  super(our_vgg, self).__init__()
  self.conv1 = nn.Sequential(
   # conv1
   nn.Conv2d(3, 64, 3, padding=35),
   nn.ReLU(inplace=True),
   nn.Conv2d(64, 64, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv2 = nn.Sequential(
   # conv2
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/2
   nn.Conv2d(64, 128, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(128, 128, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv3 = nn.Sequential(
   # conv3
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/4
   nn.Conv2d(128, 256, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(256, 256, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(256, 256, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv4 = nn.Sequential(
   # conv4
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/8
   nn.Conv2d(256, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv5 = nn.Sequential(
   # conv5
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/16
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
  )


 def forward(self, x):

  conv1 = self.conv1(x)
  conv2 = self.conv2(conv1)
  conv3 = self.conv3(conv2)
  conv4 = self.conv4(conv3)
  conv5 = self.conv5(conv4)

  return conv5

接着就是copy weights了:

def convert_vgg(vgg16):#vgg16是pytorch自带的
 net = our_vgg()# 我写的vgg

 vgg_items = net.state_dict().items()
 vgg16_items = vgg16.items()

 pretrain_model = {}
 j = 0
 for k, v in net.state_dict().iteritems():#按顺序依次填入
  v = vgg16_items[j][1]
  k = vgg_items[j][0]
  pretrain_model[k] = v
  j += 1
 return pretrain_model


## net是我们最后使用的网络,也是我们想要放置weights的网络
net = net()

print ('load the weight from vgg')
pretrained_dict = torch.load('vgg16.pth')
pretrained_dict = convert_vgg(pretrained_dict)
model_dict = net.state_dict()
# 1. 把不属于我们需要的层剔除
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. 把参数存入已经存在的model_dict
model_dict.update(pretrained_dict) 
# 3. 加载更新后的model_dict
net.load_state_dict(model_dict)
print ('copy the weight sucessfully')

这样我就基本达成目标了,注意net也就是我们要使用的网络fintune部分需要和our_vgg一致。

以上这篇pytorch在fintune时将sequential中的层输出方法,以vgg为例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python正则表达式匹配ip地址实例
Oct 09 Python
Python中字符串对齐方法介绍
May 21 Python
python中PIL安装简单教程
Apr 21 Python
windows10系统中安装python3.x+scrapy教程
Nov 08 Python
python与php实现分割文件代码
Mar 06 Python
Python实现图片尺寸缩放脚本
Mar 10 Python
python dataframe常见操作方法:实现取行、列、切片、统计特征值
Jun 09 Python
ORM Django 终端打印 SQL 语句实现解析
Aug 09 Python
在pytorch中为Module和Tensor指定GPU的例子
Aug 19 Python
Python cookie的保存与读取、SSL讲解
Feb 17 Python
pandas数据拼接的实现示例
Apr 16 Python
Python面向对象多态实现原理及代码实例
Sep 16 Python
python实现证件照换底功能
Aug 20 #Python
pytorch多进程加速及代码优化方法
Aug 19 #Python
用Pytorch训练CNN(数据集MNIST,使用GPU的方法)
Aug 19 #Python
python opencv实现证件照换底功能
Aug 19 #Python
解决pytorch GPU 计算过程中出现内存耗尽的问题
Aug 19 #Python
将Pytorch模型从CPU转换成GPU的实现方法
Aug 19 #Python
pytorch 使用单个GPU与多个GPU进行训练与测试的方法
Aug 19 #Python
You might like
PHP技术开发技巧分享
2010/03/23 PHP
PHP+MySQL统计该库中每个表的记录数并按递减顺序排列的方法
2016/02/15 PHP
Zend Framework框架教程之Zend_Db_Table_Rowset用法实例分析
2016/03/21 PHP
PHP在linux上执行外部命令的方法
2017/02/06 PHP
解决安装WampServer时提示缺少msvcr110.dll文件的问题
2017/07/09 PHP
Aster vs KG BO3 第二场2.19
2021/03/10 DOTA
翻译整理的jQuery使用查询手册
2007/03/07 Javascript
location.href语句与火狐不兼容的问题
2010/07/04 Javascript
MooTools 页面滚动浮动层智能定位实现代码
2011/08/23 Javascript
js去除重复字符串两种实现方法
2013/01/09 Javascript
JQuery的read函数与js的onload不同方式实现
2013/03/18 Javascript
js实现div闪烁原理及实现代码
2014/06/24 Javascript
谈谈Jquery中的children find 的区别有哪些
2015/10/19 Javascript
Vue.js实现一个自定义分页组件vue-paginaiton
2016/09/05 Javascript
关于vue.js弹窗组件的知识点总结
2016/09/11 Javascript
浅谈js函数中的实例对象、类对象、局部变量(局部函数)
2016/11/20 Javascript
探讨跨域请求资源的几种方式(总结)
2016/12/02 Javascript
javascript循环链表之约瑟夫环的实现方法
2017/01/16 Javascript
JS实现加载时锁定HTML页面元素的方法
2017/06/24 Javascript
JavaScript实现微信号随机切换代码
2018/03/09 Javascript
使用Node搭建reactSSR服务端渲染架构
2018/08/30 Javascript
vue-cli的build的文件夹下没有dev-server.js文件配置mock数据的方法
2019/04/17 Javascript
javascript使用substring实现的展开与收缩文字功能示例
2019/06/17 Javascript
[03:51]吞吞映像 每周精彩击杀top10第二弹
2014/06/25 DOTA
[38:39]完美世界DOTA2联赛循环赛 IO vs GXR BO2第二场 11.04
2020/11/05 DOTA
详细讲解Python中的文件I/O操作
2015/05/24 Python
在Python中使用turtle绘制多个同心圆示例
2019/11/23 Python
python访问hdfs的操作
2020/06/06 Python
python爬虫如何解决图片验证码
2021/02/14 Python
韩国著名的在线综合购物网站:Akmall
2016/08/07 全球购物
家庭户外服装:Hawkshead
2017/11/02 全球购物
授权委托书怎么写
2014/04/03 职场文书
诚实守信道德模范事迹材料
2014/08/15 职场文书
2015年组织部工作总结
2015/04/03 职场文书
房屋产权证明书
2015/06/19 职场文书
纪检监察立案决定书
2015/06/24 职场文书