pytorch在fintune时将sequential中的层输出方法,以vgg为例


Posted in Python onAugust 20, 2019

有时候我们在fintune时发现pytorch把许多层都集合在一个sequential里,但是我们希望能把中间层的结果引出来做下一步操作,于是我自己琢磨了一个方法,以vgg为例,有点僵硬哈!

首先pytorch自带的vgg16模型的网络结构如下:

VGG(
 (features): Sequential(
 (0): Conv2d (3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): ReLU(inplace)
 (2): Conv2d (64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (3): ReLU(inplace)
 (4): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (5): Conv2d (64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (6): ReLU(inplace)
 (7): Conv2d (128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (8): ReLU(inplace)
 (9): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (10): Conv2d (128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (11): ReLU(inplace)
 (12): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (13): ReLU(inplace)
 (14): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (15): ReLU(inplace)
 (16): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (17): Conv2d (256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (18): ReLU(inplace)
 (19): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (20): ReLU(inplace)
 (21): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (22): ReLU(inplace)
 (23): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (24): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (25): ReLU(inplace)
 (26): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (27): ReLU(inplace)
 (28): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (29): ReLU(inplace)
 (30): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 )
 (classifier): Sequential(
 (0): Linear(in_features=25088, out_features=4096)
 (1): ReLU(inplace)
 (2): Dropout(p=0.5)
 (3): Linear(in_features=4096, out_features=4096)
 (4): ReLU(inplace)
 (5): Dropout(p=0.5)
 (6): Linear(in_features=4096, out_features=1000)
 )
)

我们需要fintune vgg16的features部分,并且我希望把3,8, 15, 22, 29这五个作为输出进一步操作。我的想法是自己写一个vgg网络,这个网络参数与pytorch的网络一致但是保证我们需要的层输出在sequential外。于是我写的网络如下:

class our_vgg(nn.Module):
 def __init__(self):
  super(our_vgg, self).__init__()
  self.conv1 = nn.Sequential(
   # conv1
   nn.Conv2d(3, 64, 3, padding=35),
   nn.ReLU(inplace=True),
   nn.Conv2d(64, 64, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv2 = nn.Sequential(
   # conv2
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/2
   nn.Conv2d(64, 128, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(128, 128, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv3 = nn.Sequential(
   # conv3
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/4
   nn.Conv2d(128, 256, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(256, 256, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(256, 256, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv4 = nn.Sequential(
   # conv4
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/8
   nn.Conv2d(256, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv5 = nn.Sequential(
   # conv5
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/16
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
  )


 def forward(self, x):

  conv1 = self.conv1(x)
  conv2 = self.conv2(conv1)
  conv3 = self.conv3(conv2)
  conv4 = self.conv4(conv3)
  conv5 = self.conv5(conv4)

  return conv5

接着就是copy weights了:

def convert_vgg(vgg16):#vgg16是pytorch自带的
 net = our_vgg()# 我写的vgg

 vgg_items = net.state_dict().items()
 vgg16_items = vgg16.items()

 pretrain_model = {}
 j = 0
 for k, v in net.state_dict().iteritems():#按顺序依次填入
  v = vgg16_items[j][1]
  k = vgg_items[j][0]
  pretrain_model[k] = v
  j += 1
 return pretrain_model


## net是我们最后使用的网络,也是我们想要放置weights的网络
net = net()

print ('load the weight from vgg')
pretrained_dict = torch.load('vgg16.pth')
pretrained_dict = convert_vgg(pretrained_dict)
model_dict = net.state_dict()
# 1. 把不属于我们需要的层剔除
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. 把参数存入已经存在的model_dict
model_dict.update(pretrained_dict) 
# 3. 加载更新后的model_dict
net.load_state_dict(model_dict)
print ('copy the weight sucessfully')

这样我就基本达成目标了,注意net也就是我们要使用的网络fintune部分需要和our_vgg一致。

以上这篇pytorch在fintune时将sequential中的层输出方法,以vgg为例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中pygame模块用法实例
Oct 09 Python
python 多线程实现检测服务器在线情况
Nov 25 Python
django rest framework 数据的查找、过滤、排序的示例
Jun 25 Python
djang常用查询SQL语句的使用代码
Feb 15 Python
python实现抽奖小程序
Apr 15 Python
pymysql模块的使用(增删改查)详解
Sep 09 Python
python栈的基本定义与使用方法示例【初始化、赋值、入栈、出栈等】
Oct 24 Python
NumPy中的维度Axis详解
Nov 26 Python
在django admin详情表单显示中添加自定义控件的实现
Mar 11 Python
Anaconda+vscode+pytorch环境搭建过程详解
May 25 Python
python实现网页录音效果
Oct 26 Python
浅析Python模块之间的相互引用问题
Feb 26 Python
python实现证件照换底功能
Aug 20 #Python
pytorch多进程加速及代码优化方法
Aug 19 #Python
用Pytorch训练CNN(数据集MNIST,使用GPU的方法)
Aug 19 #Python
python opencv实现证件照换底功能
Aug 19 #Python
解决pytorch GPU 计算过程中出现内存耗尽的问题
Aug 19 #Python
将Pytorch模型从CPU转换成GPU的实现方法
Aug 19 #Python
pytorch 使用单个GPU与多个GPU进行训练与测试的方法
Aug 19 #Python
You might like
PHP概述.
2006/10/09 PHP
php学习笔记(三)操作符与控制结构
2011/08/06 PHP
使用Composer安装Yii框架的方法
2016/03/15 PHP
php中final关键字用法分析
2016/12/07 PHP
laravel中的一些简单实用功能
2018/11/03 PHP
传递参数的标准方法(jQuery.ajax)
2008/11/19 Javascript
js实现兼容IE6与IE7的DIV高度
2010/05/13 Javascript
使用PHP+JQuery+Ajax分页的实现
2013/04/23 Javascript
JavaScript中对象介绍
2014/12/31 Javascript
JavaScript实现动态添加,删除行的方法实例详解
2015/07/02 Javascript
实例代码讲解jquery easyui动态tab页
2015/11/17 Javascript
Javascript生成全局唯一标识符(GUID,UUID)的方法
2016/02/27 Javascript
Select下拉框模糊查询功能实现代码
2016/07/22 Javascript
基于JS如何实现给字符加千分符(65,541,694,158)
2016/08/03 Javascript
Vue数据驱动模拟实现5
2017/01/13 Javascript
vue下拉菜单组件(含搜索)的实现代码
2018/11/25 Javascript
小程序获取当前位置加搜索附近热门小区及商区的方法
2019/04/08 Javascript
Angular 多级路由实现登录页面跳转(小白教程)
2019/11/19 Javascript
在Echarts图中给坐标轴加一个标识线markLine
2020/07/20 Javascript
vue $mount 和 el的区别说明
2020/09/11 Javascript
js实现验证码干扰(动态)
2021/02/23 Javascript
Python3里的super()和__class__使用介绍
2015/04/23 Python
python密码错误三次锁定(实例讲解)
2017/11/14 Python
python初学之用户登录的实现过程(实例讲解)
2017/12/23 Python
python tkinter库实现气泡屏保和锁屏
2019/07/29 Python
Python numpy数组转置与轴变换
2019/11/15 Python
详解使用Python写一个向数据库填充数据的小工具(推荐)
2020/09/11 Python
美国在线印刷公司:PsPrint
2017/10/12 全球购物
大学生旅游业创业计划书
2014/01/29 职场文书
三月学雷锋月活动总结
2014/04/28 职场文书
2014年会计人员工作总结
2014/12/10 职场文书
2015年班长个人工作总结
2015/04/03 职场文书
婚宴新郎致辞
2015/07/28 职场文书
《乘法分配律》教学反思
2016/02/24 职场文书
学习型家庭事迹材料(2016精选版)
2016/02/29 职场文书
Vue Element UI自定义描述列表组件
2021/05/18 Vue.js