pytorch在fintune时将sequential中的层输出方法,以vgg为例


Posted in Python onAugust 20, 2019

有时候我们在fintune时发现pytorch把许多层都集合在一个sequential里,但是我们希望能把中间层的结果引出来做下一步操作,于是我自己琢磨了一个方法,以vgg为例,有点僵硬哈!

首先pytorch自带的vgg16模型的网络结构如下:

VGG(
 (features): Sequential(
 (0): Conv2d (3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): ReLU(inplace)
 (2): Conv2d (64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (3): ReLU(inplace)
 (4): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (5): Conv2d (64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (6): ReLU(inplace)
 (7): Conv2d (128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (8): ReLU(inplace)
 (9): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (10): Conv2d (128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (11): ReLU(inplace)
 (12): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (13): ReLU(inplace)
 (14): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (15): ReLU(inplace)
 (16): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (17): Conv2d (256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (18): ReLU(inplace)
 (19): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (20): ReLU(inplace)
 (21): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (22): ReLU(inplace)
 (23): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (24): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (25): ReLU(inplace)
 (26): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (27): ReLU(inplace)
 (28): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (29): ReLU(inplace)
 (30): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 )
 (classifier): Sequential(
 (0): Linear(in_features=25088, out_features=4096)
 (1): ReLU(inplace)
 (2): Dropout(p=0.5)
 (3): Linear(in_features=4096, out_features=4096)
 (4): ReLU(inplace)
 (5): Dropout(p=0.5)
 (6): Linear(in_features=4096, out_features=1000)
 )
)

我们需要fintune vgg16的features部分,并且我希望把3,8, 15, 22, 29这五个作为输出进一步操作。我的想法是自己写一个vgg网络,这个网络参数与pytorch的网络一致但是保证我们需要的层输出在sequential外。于是我写的网络如下:

class our_vgg(nn.Module):
 def __init__(self):
  super(our_vgg, self).__init__()
  self.conv1 = nn.Sequential(
   # conv1
   nn.Conv2d(3, 64, 3, padding=35),
   nn.ReLU(inplace=True),
   nn.Conv2d(64, 64, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv2 = nn.Sequential(
   # conv2
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/2
   nn.Conv2d(64, 128, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(128, 128, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv3 = nn.Sequential(
   # conv3
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/4
   nn.Conv2d(128, 256, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(256, 256, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(256, 256, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv4 = nn.Sequential(
   # conv4
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/8
   nn.Conv2d(256, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv5 = nn.Sequential(
   # conv5
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/16
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
  )


 def forward(self, x):

  conv1 = self.conv1(x)
  conv2 = self.conv2(conv1)
  conv3 = self.conv3(conv2)
  conv4 = self.conv4(conv3)
  conv5 = self.conv5(conv4)

  return conv5

接着就是copy weights了:

def convert_vgg(vgg16):#vgg16是pytorch自带的
 net = our_vgg()# 我写的vgg

 vgg_items = net.state_dict().items()
 vgg16_items = vgg16.items()

 pretrain_model = {}
 j = 0
 for k, v in net.state_dict().iteritems():#按顺序依次填入
  v = vgg16_items[j][1]
  k = vgg_items[j][0]
  pretrain_model[k] = v
  j += 1
 return pretrain_model


## net是我们最后使用的网络,也是我们想要放置weights的网络
net = net()

print ('load the weight from vgg')
pretrained_dict = torch.load('vgg16.pth')
pretrained_dict = convert_vgg(pretrained_dict)
model_dict = net.state_dict()
# 1. 把不属于我们需要的层剔除
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. 把参数存入已经存在的model_dict
model_dict.update(pretrained_dict) 
# 3. 加载更新后的model_dict
net.load_state_dict(model_dict)
print ('copy the weight sucessfully')

这样我就基本达成目标了,注意net也就是我们要使用的网络fintune部分需要和our_vgg一致。

以上这篇pytorch在fintune时将sequential中的层输出方法,以vgg为例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python发布模块的步骤分享
Feb 21 Python
Python StringIO模块实现在内存缓冲区中读写数据
Apr 08 Python
利用python发送和接收邮件
Sep 27 Python
Python实现PS图像调整之对比度调整功能示例
Jan 26 Python
python linecache 处理固定格式文本数据的方法
Jan 08 Python
Python 共享变量加锁、释放详解
Aug 28 Python
python文件绝对路径写法介绍(windows)
Dec 25 Python
Python逐行读取文件内容的方法总结
Feb 14 Python
Python坐标轴操作及设置代码实例
Jun 04 Python
Django:使用filter的pk进行多值查询操作
Jul 15 Python
Python更改pip镜像源的方法示例
Dec 01 Python
Python Django搭建文件下载服务器的实现
May 10 Python
python实现证件照换底功能
Aug 20 #Python
pytorch多进程加速及代码优化方法
Aug 19 #Python
用Pytorch训练CNN(数据集MNIST,使用GPU的方法)
Aug 19 #Python
python opencv实现证件照换底功能
Aug 19 #Python
解决pytorch GPU 计算过程中出现内存耗尽的问题
Aug 19 #Python
将Pytorch模型从CPU转换成GPU的实现方法
Aug 19 #Python
pytorch 使用单个GPU与多个GPU进行训练与测试的方法
Aug 19 #Python
You might like
php,不用COM,生成excel文件
2006/10/09 PHP
PHP similar_text 字符串的相似性比较函数
2010/05/26 PHP
PHP的openssl加密扩展使用小结(推荐)
2016/07/18 PHP
详解PHP处理密码的几种方式
2016/11/30 PHP
safari下载文件自动加了html后缀问题
2018/11/09 PHP
jQuery 事件的命名空间简单了解
2013/11/22 Javascript
鼠标左键单击冲突的问题解决方法(防止冒泡)
2014/05/14 Javascript
node.js中的forEach()是同步还是异步呢
2015/01/29 Javascript
jquery右下角自动弹出可关闭的广告层
2015/05/08 Javascript
pace.js页面加载进度条插件
2015/09/29 Javascript
AngularJS基础 ng-mousemove 指令简单示例
2016/08/02 Javascript
vue.js通过自定义指令实现数据拉取更新的实现方法
2016/10/18 Javascript
详解angularjs的数组传参方式的简单实现
2017/07/28 Javascript
javascript 作用于作用域链的详解
2017/09/27 Javascript
js 只比较时间大小的实例
2017/10/26 Javascript
Vue.js用法详解
2017/11/13 Javascript
vue实现分页组件
2020/06/16 Javascript
vue 引用自定义ttf、otf、在线字体的方法
2019/05/09 Javascript
Vue拖拽组件列表实现动态页面配置功能
2019/06/17 Javascript
Vue 中 filter 与 computed 的区别与用法解析
2019/11/21 Javascript
Node.js学习之内置模块fs用法示例
2020/01/22 Javascript
Python中实现结构相似的函数调用方法
2015/03/10 Python
python3操作微信itchat实现发送图片
2018/02/24 Python
详解Python数据可视化编程 - 词云生成并保存(jieba+WordCloud)
2019/03/26 Python
利用PyQt中的QThread类实现多线程
2020/02/18 Python
python 将视频 通过视频帧转换成时间实例
2020/04/23 Python
Python实现动态循环输出文字功能
2020/05/07 Python
Python中内建模块collections如何使用
2020/05/27 Python
keras和tensorflow使用fit_generator 批次训练操作
2020/07/03 Python
西部世纪.net笔试题面试题
2014/04/03 面试题
竞选学生会演讲稿
2014/04/25 职场文书
2014年团支书工作总结
2014/11/14 职场文书
六一儿童节新闻稿
2015/07/17 职场文书
聊聊mysql都有哪几种分区方式
2022/04/13 MySQL
Nginx 安装SSL证书完成HTTPS部署
2022/04/28 Servers
从原生JavaScript到React深入理解
2022/07/23 Javascript