pytorch在fintune时将sequential中的层输出方法,以vgg为例


Posted in Python onAugust 20, 2019

有时候我们在fintune时发现pytorch把许多层都集合在一个sequential里,但是我们希望能把中间层的结果引出来做下一步操作,于是我自己琢磨了一个方法,以vgg为例,有点僵硬哈!

首先pytorch自带的vgg16模型的网络结构如下:

VGG(
 (features): Sequential(
 (0): Conv2d (3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): ReLU(inplace)
 (2): Conv2d (64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (3): ReLU(inplace)
 (4): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (5): Conv2d (64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (6): ReLU(inplace)
 (7): Conv2d (128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (8): ReLU(inplace)
 (9): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (10): Conv2d (128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (11): ReLU(inplace)
 (12): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (13): ReLU(inplace)
 (14): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (15): ReLU(inplace)
 (16): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (17): Conv2d (256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (18): ReLU(inplace)
 (19): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (20): ReLU(inplace)
 (21): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (22): ReLU(inplace)
 (23): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (24): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (25): ReLU(inplace)
 (26): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (27): ReLU(inplace)
 (28): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (29): ReLU(inplace)
 (30): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 )
 (classifier): Sequential(
 (0): Linear(in_features=25088, out_features=4096)
 (1): ReLU(inplace)
 (2): Dropout(p=0.5)
 (3): Linear(in_features=4096, out_features=4096)
 (4): ReLU(inplace)
 (5): Dropout(p=0.5)
 (6): Linear(in_features=4096, out_features=1000)
 )
)

我们需要fintune vgg16的features部分,并且我希望把3,8, 15, 22, 29这五个作为输出进一步操作。我的想法是自己写一个vgg网络,这个网络参数与pytorch的网络一致但是保证我们需要的层输出在sequential外。于是我写的网络如下:

class our_vgg(nn.Module):
 def __init__(self):
  super(our_vgg, self).__init__()
  self.conv1 = nn.Sequential(
   # conv1
   nn.Conv2d(3, 64, 3, padding=35),
   nn.ReLU(inplace=True),
   nn.Conv2d(64, 64, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv2 = nn.Sequential(
   # conv2
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/2
   nn.Conv2d(64, 128, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(128, 128, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv3 = nn.Sequential(
   # conv3
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/4
   nn.Conv2d(128, 256, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(256, 256, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(256, 256, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv4 = nn.Sequential(
   # conv4
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/8
   nn.Conv2d(256, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv5 = nn.Sequential(
   # conv5
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/16
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
  )


 def forward(self, x):

  conv1 = self.conv1(x)
  conv2 = self.conv2(conv1)
  conv3 = self.conv3(conv2)
  conv4 = self.conv4(conv3)
  conv5 = self.conv5(conv4)

  return conv5

接着就是copy weights了:

def convert_vgg(vgg16):#vgg16是pytorch自带的
 net = our_vgg()# 我写的vgg

 vgg_items = net.state_dict().items()
 vgg16_items = vgg16.items()

 pretrain_model = {}
 j = 0
 for k, v in net.state_dict().iteritems():#按顺序依次填入
  v = vgg16_items[j][1]
  k = vgg_items[j][0]
  pretrain_model[k] = v
  j += 1
 return pretrain_model


## net是我们最后使用的网络,也是我们想要放置weights的网络
net = net()

print ('load the weight from vgg')
pretrained_dict = torch.load('vgg16.pth')
pretrained_dict = convert_vgg(pretrained_dict)
model_dict = net.state_dict()
# 1. 把不属于我们需要的层剔除
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. 把参数存入已经存在的model_dict
model_dict.update(pretrained_dict) 
# 3. 加载更新后的model_dict
net.load_state_dict(model_dict)
print ('copy the weight sucessfully')

这样我就基本达成目标了,注意net也就是我们要使用的网络fintune部分需要和our_vgg一致。

以上这篇pytorch在fintune时将sequential中的层输出方法,以vgg为例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
win7安装python生成随机数代码分享
Dec 27 Python
python使用urllib模块开发的多线程豆瓣小站mp3下载器
Jan 16 Python
如何利用python查找电脑文件
Apr 27 Python
使用Python操作FTP实现上传和下载的方法
Apr 01 Python
python文件写入write()的操作
May 14 Python
python中时间、日期、时间戳的转换的实现方法
Jul 06 Python
django多种支付、并发订单处理实例代码
Dec 13 Python
Python程序控制语句用法实例分析
Jan 14 Python
python实现爱奇艺登陆密码RSA加密的方法示例详解
May 27 Python
手把手教你从PyCharm安装到激活(最新激活码),亲测有效可激活至2089年
Nov 25 Python
Python 爬虫批量爬取网页图片保存到本地的实现代码
Dec 24 Python
pytorch通过训练结果的复现设置随机种子
Jun 01 Python
python实现证件照换底功能
Aug 20 #Python
pytorch多进程加速及代码优化方法
Aug 19 #Python
用Pytorch训练CNN(数据集MNIST,使用GPU的方法)
Aug 19 #Python
python opencv实现证件照换底功能
Aug 19 #Python
解决pytorch GPU 计算过程中出现内存耗尽的问题
Aug 19 #Python
将Pytorch模型从CPU转换成GPU的实现方法
Aug 19 #Python
pytorch 使用单个GPU与多个GPU进行训练与测试的方法
Aug 19 #Python
You might like
BBS(php & mysql)完整版(四)
2006/10/09 PHP
一个图形显示IP的PHP程序代码
2007/10/19 PHP
php下HTTP Response中的Chunked编码实现方法
2008/11/19 PHP
PHP防止跨域提交表单
2013/11/01 PHP
PHP文件操作之获取目录下文件与计算相对路径的方法
2016/01/08 PHP
yii 框架实现按天,月,年,自定义时间段统计数据的方法分析
2020/04/04 PHP
javascript 原型模式实现OOP的再研究
2009/04/09 Javascript
用JQuery模仿淘宝的图片放大镜显示效果
2011/09/15 Javascript
图片img的src不变让浏览器重新加载实现方法
2013/03/29 Javascript
json中换行符的处理方法示例介绍
2014/06/10 Javascript
页面内容排序插件jSort使用方法
2015/10/10 Javascript
javascript jquery对form元素的常见操作详解
2016/06/12 Javascript
详解react如何在组件中获取路由参数
2017/06/15 Javascript
JavaScript标准对象_动力节点Java学院整理
2017/06/27 Javascript
Vue2.0基于vue-cli+webpack Vuex的用法(实例讲解)
2017/09/15 Javascript
小程序数据通信方法大全(推荐)
2019/04/15 Javascript
微信小程序收货地址API兼容低版本解决方法
2019/05/18 Javascript
vue-i18n结合Element-ui的配置方法
2019/05/20 Javascript
JavaScript HTML DOM元素 节点操作汇总
2019/07/29 Javascript
JS合并两个数组的3种方法详解
2019/10/24 Javascript
[48:12]Secret vs Optic Supermajor 胜者组 BO3 第三场 6.4
2018/06/05 DOTA
python抓取京东价格分析京东商品价格走势
2014/01/09 Python
解析Python中的异常处理
2015/04/28 Python
Python的Asyncore异步Socket模块及实现端口转发的例子
2016/06/14 Python
Python排序搜索基本算法之选择排序实例分析
2017/12/09 Python
Python安装pycurl失败的解决方法
2018/10/15 Python
Python 窗体(tkinter)按钮 位置实例
2019/06/13 Python
python绘制雪景图
2019/12/16 Python
Python sublime安装及配置过程详解
2020/06/29 Python
IBatis持久层技术
2016/07/18 面试题
演讲稿开场白
2014/01/13 职场文书
《九色鹿》教学反思
2014/02/27 职场文书
环保宣传标语
2014/06/12 职场文书
校园环保广播稿(3篇)
2014/09/15 职场文书
2014年节能工作总结
2014/12/18 职场文书
bat批处理之字符串操作的实现
2022/03/16 Python