pytorch在fintune时将sequential中的层输出方法,以vgg为例


Posted in Python onAugust 20, 2019

有时候我们在fintune时发现pytorch把许多层都集合在一个sequential里,但是我们希望能把中间层的结果引出来做下一步操作,于是我自己琢磨了一个方法,以vgg为例,有点僵硬哈!

首先pytorch自带的vgg16模型的网络结构如下:

VGG(
 (features): Sequential(
 (0): Conv2d (3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): ReLU(inplace)
 (2): Conv2d (64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (3): ReLU(inplace)
 (4): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (5): Conv2d (64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (6): ReLU(inplace)
 (7): Conv2d (128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (8): ReLU(inplace)
 (9): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (10): Conv2d (128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (11): ReLU(inplace)
 (12): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (13): ReLU(inplace)
 (14): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (15): ReLU(inplace)
 (16): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (17): Conv2d (256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (18): ReLU(inplace)
 (19): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (20): ReLU(inplace)
 (21): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (22): ReLU(inplace)
 (23): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (24): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (25): ReLU(inplace)
 (26): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (27): ReLU(inplace)
 (28): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (29): ReLU(inplace)
 (30): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 )
 (classifier): Sequential(
 (0): Linear(in_features=25088, out_features=4096)
 (1): ReLU(inplace)
 (2): Dropout(p=0.5)
 (3): Linear(in_features=4096, out_features=4096)
 (4): ReLU(inplace)
 (5): Dropout(p=0.5)
 (6): Linear(in_features=4096, out_features=1000)
 )
)

我们需要fintune vgg16的features部分,并且我希望把3,8, 15, 22, 29这五个作为输出进一步操作。我的想法是自己写一个vgg网络,这个网络参数与pytorch的网络一致但是保证我们需要的层输出在sequential外。于是我写的网络如下:

class our_vgg(nn.Module):
 def __init__(self):
  super(our_vgg, self).__init__()
  self.conv1 = nn.Sequential(
   # conv1
   nn.Conv2d(3, 64, 3, padding=35),
   nn.ReLU(inplace=True),
   nn.Conv2d(64, 64, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv2 = nn.Sequential(
   # conv2
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/2
   nn.Conv2d(64, 128, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(128, 128, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv3 = nn.Sequential(
   # conv3
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/4
   nn.Conv2d(128, 256, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(256, 256, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(256, 256, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv4 = nn.Sequential(
   # conv4
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/8
   nn.Conv2d(256, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv5 = nn.Sequential(
   # conv5
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/16
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
  )


 def forward(self, x):

  conv1 = self.conv1(x)
  conv2 = self.conv2(conv1)
  conv3 = self.conv3(conv2)
  conv4 = self.conv4(conv3)
  conv5 = self.conv5(conv4)

  return conv5

接着就是copy weights了:

def convert_vgg(vgg16):#vgg16是pytorch自带的
 net = our_vgg()# 我写的vgg

 vgg_items = net.state_dict().items()
 vgg16_items = vgg16.items()

 pretrain_model = {}
 j = 0
 for k, v in net.state_dict().iteritems():#按顺序依次填入
  v = vgg16_items[j][1]
  k = vgg_items[j][0]
  pretrain_model[k] = v
  j += 1
 return pretrain_model


## net是我们最后使用的网络,也是我们想要放置weights的网络
net = net()

print ('load the weight from vgg')
pretrained_dict = torch.load('vgg16.pth')
pretrained_dict = convert_vgg(pretrained_dict)
model_dict = net.state_dict()
# 1. 把不属于我们需要的层剔除
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. 把参数存入已经存在的model_dict
model_dict.update(pretrained_dict) 
# 3. 加载更新后的model_dict
net.load_state_dict(model_dict)
print ('copy the weight sucessfully')

这样我就基本达成目标了,注意net也就是我们要使用的网络fintune部分需要和our_vgg一致。

以上这篇pytorch在fintune时将sequential中的层输出方法,以vgg为例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python操作CouchDB的方法
Oct 08 Python
使用优化器来提升Python程序的执行效率的教程
Apr 02 Python
windows及linux环境下永久修改pip镜像源的方法
Nov 28 Python
Python面向对象编程基础解析(二)
Oct 26 Python
运行django项目指定IP和端口的方法
May 14 Python
python判断输入日期为第几天的实例
Nov 13 Python
python Selenium实现付费音乐批量下载的实现方法
Jan 24 Python
PyQt+socket实现远程操作服务器的方法示例
Aug 22 Python
Tensorflow的常用矩阵生成方式
Jan 04 Python
基于python 等频分箱qcut问题的解决
Mar 03 Python
python用什么编辑器进行项目开发
Jun 17 Python
Python命名空间及作用域原理实例解析
Aug 12 Python
python实现证件照换底功能
Aug 20 #Python
pytorch多进程加速及代码优化方法
Aug 19 #Python
用Pytorch训练CNN(数据集MNIST,使用GPU的方法)
Aug 19 #Python
python opencv实现证件照换底功能
Aug 19 #Python
解决pytorch GPU 计算过程中出现内存耗尽的问题
Aug 19 #Python
将Pytorch模型从CPU转换成GPU的实现方法
Aug 19 #Python
pytorch 使用单个GPU与多个GPU进行训练与测试的方法
Aug 19 #Python
You might like
PHP初学者最感迷茫的问题小结
2010/03/27 PHP
提示Trying to clone an uncloneable object of class Imagic的解决
2011/10/27 PHP
php通过正则表达式记取数据来读取xml的方法
2015/03/09 PHP
使用xampp搭建运行php虚拟主机的详细步骤
2015/10/21 PHP
[原创]CI(CodeIgniter)简单统计访问人数实现方法
2016/01/19 PHP
PHP模板引擎Smarty中变量的使用方法示例
2016/04/11 PHP
PHP简单获取网站百度搜索和搜狗搜索收录量的方法
2016/08/23 PHP
php防止sql注入的方法详解
2017/02/20 PHP
php基于session锁防止阻塞请求的方法分析
2017/08/07 PHP
PHP实现的防止跨站和xss攻击代码【来自阿里云】
2018/01/29 PHP
几个有趣的Javascript Hack
2010/07/24 Javascript
ASP.NET jQuery 实例6 (实现CheckBoxList成员全选或全取消)
2012/01/13 Javascript
jquery提示效果实例分析
2014/11/25 Javascript
Js和JQuery获取鼠标指针坐标的实现代码分享
2015/05/25 Javascript
javascript实现的多个层切换效果通用函数实例
2015/07/06 Javascript
原生js获取iframe中dom元素--父子页面相互获取对方dom元素的方法
2016/08/05 Javascript
B/S(Web)实时通讯解决方案分享
2017/04/06 Javascript
从零开始学习Node.js系列教程二:文本提交与显示方法
2017/04/13 Javascript
使用vue-cli打包过程中的步骤以及问题的解决
2018/05/08 Javascript
JS中‘hello’与new String(‘hello’)引出的问题详解
2018/08/14 Javascript
[01:00:44]DOTA2上海特级锦标赛主赛事日 - 3 败者组第三轮#1COL VS Alliance第三局
2016/03/04 DOTA
[01:11:21]DOTA2-DPC中国联赛 正赛 VG vs Elephant BO3 第一场 3月6日
2021/03/11 DOTA
Python 实现删除某路径下文件及文件夹的实例讲解
2018/04/24 Python
python实现最小二乘法线性拟合
2019/07/19 Python
在Keras中实现保存和加载权重及模型结构
2020/06/15 Python
一篇文章搞懂python的转义字符及用法
2020/09/03 Python
Allsole美国/加拿大:英国一家专门出售品牌鞋子的网站
2018/10/21 全球购物
荷兰街头时尚之家:Funkie House
2019/03/18 全球购物
奥地利度假券的专家:we-are.travel
2019/04/10 全球购物
会计大学生职业生涯规划书范文
2014/01/13 职场文书
入党积极分子群众意见
2015/06/01 职场文书
为什么阅读对所有年龄段的孩子都很重要?
2019/07/08 职场文书
本地通过nginx配置反向代理的全过程记录
2021/03/31 Servers
React如何创建组件
2021/06/27 Javascript
Java面试题冲刺第十七天--基础篇3
2021/08/07 面试题
不同品牌、不同型号对讲机如何互相通联
2022/02/18 无线电