pytorch在fintune时将sequential中的层输出方法,以vgg为例


Posted in Python onAugust 20, 2019

有时候我们在fintune时发现pytorch把许多层都集合在一个sequential里,但是我们希望能把中间层的结果引出来做下一步操作,于是我自己琢磨了一个方法,以vgg为例,有点僵硬哈!

首先pytorch自带的vgg16模型的网络结构如下:

VGG(
 (features): Sequential(
 (0): Conv2d (3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): ReLU(inplace)
 (2): Conv2d (64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (3): ReLU(inplace)
 (4): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (5): Conv2d (64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (6): ReLU(inplace)
 (7): Conv2d (128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (8): ReLU(inplace)
 (9): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (10): Conv2d (128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (11): ReLU(inplace)
 (12): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (13): ReLU(inplace)
 (14): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (15): ReLU(inplace)
 (16): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (17): Conv2d (256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (18): ReLU(inplace)
 (19): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (20): ReLU(inplace)
 (21): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (22): ReLU(inplace)
 (23): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (24): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (25): ReLU(inplace)
 (26): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (27): ReLU(inplace)
 (28): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (29): ReLU(inplace)
 (30): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 )
 (classifier): Sequential(
 (0): Linear(in_features=25088, out_features=4096)
 (1): ReLU(inplace)
 (2): Dropout(p=0.5)
 (3): Linear(in_features=4096, out_features=4096)
 (4): ReLU(inplace)
 (5): Dropout(p=0.5)
 (6): Linear(in_features=4096, out_features=1000)
 )
)

我们需要fintune vgg16的features部分,并且我希望把3,8, 15, 22, 29这五个作为输出进一步操作。我的想法是自己写一个vgg网络,这个网络参数与pytorch的网络一致但是保证我们需要的层输出在sequential外。于是我写的网络如下:

class our_vgg(nn.Module):
 def __init__(self):
  super(our_vgg, self).__init__()
  self.conv1 = nn.Sequential(
   # conv1
   nn.Conv2d(3, 64, 3, padding=35),
   nn.ReLU(inplace=True),
   nn.Conv2d(64, 64, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv2 = nn.Sequential(
   # conv2
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/2
   nn.Conv2d(64, 128, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(128, 128, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv3 = nn.Sequential(
   # conv3
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/4
   nn.Conv2d(128, 256, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(256, 256, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(256, 256, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv4 = nn.Sequential(
   # conv4
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/8
   nn.Conv2d(256, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv5 = nn.Sequential(
   # conv5
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/16
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
  )


 def forward(self, x):

  conv1 = self.conv1(x)
  conv2 = self.conv2(conv1)
  conv3 = self.conv3(conv2)
  conv4 = self.conv4(conv3)
  conv5 = self.conv5(conv4)

  return conv5

接着就是copy weights了:

def convert_vgg(vgg16):#vgg16是pytorch自带的
 net = our_vgg()# 我写的vgg

 vgg_items = net.state_dict().items()
 vgg16_items = vgg16.items()

 pretrain_model = {}
 j = 0
 for k, v in net.state_dict().iteritems():#按顺序依次填入
  v = vgg16_items[j][1]
  k = vgg_items[j][0]
  pretrain_model[k] = v
  j += 1
 return pretrain_model


## net是我们最后使用的网络,也是我们想要放置weights的网络
net = net()

print ('load the weight from vgg')
pretrained_dict = torch.load('vgg16.pth')
pretrained_dict = convert_vgg(pretrained_dict)
model_dict = net.state_dict()
# 1. 把不属于我们需要的层剔除
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. 把参数存入已经存在的model_dict
model_dict.update(pretrained_dict) 
# 3. 加载更新后的model_dict
net.load_state_dict(model_dict)
print ('copy the weight sucessfully')

这样我就基本达成目标了,注意net也就是我们要使用的网络fintune部分需要和our_vgg一致。

以上这篇pytorch在fintune时将sequential中的层输出方法,以vgg为例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python爬取读者并制作成PDF
Mar 10 Python
编写Python CGI脚本的教程
Jun 29 Python
Python3中的真除和Floor除法用法分析
Mar 16 Python
用yum安装MySQLdb模块的步骤方法
Dec 15 Python
Python基于numpy灵活定义神经网络结构的方法
Aug 19 Python
Python遍历某目录下的所有文件夹与文件路径
Mar 15 Python
Python实现加载及解析properties配置文件的方法
Mar 29 Python
python conda操作方法
Sep 11 Python
Python 支持向量机分类器的实现
Jan 15 Python
Pandas —— resample()重采样和asfreq()频度转换方式
Feb 26 Python
利用python对mysql表做全局模糊搜索并分页实例
Jul 12 Python
python百行代码实现汉服圈图片爬取
Nov 23 Python
python实现证件照换底功能
Aug 20 #Python
pytorch多进程加速及代码优化方法
Aug 19 #Python
用Pytorch训练CNN(数据集MNIST,使用GPU的方法)
Aug 19 #Python
python opencv实现证件照换底功能
Aug 19 #Python
解决pytorch GPU 计算过程中出现内存耗尽的问题
Aug 19 #Python
将Pytorch模型从CPU转换成GPU的实现方法
Aug 19 #Python
pytorch 使用单个GPU与多个GPU进行训练与测试的方法
Aug 19 #Python
You might like
PHP转换文件夹下所有文件编码的实现代码
2013/06/06 PHP
基于PHP实现数据分页显示功能
2016/05/26 PHP
微信支付扫码支付php版
2016/07/22 PHP
ThinkPHP5框架实现简单的批量查询功能示例
2018/06/07 PHP
javascript之解决IE下不渲染的bug
2007/06/29 Javascript
js+ajax实现获取文件大小的方法
2015/12/08 Javascript
Javascript removeChild()删除节点及删除子节点的方法
2015/12/27 Javascript
javascript类型系统 Window对象学习笔记
2016/01/07 Javascript
Javascript动画效果(2)
2016/10/11 Javascript
Vue.js开发环境快速搭建教程
2017/03/17 Javascript
vue.js在标签属性中插入变量参数的方法
2018/03/06 Javascript
vue2 拖动排序 vuedraggable组件的实现
2019/08/08 Javascript
webpack的pitching loader详解
2019/09/23 Javascript
vue-router的hooks用法详解
2020/06/08 Javascript
谈谈JavaScript中的函数
2020/09/08 Javascript
vue 实现element-ui中的加载中状态
2020/11/11 Javascript
[02:27]2018DOTA2亚洲邀请赛趣味视频之钓鱼大赛 谁是垂钓冠军?
2018/04/05 DOTA
Python中staticmethod和classmethod的作用与区别
2018/10/11 Python
python scipy求解非线性方程的方法(fsolve/root)
2018/11/12 Python
Django2.1.3 中间件使用详解
2018/11/26 Python
基于python的itchat库实现微信聊天机器人(推荐)
2019/10/29 Python
Pythonic版二分查找实现过程原理解析
2020/08/11 Python
基于Django快速集成Echarts代码示例
2020/12/01 Python
python自动化发送邮件实例讲解
2021/01/04 Python
意大利会呼吸的鞋:Geox健乐士
2017/02/12 全球购物
SheIn俄罗斯:时尚女装网上商店
2017/02/28 全球购物
New Balance美国官网:运动鞋和健身服装
2017/04/11 全球购物
牦牛毛户外探险服装:Kora
2019/02/08 全球购物
Interhome丹麦:在线预订度假屋和公寓
2019/07/18 全球购物
Ancheer官方户外和运动商店:销售电动自行车
2019/08/07 全球购物
三年级语文教学反思
2014/02/01 职场文书
优秀家长自荐材料
2014/08/26 职场文书
2014年社区计生工作总结
2014/11/18 职场文书
先进基层党组织事迹材料
2014/12/25 职场文书
2015年大学生实习评语
2015/03/25 职场文书
民主生活会意见
2015/06/05 职场文书