pytorch在fintune时将sequential中的层输出方法,以vgg为例


Posted in Python onAugust 20, 2019

有时候我们在fintune时发现pytorch把许多层都集合在一个sequential里,但是我们希望能把中间层的结果引出来做下一步操作,于是我自己琢磨了一个方法,以vgg为例,有点僵硬哈!

首先pytorch自带的vgg16模型的网络结构如下:

VGG(
 (features): Sequential(
 (0): Conv2d (3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): ReLU(inplace)
 (2): Conv2d (64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (3): ReLU(inplace)
 (4): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (5): Conv2d (64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (6): ReLU(inplace)
 (7): Conv2d (128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (8): ReLU(inplace)
 (9): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (10): Conv2d (128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (11): ReLU(inplace)
 (12): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (13): ReLU(inplace)
 (14): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (15): ReLU(inplace)
 (16): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (17): Conv2d (256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (18): ReLU(inplace)
 (19): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (20): ReLU(inplace)
 (21): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (22): ReLU(inplace)
 (23): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (24): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (25): ReLU(inplace)
 (26): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (27): ReLU(inplace)
 (28): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (29): ReLU(inplace)
 (30): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 )
 (classifier): Sequential(
 (0): Linear(in_features=25088, out_features=4096)
 (1): ReLU(inplace)
 (2): Dropout(p=0.5)
 (3): Linear(in_features=4096, out_features=4096)
 (4): ReLU(inplace)
 (5): Dropout(p=0.5)
 (6): Linear(in_features=4096, out_features=1000)
 )
)

我们需要fintune vgg16的features部分,并且我希望把3,8, 15, 22, 29这五个作为输出进一步操作。我的想法是自己写一个vgg网络,这个网络参数与pytorch的网络一致但是保证我们需要的层输出在sequential外。于是我写的网络如下:

class our_vgg(nn.Module):
 def __init__(self):
  super(our_vgg, self).__init__()
  self.conv1 = nn.Sequential(
   # conv1
   nn.Conv2d(3, 64, 3, padding=35),
   nn.ReLU(inplace=True),
   nn.Conv2d(64, 64, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv2 = nn.Sequential(
   # conv2
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/2
   nn.Conv2d(64, 128, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(128, 128, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv3 = nn.Sequential(
   # conv3
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/4
   nn.Conv2d(128, 256, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(256, 256, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(256, 256, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv4 = nn.Sequential(
   # conv4
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/8
   nn.Conv2d(256, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv5 = nn.Sequential(
   # conv5
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/16
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
  )


 def forward(self, x):

  conv1 = self.conv1(x)
  conv2 = self.conv2(conv1)
  conv3 = self.conv3(conv2)
  conv4 = self.conv4(conv3)
  conv5 = self.conv5(conv4)

  return conv5

接着就是copy weights了:

def convert_vgg(vgg16):#vgg16是pytorch自带的
 net = our_vgg()# 我写的vgg

 vgg_items = net.state_dict().items()
 vgg16_items = vgg16.items()

 pretrain_model = {}
 j = 0
 for k, v in net.state_dict().iteritems():#按顺序依次填入
  v = vgg16_items[j][1]
  k = vgg_items[j][0]
  pretrain_model[k] = v
  j += 1
 return pretrain_model


## net是我们最后使用的网络,也是我们想要放置weights的网络
net = net()

print ('load the weight from vgg')
pretrained_dict = torch.load('vgg16.pth')
pretrained_dict = convert_vgg(pretrained_dict)
model_dict = net.state_dict()
# 1. 把不属于我们需要的层剔除
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. 把参数存入已经存在的model_dict
model_dict.update(pretrained_dict) 
# 3. 加载更新后的model_dict
net.load_state_dict(model_dict)
print ('copy the weight sucessfully')

这样我就基本达成目标了,注意net也就是我们要使用的网络fintune部分需要和our_vgg一致。

以上这篇pytorch在fintune时将sequential中的层输出方法,以vgg为例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python处理图片之PIL模块简单使用方法
May 11 Python
简单了解Python中的几种函数
Nov 03 Python
python3调用R的示例代码
Feb 23 Python
浅谈python日志的配置文件路径问题
Apr 28 Python
Django中使用Celery的方法示例
Nov 29 Python
Python API 自动化实战详解(纯代码)
Jun 11 Python
Python绘制堆叠柱状图的实例
Jul 09 Python
Python 如何提高元组的可读性
Aug 26 Python
Python谱减法语音降噪实例
Dec 18 Python
Django CSRF认证的几种解决方案
Mar 03 Python
解决pyqt5异常退出无提示信息的问题
Apr 08 Python
python实现大文本文件分割成多个小文件
Apr 20 Python
python实现证件照换底功能
Aug 20 #Python
pytorch多进程加速及代码优化方法
Aug 19 #Python
用Pytorch训练CNN(数据集MNIST,使用GPU的方法)
Aug 19 #Python
python opencv实现证件照换底功能
Aug 19 #Python
解决pytorch GPU 计算过程中出现内存耗尽的问题
Aug 19 #Python
将Pytorch模型从CPU转换成GPU的实现方法
Aug 19 #Python
pytorch 使用单个GPU与多个GPU进行训练与测试的方法
Aug 19 #Python
You might like
如何用php获取程序执行的时间
2013/06/09 PHP
php格式化日期和时间格式化示例分享
2014/02/24 PHP
php封装的smartyBC类完整实例
2016/10/19 PHP
详解thinkphp5+swoole实现异步邮件群发(SMTP方式)
2017/10/13 PHP
PHP实现通过CURL上传文件功能示例
2018/05/30 PHP
phpstudy后门rce批量利用脚本的实现
2019/12/12 PHP
javascript 冒号 使用说明
2009/06/06 Javascript
javascript instanceof 内部机制探析
2010/10/15 Javascript
js获得网页背景色和字体色的方法
2014/03/21 Javascript
js选择并转移导航菜单示例代码
2014/08/19 Javascript
jQuery随机密码生成的方法
2015/03/09 Javascript
特殊日期提示功能的实现方法
2016/06/16 Javascript
js两种拼接字符串的简单方法(必看)
2016/09/02 Javascript
更靠谱的H5横竖屏检测方法(js代码)
2016/09/13 Javascript
JS刷新父窗口的几种方式小结(推荐)
2016/11/09 Javascript
Bootstrap表单使用方法详解
2017/02/17 Javascript
Bootstrap 过渡效果Transition 模态框(Modal)
2017/03/17 Javascript
nodejs和C语言插入mysql数据库乱码问题的解决方法
2017/04/14 NodeJs
JavaScript 中定义函数用 var foo = function () {} 和 function foo()区别介绍
2018/03/01 Javascript
浅谈vue-cli 3.0.x 初体验
2018/04/11 Javascript
vue-cli3配置与跨域处理方法
2019/08/17 Javascript
使用flow来规范javascript的变量类型
2019/09/12 Javascript
javascript实现视频弹幕效果(两个版本)
2019/11/28 Javascript
[07:03]显微镜下的DOTA2第九期——430圣堂刺客杀戮秀
2014/06/20 DOTA
Python基础教程之正则表达式基本语法以及re模块
2016/03/25 Python
Python3.6正式版新特性预览
2016/12/15 Python
基于Python3.6+splinter实现自动抢火车票
2018/09/25 Python
详解django自定义中间件处理
2018/11/21 Python
python模块导入的细节详解
2018/12/10 Python
基于zepto的插件之移动端无缝向上滚动并上下触摸滑动实例代码
2016/12/20 HTML / CSS
ProBikeKit新西兰:自行车套件,跑步和铁人三项装备
2017/04/05 全球购物
可以在一个PHP文件里面include另外一个PHP文件两次吗
2015/05/22 面试题
自荐信格式写作方法有哪些呢
2013/11/20 职场文书
师范生求职自荐信
2014/06/14 职场文书
2014年仓库工作总结
2014/11/20 职场文书
西安导游词
2015/02/12 职场文书