pytorch在fintune时将sequential中的层输出方法,以vgg为例


Posted in Python onAugust 20, 2019

有时候我们在fintune时发现pytorch把许多层都集合在一个sequential里,但是我们希望能把中间层的结果引出来做下一步操作,于是我自己琢磨了一个方法,以vgg为例,有点僵硬哈!

首先pytorch自带的vgg16模型的网络结构如下:

VGG(
 (features): Sequential(
 (0): Conv2d (3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): ReLU(inplace)
 (2): Conv2d (64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (3): ReLU(inplace)
 (4): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (5): Conv2d (64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (6): ReLU(inplace)
 (7): Conv2d (128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (8): ReLU(inplace)
 (9): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (10): Conv2d (128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (11): ReLU(inplace)
 (12): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (13): ReLU(inplace)
 (14): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (15): ReLU(inplace)
 (16): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (17): Conv2d (256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (18): ReLU(inplace)
 (19): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (20): ReLU(inplace)
 (21): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (22): ReLU(inplace)
 (23): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (24): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (25): ReLU(inplace)
 (26): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (27): ReLU(inplace)
 (28): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (29): ReLU(inplace)
 (30): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 )
 (classifier): Sequential(
 (0): Linear(in_features=25088, out_features=4096)
 (1): ReLU(inplace)
 (2): Dropout(p=0.5)
 (3): Linear(in_features=4096, out_features=4096)
 (4): ReLU(inplace)
 (5): Dropout(p=0.5)
 (6): Linear(in_features=4096, out_features=1000)
 )
)

我们需要fintune vgg16的features部分,并且我希望把3,8, 15, 22, 29这五个作为输出进一步操作。我的想法是自己写一个vgg网络,这个网络参数与pytorch的网络一致但是保证我们需要的层输出在sequential外。于是我写的网络如下:

class our_vgg(nn.Module):
 def __init__(self):
  super(our_vgg, self).__init__()
  self.conv1 = nn.Sequential(
   # conv1
   nn.Conv2d(3, 64, 3, padding=35),
   nn.ReLU(inplace=True),
   nn.Conv2d(64, 64, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv2 = nn.Sequential(
   # conv2
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/2
   nn.Conv2d(64, 128, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(128, 128, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv3 = nn.Sequential(
   # conv3
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/4
   nn.Conv2d(128, 256, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(256, 256, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(256, 256, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv4 = nn.Sequential(
   # conv4
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/8
   nn.Conv2d(256, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv5 = nn.Sequential(
   # conv5
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/16
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
  )


 def forward(self, x):

  conv1 = self.conv1(x)
  conv2 = self.conv2(conv1)
  conv3 = self.conv3(conv2)
  conv4 = self.conv4(conv3)
  conv5 = self.conv5(conv4)

  return conv5

接着就是copy weights了:

def convert_vgg(vgg16):#vgg16是pytorch自带的
 net = our_vgg()# 我写的vgg

 vgg_items = net.state_dict().items()
 vgg16_items = vgg16.items()

 pretrain_model = {}
 j = 0
 for k, v in net.state_dict().iteritems():#按顺序依次填入
  v = vgg16_items[j][1]
  k = vgg_items[j][0]
  pretrain_model[k] = v
  j += 1
 return pretrain_model


## net是我们最后使用的网络,也是我们想要放置weights的网络
net = net()

print ('load the weight from vgg')
pretrained_dict = torch.load('vgg16.pth')
pretrained_dict = convert_vgg(pretrained_dict)
model_dict = net.state_dict()
# 1. 把不属于我们需要的层剔除
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. 把参数存入已经存在的model_dict
model_dict.update(pretrained_dict) 
# 3. 加载更新后的model_dict
net.load_state_dict(model_dict)
print ('copy the weight sucessfully')

这样我就基本达成目标了,注意net也就是我们要使用的网络fintune部分需要和our_vgg一致。

以上这篇pytorch在fintune时将sequential中的层输出方法,以vgg为例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现批量检测HTTP服务的状态
Oct 27 Python
Python+Selenium+PIL+Tesseract自动识别验证码进行一键登录
Sep 20 Python
利用Hyperic调用Python实现进程守护
Jan 02 Python
Python文本处理之按行处理大文件的方法
Apr 09 Python
python 利用栈和队列模拟递归的过程
May 29 Python
python实现验证码识别功能
Jun 07 Python
Flask框架通过Flask_login实现用户登录功能示例
Jul 17 Python
python使用pdfminer解析pdf文件的方法示例
Dec 20 Python
使用python将请求的requests headers参数格式化方法
Jan 02 Python
django框架创建应用操作示例
Sep 26 Python
深入浅析Django MTV模式
Sep 04 Python
python pygame 开发五子棋双人对弈
May 02 Python
python实现证件照换底功能
Aug 20 #Python
pytorch多进程加速及代码优化方法
Aug 19 #Python
用Pytorch训练CNN(数据集MNIST,使用GPU的方法)
Aug 19 #Python
python opencv实现证件照换底功能
Aug 19 #Python
解决pytorch GPU 计算过程中出现内存耗尽的问题
Aug 19 #Python
将Pytorch模型从CPU转换成GPU的实现方法
Aug 19 #Python
pytorch 使用单个GPU与多个GPU进行训练与测试的方法
Aug 19 #Python
You might like
PHP数组的交集array_intersect(),array_intersect_assoc(),array_inter_key()函数的小问题
2011/05/29 PHP
js和php邮箱地址验证的实现方法
2014/01/09 PHP
两个php日期控制类实例
2014/12/09 PHP
twig里使用js变量的方法
2016/02/05 PHP
在Thinkphp中使用ajax实现无刷新分页的方法
2016/10/25 PHP
CI框架(CodeIgniter)公共模型类定义与用法示例
2017/08/10 PHP
关于JavaScript的gzip静态压缩方法
2007/01/05 Javascript
brook javascript框架介绍
2011/10/10 Javascript
javascript实现博客园页面右下角返回顶部按钮
2015/02/22 Javascript
JS基于Mootools实现的个性菜单效果代码
2015/10/21 Javascript
分享使用AngularJS创建应用的5个框架
2015/12/05 Javascript
浅析JavaScript Array和string的转换(推荐)
2016/05/20 Javascript
微信小程序 页面跳转及数据传递详解
2017/03/14 Javascript
js自定义瀑布流布局插件
2017/05/16 Javascript
集合Bootstrap自定义confirm提示效果
2017/09/19 Javascript
浅谈Angular 中何时取消订阅
2017/11/22 Javascript
vue实现个人信息查看和密码修改功能
2018/05/06 Javascript
Angular4 反向代理Details实践
2018/05/30 Javascript
jQuery控制input只能输入数字和两位小数的方法
2019/05/16 jQuery
vue 路由缓存 路由嵌套 路由守卫 监听物理返回操作
2020/08/06 Javascript
[06:07]DOTA2-DPC中国联赛 正赛 Ehome vs VG 选手采访
2021/03/11 DOTA
Mac中升级Python2.7到Python3.5步骤详解
2017/04/27 Python
python 获取等间隔的数组实例
2019/07/04 Python
python列表每个元素同增同减和列表元素去空格的实例
2019/07/20 Python
python3发送邮件需要经过代理服务器的示例代码
2019/07/25 Python
Python3 tkinter 实现文件读取及保存功能
2019/09/12 Python
如何基于Python批量下载音乐
2019/11/11 Python
Pytorch之保存读取模型实例
2019/12/30 Python
Python新手学习raise用法
2020/06/03 Python
Boden美国官网:英伦原创时装品牌
2017/07/03 全球购物
军训考核自我鉴定
2014/02/13 职场文书
关于祖国的演讲稿
2014/05/04 职场文书
2014购房个人委托书范本
2014/10/12 职场文书
2016年秋季新学期致辞
2015/07/30 职场文书
2019广播稿怎么写
2019/04/17 职场文书
Golang Web 框架Iris安装部署
2022/08/14 Python