pytorch在fintune时将sequential中的层输出方法,以vgg为例


Posted in Python onAugust 20, 2019

有时候我们在fintune时发现pytorch把许多层都集合在一个sequential里,但是我们希望能把中间层的结果引出来做下一步操作,于是我自己琢磨了一个方法,以vgg为例,有点僵硬哈!

首先pytorch自带的vgg16模型的网络结构如下:

VGG(
 (features): Sequential(
 (0): Conv2d (3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): ReLU(inplace)
 (2): Conv2d (64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (3): ReLU(inplace)
 (4): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (5): Conv2d (64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (6): ReLU(inplace)
 (7): Conv2d (128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (8): ReLU(inplace)
 (9): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (10): Conv2d (128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (11): ReLU(inplace)
 (12): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (13): ReLU(inplace)
 (14): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (15): ReLU(inplace)
 (16): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (17): Conv2d (256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (18): ReLU(inplace)
 (19): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (20): ReLU(inplace)
 (21): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (22): ReLU(inplace)
 (23): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (24): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (25): ReLU(inplace)
 (26): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (27): ReLU(inplace)
 (28): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (29): ReLU(inplace)
 (30): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 )
 (classifier): Sequential(
 (0): Linear(in_features=25088, out_features=4096)
 (1): ReLU(inplace)
 (2): Dropout(p=0.5)
 (3): Linear(in_features=4096, out_features=4096)
 (4): ReLU(inplace)
 (5): Dropout(p=0.5)
 (6): Linear(in_features=4096, out_features=1000)
 )
)

我们需要fintune vgg16的features部分,并且我希望把3,8, 15, 22, 29这五个作为输出进一步操作。我的想法是自己写一个vgg网络,这个网络参数与pytorch的网络一致但是保证我们需要的层输出在sequential外。于是我写的网络如下:

class our_vgg(nn.Module):
 def __init__(self):
  super(our_vgg, self).__init__()
  self.conv1 = nn.Sequential(
   # conv1
   nn.Conv2d(3, 64, 3, padding=35),
   nn.ReLU(inplace=True),
   nn.Conv2d(64, 64, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv2 = nn.Sequential(
   # conv2
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/2
   nn.Conv2d(64, 128, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(128, 128, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv3 = nn.Sequential(
   # conv3
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/4
   nn.Conv2d(128, 256, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(256, 256, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(256, 256, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv4 = nn.Sequential(
   # conv4
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/8
   nn.Conv2d(256, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv5 = nn.Sequential(
   # conv5
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/16
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
  )


 def forward(self, x):

  conv1 = self.conv1(x)
  conv2 = self.conv2(conv1)
  conv3 = self.conv3(conv2)
  conv4 = self.conv4(conv3)
  conv5 = self.conv5(conv4)

  return conv5

接着就是copy weights了:

def convert_vgg(vgg16):#vgg16是pytorch自带的
 net = our_vgg()# 我写的vgg

 vgg_items = net.state_dict().items()
 vgg16_items = vgg16.items()

 pretrain_model = {}
 j = 0
 for k, v in net.state_dict().iteritems():#按顺序依次填入
  v = vgg16_items[j][1]
  k = vgg_items[j][0]
  pretrain_model[k] = v
  j += 1
 return pretrain_model


## net是我们最后使用的网络,也是我们想要放置weights的网络
net = net()

print ('load the weight from vgg')
pretrained_dict = torch.load('vgg16.pth')
pretrained_dict = convert_vgg(pretrained_dict)
model_dict = net.state_dict()
# 1. 把不属于我们需要的层剔除
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. 把参数存入已经存在的model_dict
model_dict.update(pretrained_dict) 
# 3. 加载更新后的model_dict
net.load_state_dict(model_dict)
print ('copy the weight sucessfully')

这样我就基本达成目标了,注意net也就是我们要使用的网络fintune部分需要和our_vgg一致。

以上这篇pytorch在fintune时将sequential中的层输出方法,以vgg为例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用python检测主机存活端口及检查存活主机
Oct 12 Python
轻松掌握python设计模式之访问者模式
Nov 18 Python
Python变量和数据类型详解
Feb 15 Python
定制FileField中的上传文件名称实例
Aug 23 Python
Python序列化基础知识(json/pickle)
Oct 19 Python
tensorflow 使用flags定义命令行参数的方法
Apr 23 Python
解决nohup重定向python输出到文件不成功的问题
May 11 Python
python 处理微信对账单数据的实例代码
Jul 19 Python
程序员的七夕用30行代码让Python化身表白神器
Aug 07 Python
简单了解python中的与或非运算
Sep 18 Python
在keras 中获取张量 tensor 的维度大小实例
Jun 10 Python
Python接口自动化系列之unittest结合ddt的使用教程详解
Feb 23 Python
python实现证件照换底功能
Aug 20 #Python
pytorch多进程加速及代码优化方法
Aug 19 #Python
用Pytorch训练CNN(数据集MNIST,使用GPU的方法)
Aug 19 #Python
python opencv实现证件照换底功能
Aug 19 #Python
解决pytorch GPU 计算过程中出现内存耗尽的问题
Aug 19 #Python
将Pytorch模型从CPU转换成GPU的实现方法
Aug 19 #Python
pytorch 使用单个GPU与多个GPU进行训练与测试的方法
Aug 19 #Python
You might like
php中常用字符串处理代码片段整理
2011/11/07 PHP
适用于初学者的简易PHP文件上传类
2015/10/29 PHP
php模拟post上传图片实现代码
2016/06/24 PHP
PHP简单实现二维数组赋值与遍历功能示例
2017/10/19 PHP
JavaScript中的对象化编程
2008/01/16 Javascript
javascript function、指针及内置对象
2009/02/19 Javascript
js的写法基础分析
2011/01/17 Javascript
推荐17个优美新鲜的jQuery的工具提示插件
2012/09/14 Javascript
JS正则表达式获取分组内容的方法详解
2013/11/15 Javascript
JavaScript italics方法入门实例(把字符串显示为斜体)
2014/10/17 Javascript
JavaScript在Android的WebView中parseInt函数转换不正确问题解决方法
2015/04/25 Javascript
Jquery attr()方法 属性赋值和属性获取详解
2016/04/15 Javascript
深入浅析knockout源码分析之订阅
2016/07/12 Javascript
关于jQuery.ajax()的jsonp碰上post详解
2017/07/02 jQuery
vue组件watch属性实例讲解
2017/11/07 Javascript
详解webpack与SPA实践之开发环境搭建
2017/12/18 Javascript
vue router总结 $router和$route及router与 router与route区别
2019/07/05 Javascript
vue下使用nginx刷新页面404的问题解决
2019/08/02 Javascript
如何基于javascript实现贪吃蛇游戏
2020/02/09 Javascript
JS实现公告上线滚动效果
2021/01/10 Javascript
[53:44]DOTA2-DPC中国联赛 正赛 PSG.LGD vs Magma BO3 第一场 1月31日
2021/03/11 DOTA
python在不同层级目录import模块的方法
2016/01/31 Python
Python字符串逆序输出的实例讲解
2019/02/16 Python
python3 反射的四种基本方法解析
2019/08/26 Python
python rsa实现数据加密和解密、签名加密和验签功能
2019/09/18 Python
python读取yaml文件后修改写入本地实例
2020/04/27 Python
Python3自动生成MySQL数据字典的markdown文本的实现
2020/05/07 Python
CSS3中使用RGBA设置透明度的示例
2015/08/04 HTML / CSS
儿科护士自我鉴定
2013/10/14 职场文书
中学生运动会入场词
2014/02/12 职场文书
乡镇办公室工作决心书
2014/03/11 职场文书
幼儿园父亲节活动方案
2014/03/11 职场文书
2015年英语教师工作总结
2015/05/20 职场文书
2015年社区工会工作总结
2015/05/26 职场文书
宝葫芦的秘密观后感
2015/06/11 职场文书
Ajax实现三级联动效果
2021/10/05 Javascript