使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建)


Posted in Python onJanuary 18, 2020

本文介绍了如何在pytorch下搭建AlexNet,使用了两种方法,一种是直接加载预训练模型,并根据自己的需要微调(将最后一层全连接层输出由1000改为10),另一种是手动搭建。

构建模型类的时候需要继承自torch.nn.Module类,要自己重写__ \_\___init__ \_\___方法和正向传递时的forward方法,这里我自己的理解是,搭建网络写在__ \_\___init__ \_\___中,每次正向传递需要计算的部分写在forward中,例如把矩阵压平之类的。

加载预训练alexnet之后,可以print出来查看模型的结构及信息:

使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建)

model = models.alexnet(pretrained=True)
print(model)

分为两个部分,features及classifier,后续搭建模型时可以也写成这两部分,并且从打印出来的模型信息中也可以看出每一层的引用方式,便于修改,例如model.classifier[1]指的就是Linear(in_features=9216, out_features=4096, bias=True)这层。

下面放出完整的搭建代码:

import torch.nn as nn
from torchvision import models

class BuildAlexNet(nn.Module):
  def __init__(self, model_type, n_output):
    super(BuildAlexNet, self).__init__()
    self.model_type = model_type
    if model_type == 'pre':
      model = models.alexnet(pretrained=True)
      self.features = model.features
      fc1 = nn.Linear(9216, 4096)
      fc1.bias = model.classifier[1].bias
      fc1.weight = model.classifier[1].weight
      
      fc2 = nn.Linear(4096, 4096)
      fc2.bias = model.classifier[4].bias
      fc2.weight = model.classifier[4].weight
      
      self.classifier = nn.Sequential(
          nn.Dropout(),
          fc1,
          nn.ReLU(inplace=True),
          nn.Dropout(),
          fc2,
          nn.ReLU(inplace=True),
          nn.Linear(4096, n_output)) 
      #或者直接修改为
#      model.classifier[6]==nn.Linear(4096,n_output)
#      self.classifier = model.classifier
    if model_type == 'new':
      self.features = nn.Sequential(
          nn.Conv2d(3, 64, 11, 4, 2),
          nn.ReLU(inplace = True),
          nn.MaxPool2d(3, 2, 0),
          nn.Conv2d(64, 192, 5, 1, 2),
          nn.ReLU(inplace=True),
          nn.MaxPool2d(3, 2, 0),
          nn.Conv2d(192, 384, 3, 1, 1),
          nn.ReLU(inplace = True),
          nn.Conv2d(384, 256, 3, 1, 1),
          nn.ReLU(inplace=True),
          nn.MaxPool2d(3, 2, 0))
      self.classifier = nn.Sequential(
          nn.Dropout(),
          nn.Linear(9216, 4096),
          nn.ReLU(inplace=True),
          nn.Dropout(),
          nn.Linear(4096, 4096),
          nn.ReLU(inplace=True),
          nn.Linear(4096, n_output))
      
  def forward(self, x):
    x = self.features(x)
    x = x.view(x.size(0), -1)
    out = self.classifier(x)
    return out

微调预训练模型的思路为:直接保留原模型的features部分,重写classifier部分。在classifier部分中,我们实际需要修改的只有最后一层全连接层,之前的两个全连接层不需要修改,所以重写的时候需要把这两层的预训练权重和偏移保留下来,也可以像注释掉的两行代码里那样直接引用最后一层全连接层进行修改。

网络搭好之后可以小小的测试一下以检验维度是否正确。

import numpy as np
from torch.autograd import Variable
import torch

if __name__ == '__main__':
  model_type = 'pre'
  n_output = 10
  alexnet = BuildAlexNet(model_type, n_output)
  print(alexnet)
  
  x = np.random.rand(1,3,224,224)
  x = x.astype(np.float32)
  x_ts = torch.from_numpy(x)
  x_in = Variable(x_ts)
  y = alexnet(x_in)

这里如果不加“x = x.astype(np.float32)”的话会报一个类型错误,感觉有点奇怪。

输出y.data.numpy()可得10维输出,表明网络搭建正确。

以上这篇使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python多进程和多线程究竟谁更快(详解)
May 29 Python
机器学习10大经典算法详解
Dec 07 Python
浅谈Python的list中的选取范围
Nov 12 Python
Python实现的调用C语言函数功能简单实例
Mar 13 Python
详解Python数据分析--Pandas知识点
Mar 23 Python
Python中调用其他程序的方式详解
Aug 06 Python
用python实现英文字母和相应序数转换的方法
Sep 18 Python
python实现while循环打印星星的四种形状
Nov 23 Python
keras获得某一层或者某层权重的输出实例
Jan 24 Python
python如何写出表白程序
Jun 01 Python
python读取图像矩阵文件并转换为向量实例
Jun 18 Python
python中count函数知识点浅析
Dec 17 Python
selenium 多窗口切换的实现(windows)
Jan 18 #Python
pytorch实现建立自己的数据集(以mnist为例)
Jan 18 #Python
使用PyTorch实现MNIST手写体识别代码
Jan 18 #Python
Pytorch之finetune使用详解
Jan 18 #Python
pytorch 修改预训练model实例
Jan 18 #Python
Pytorch自己加载单通道图片用作数据集训练的实例
Jan 18 #Python
pyinstaller 3.6版本通过pip安装失败的解决办法(推荐)
Jan 18 #Python
You might like
PHP中限制IP段访问、禁止IP提交表单的代码
2011/04/23 PHP
Nginx下ThinkPHP5的配置方法详解
2017/08/01 PHP
PHP利用百度ai实现文本和图片审核
2019/05/08 PHP
jquery jqPlot API 中文使用教程(非常强大的图表工具)
2011/08/15 Javascript
javascript中将Object转换为String函数代码 (json str)
2012/04/29 Javascript
点击A元素触发B元素的事件在IE8下会识别成A元素
2014/09/04 Javascript
js实现点击链接后延迟3秒再跳转的方法
2015/06/05 Javascript
JavaScript统计网站访问次数的实现代码
2015/11/18 Javascript
JavaScript html5 canvas绘制时钟效果
2016/03/01 Javascript
JavaScript是如何实现继承的(六种方式)
2016/03/31 Javascript
jQuery实现表格与ckeckbox的全选与单选功能
2016/11/24 Javascript
headjs实现网站并行加载但顺序执行JS
2016/11/29 Javascript
浅谈vue-lazyload实现的详细过程
2017/08/22 Javascript
浅谈Vue.js中ref ($refs)用法举例总结
2017/12/19 Javascript
详解Vue webapp项目通过HBulider打包原生APP
2018/06/29 Javascript
vue.js编译时给生成的文件增加版本号
2018/09/17 Javascript
Python里disconnect UDP套接字的方法
2015/04/23 Python
详解利用Python scipy.signal.filtfilt() 实现信号滤波
2019/06/05 Python
12个步骤教你理解Python装饰器
2019/07/01 Python
Python爬虫库BeautifulSoup的介绍与简单使用实例
2020/01/25 Python
Jupyter Notebook远程登录及密码设置操作
2020/04/10 Python
Python库skimage绘制二值图像代码实例
2020/04/10 Python
Python -m参数原理及使用方法解析
2020/08/21 Python
python 合并多个excel中同名的sheet
2021/01/22 Python
Ellos丹麦:时尚和服装在线
2016/09/19 全球购物
党员的自我评价范文
2014/01/02 职场文书
安全生产汇报材料
2014/02/17 职场文书
《宿建德江》教学反思
2014/04/23 职场文书
庆祝教师节演讲稿
2014/09/03 职场文书
法英专业大学生职业生涯规划范文:衡外情,量己力!
2014/09/23 职场文书
研究生导师推荐信
2015/03/25 职场文书
网络营销实训总结
2015/08/03 职场文书
python数据分析之用sklearn预测糖尿病
2021/04/22 Python
MySQL 数据恢复的多种方法汇总
2021/06/21 MySQL
电频谱管理的原则是什么
2022/02/18 无线电