pytorch在fintune时将sequential中的层输出方法,以vgg为例


Posted in Python onAugust 20, 2019

有时候我们在fintune时发现pytorch把许多层都集合在一个sequential里,但是我们希望能把中间层的结果引出来做下一步操作,于是我自己琢磨了一个方法,以vgg为例,有点僵硬哈!

首先pytorch自带的vgg16模型的网络结构如下:

VGG(
 (features): Sequential(
 (0): Conv2d (3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): ReLU(inplace)
 (2): Conv2d (64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (3): ReLU(inplace)
 (4): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (5): Conv2d (64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (6): ReLU(inplace)
 (7): Conv2d (128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (8): ReLU(inplace)
 (9): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (10): Conv2d (128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (11): ReLU(inplace)
 (12): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (13): ReLU(inplace)
 (14): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (15): ReLU(inplace)
 (16): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (17): Conv2d (256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (18): ReLU(inplace)
 (19): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (20): ReLU(inplace)
 (21): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (22): ReLU(inplace)
 (23): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (24): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (25): ReLU(inplace)
 (26): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (27): ReLU(inplace)
 (28): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (29): ReLU(inplace)
 (30): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 )
 (classifier): Sequential(
 (0): Linear(in_features=25088, out_features=4096)
 (1): ReLU(inplace)
 (2): Dropout(p=0.5)
 (3): Linear(in_features=4096, out_features=4096)
 (4): ReLU(inplace)
 (5): Dropout(p=0.5)
 (6): Linear(in_features=4096, out_features=1000)
 )
)

我们需要fintune vgg16的features部分,并且我希望把3,8, 15, 22, 29这五个作为输出进一步操作。我的想法是自己写一个vgg网络,这个网络参数与pytorch的网络一致但是保证我们需要的层输出在sequential外。于是我写的网络如下:

class our_vgg(nn.Module):
 def __init__(self):
  super(our_vgg, self).__init__()
  self.conv1 = nn.Sequential(
   # conv1
   nn.Conv2d(3, 64, 3, padding=35),
   nn.ReLU(inplace=True),
   nn.Conv2d(64, 64, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv2 = nn.Sequential(
   # conv2
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/2
   nn.Conv2d(64, 128, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(128, 128, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv3 = nn.Sequential(
   # conv3
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/4
   nn.Conv2d(128, 256, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(256, 256, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(256, 256, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv4 = nn.Sequential(
   # conv4
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/8
   nn.Conv2d(256, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv5 = nn.Sequential(
   # conv5
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/16
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
  )


 def forward(self, x):

  conv1 = self.conv1(x)
  conv2 = self.conv2(conv1)
  conv3 = self.conv3(conv2)
  conv4 = self.conv4(conv3)
  conv5 = self.conv5(conv4)

  return conv5

接着就是copy weights了:

def convert_vgg(vgg16):#vgg16是pytorch自带的
 net = our_vgg()# 我写的vgg

 vgg_items = net.state_dict().items()
 vgg16_items = vgg16.items()

 pretrain_model = {}
 j = 0
 for k, v in net.state_dict().iteritems():#按顺序依次填入
  v = vgg16_items[j][1]
  k = vgg_items[j][0]
  pretrain_model[k] = v
  j += 1
 return pretrain_model


## net是我们最后使用的网络,也是我们想要放置weights的网络
net = net()

print ('load the weight from vgg')
pretrained_dict = torch.load('vgg16.pth')
pretrained_dict = convert_vgg(pretrained_dict)
model_dict = net.state_dict()
# 1. 把不属于我们需要的层剔除
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. 把参数存入已经存在的model_dict
model_dict.update(pretrained_dict) 
# 3. 加载更新后的model_dict
net.load_state_dict(model_dict)
print ('copy the weight sucessfully')

这样我就基本达成目标了,注意net也就是我们要使用的网络fintune部分需要和our_vgg一致。

以上这篇pytorch在fintune时将sequential中的层输出方法,以vgg为例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
简单介绍使用Python解析并修改XML文档的方法
Oct 15 Python
Python3使用requests登录人人影视网站的方法
May 11 Python
python实现比较文件内容异同
Jun 22 Python
Python基础知识点 初识Python.md
May 14 Python
关于 Python opencv 使用中的 ValueError: too many values to unpack
Jun 28 Python
Python&&GDAL实现NDVI的计算方式
Jan 09 Python
OpenCV哈里斯(Harris)角点检测的实现
Jan 15 Python
pandas和spark dataframe互相转换实例详解
Feb 18 Python
Python数据可视化实现多种图例代码详解
Jul 14 Python
详解Python 函数参数的拆解
Sep 02 Python
如何使用python自带IDLE的几种方法
Oct 10 Python
python3中for循环踩过的坑记录
Dec 14 Python
python实现证件照换底功能
Aug 20 #Python
pytorch多进程加速及代码优化方法
Aug 19 #Python
用Pytorch训练CNN(数据集MNIST,使用GPU的方法)
Aug 19 #Python
python opencv实现证件照换底功能
Aug 19 #Python
解决pytorch GPU 计算过程中出现内存耗尽的问题
Aug 19 #Python
将Pytorch模型从CPU转换成GPU的实现方法
Aug 19 #Python
pytorch 使用单个GPU与多个GPU进行训练与测试的方法
Aug 19 #Python
You might like
怎样辨别一杯好咖啡
2021/03/03 新手入门
跟我学小偷程序之成功偷取首页(第三天)
2006/10/09 PHP
php递归列出所有文件和目录的代码
2008/09/10 PHP
PHP函数spl_autoload_register()用法和__autoload()介绍
2012/02/04 PHP
PHP中运用jQuery的Ajax跨域调用实现代码
2012/02/21 PHP
php获取qq用户昵称和在线状态(实例分析)
2013/10/27 PHP
destoon实现调用自增数字从1开始的方法
2014/08/21 PHP
php结合正则获取字符串中数字
2015/06/19 PHP
PHP整合七牛实现上传文件
2015/07/03 PHP
PHP explode()函数用法讲解
2019/02/15 PHP
jquery 防止表单重复提交代码
2010/01/21 Javascript
JSON+HTML实现国家省市联动选择效果
2014/05/18 Javascript
js的[defer]和[async]属性
2014/11/24 Javascript
JavaScript中数据结构与算法(四):串(BF)
2015/06/19 Javascript
基于jquery实现鼠标滚轮驱动的图片切换效果
2015/10/26 Javascript
JS+CSS实现DIV层的展开、收缩效果
2016/01/28 Javascript
JavaScript对象封装的简单实现方法(3种方法)
2017/01/03 Javascript
Javascript实现页面滚动时导航智能定位
2017/05/06 Javascript
轻量级富文本编辑器wangEditor结合vue使用方法示例
2018/10/10 Javascript
vue实现的封装全局filter并统一管理操作示例
2020/02/02 Javascript
JavaScript实现手机号码 3-4-4格式并控制新增和删除时光标的位置
2020/06/02 Javascript
Node.js fs模块原理及常见用途
2020/10/22 Javascript
使用简单工厂模式来进行Python的设计模式编程
2016/03/01 Python
Python算法之图的遍历
2017/11/16 Python
python获取命令行输入参数列表的实例代码
2018/06/23 Python
Python实现的读取/更改/写入xml文件操作示例
2018/08/30 Python
pycharm打开命令行或Terminal的方法
2019/01/16 Python
Pytorch中的VGG实现修改最后一层FC
2020/01/15 Python
Python如何批量获取文件夹的大小并保存
2020/03/31 Python
使用Python实现批量ping操作方法
2020/05/06 Python
使用 django orm 写 exists 条件过滤实例
2020/05/20 Python
如何利用Python动态模拟太阳系运转
2020/09/04 Python
美国运动鞋类和服装零售连锁店:Shoe Palace
2019/08/13 全球购物
研究生导师评语
2014/12/31 职场文书
2016关于军训的心得体会
2016/01/11 职场文书
新手入门Jvm-- JVM对象创建与内存分配机制
2021/06/18 Java/Android