使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建)


Posted in Python onJanuary 18, 2020

本文介绍了如何在pytorch下搭建AlexNet,使用了两种方法,一种是直接加载预训练模型,并根据自己的需要微调(将最后一层全连接层输出由1000改为10),另一种是手动搭建。

构建模型类的时候需要继承自torch.nn.Module类,要自己重写__ \_\___init__ \_\___方法和正向传递时的forward方法,这里我自己的理解是,搭建网络写在__ \_\___init__ \_\___中,每次正向传递需要计算的部分写在forward中,例如把矩阵压平之类的。

加载预训练alexnet之后,可以print出来查看模型的结构及信息:

使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建)

model = models.alexnet(pretrained=True)
print(model)

分为两个部分,features及classifier,后续搭建模型时可以也写成这两部分,并且从打印出来的模型信息中也可以看出每一层的引用方式,便于修改,例如model.classifier[1]指的就是Linear(in_features=9216, out_features=4096, bias=True)这层。

下面放出完整的搭建代码:

import torch.nn as nn
from torchvision import models

class BuildAlexNet(nn.Module):
  def __init__(self, model_type, n_output):
    super(BuildAlexNet, self).__init__()
    self.model_type = model_type
    if model_type == 'pre':
      model = models.alexnet(pretrained=True)
      self.features = model.features
      fc1 = nn.Linear(9216, 4096)
      fc1.bias = model.classifier[1].bias
      fc1.weight = model.classifier[1].weight
      
      fc2 = nn.Linear(4096, 4096)
      fc2.bias = model.classifier[4].bias
      fc2.weight = model.classifier[4].weight
      
      self.classifier = nn.Sequential(
          nn.Dropout(),
          fc1,
          nn.ReLU(inplace=True),
          nn.Dropout(),
          fc2,
          nn.ReLU(inplace=True),
          nn.Linear(4096, n_output)) 
      #或者直接修改为
#      model.classifier[6]==nn.Linear(4096,n_output)
#      self.classifier = model.classifier
    if model_type == 'new':
      self.features = nn.Sequential(
          nn.Conv2d(3, 64, 11, 4, 2),
          nn.ReLU(inplace = True),
          nn.MaxPool2d(3, 2, 0),
          nn.Conv2d(64, 192, 5, 1, 2),
          nn.ReLU(inplace=True),
          nn.MaxPool2d(3, 2, 0),
          nn.Conv2d(192, 384, 3, 1, 1),
          nn.ReLU(inplace = True),
          nn.Conv2d(384, 256, 3, 1, 1),
          nn.ReLU(inplace=True),
          nn.MaxPool2d(3, 2, 0))
      self.classifier = nn.Sequential(
          nn.Dropout(),
          nn.Linear(9216, 4096),
          nn.ReLU(inplace=True),
          nn.Dropout(),
          nn.Linear(4096, 4096),
          nn.ReLU(inplace=True),
          nn.Linear(4096, n_output))
      
  def forward(self, x):
    x = self.features(x)
    x = x.view(x.size(0), -1)
    out = self.classifier(x)
    return out

微调预训练模型的思路为:直接保留原模型的features部分,重写classifier部分。在classifier部分中,我们实际需要修改的只有最后一层全连接层,之前的两个全连接层不需要修改,所以重写的时候需要把这两层的预训练权重和偏移保留下来,也可以像注释掉的两行代码里那样直接引用最后一层全连接层进行修改。

网络搭好之后可以小小的测试一下以检验维度是否正确。

import numpy as np
from torch.autograd import Variable
import torch

if __name__ == '__main__':
  model_type = 'pre'
  n_output = 10
  alexnet = BuildAlexNet(model_type, n_output)
  print(alexnet)
  
  x = np.random.rand(1,3,224,224)
  x = x.astype(np.float32)
  x_ts = torch.from_numpy(x)
  x_in = Variable(x_ts)
  y = alexnet(x_in)

这里如果不加“x = x.astype(np.float32)”的话会报一个类型错误,感觉有点奇怪。

输出y.data.numpy()可得10维输出,表明网络搭建正确。

以上这篇使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的迭代器与生成器高级用法解析
Jun 28 Python
浅谈python中的变量默认是什么类型
Sep 11 Python
python使用pycharm环境调用opencv库
Feb 11 Python
python读取文件名称生成list的方法
Apr 27 Python
Python查找文件中包含中文的行方法
Dec 19 Python
Python实现的微信支付方式总结【三种方式】
Apr 13 Python
使用Python检测文章抄袭及去重算法原理解析
Jun 14 Python
Python实现TCP探测目标服务路由轨迹的原理与方法详解
Sep 04 Python
PyTorch加载自己的数据集实例详解
Mar 18 Python
Python文本文件的合并操作方法代码实例
Mar 31 Python
Python sqlalchemy时间戳及密码管理实现代码详解
Aug 01 Python
Django集成MongoDB实现过程解析
Dec 01 Python
selenium 多窗口切换的实现(windows)
Jan 18 #Python
pytorch实现建立自己的数据集(以mnist为例)
Jan 18 #Python
使用PyTorch实现MNIST手写体识别代码
Jan 18 #Python
Pytorch之finetune使用详解
Jan 18 #Python
pytorch 修改预训练model实例
Jan 18 #Python
Pytorch自己加载单通道图片用作数据集训练的实例
Jan 18 #Python
pyinstaller 3.6版本通过pip安装失败的解决办法(推荐)
Jan 18 #Python
You might like
PHP框架Laravel的小技巧两则
2015/02/10 PHP
服务器迁移php版本不同可能诱发的问题
2015/12/22 PHP
Yii框架连表查询操作示例
2019/09/06 PHP
Javascript学习笔记4 Eval函数
2010/01/11 Javascript
javascript学习(二)javascript常见问题总结
2013/01/02 Javascript
高效率JavaScript编写技巧整理
2013/08/23 Javascript
js+div实现文字滚动和图片切换效果代码
2015/08/27 Javascript
jQuery ztree实现动态树形多选菜单
2016/08/12 Javascript
js轮播图透明度切换(带上下页和底部圆点切换)
2017/04/27 Javascript
node.js操作mysql简单实例
2017/05/25 Javascript
vue复合组件实现注册表单功能
2017/11/06 Javascript
详解Vue-cli中的静态资源管理(src/assets和static/的区别)
2018/06/19 Javascript
详解VUE单页应用骨架屏方案
2019/01/17 Javascript
vue+ts下对axios的封装实现
2020/02/18 Javascript
JavaScript中继承原理与用法实例入门
2020/05/09 Javascript
[01:15:44]首部DOTA2纪录片今日23时全网上映
2014/03/19 DOTA
python控制台英汉汉英电子词典
2020/04/23 Python
Python 判断是否为质数或素数的实例
2017/10/30 Python
python 脚本生成随机 字母 + 数字密码功能
2018/05/26 Python
Python中zip()函数的简单用法举例
2019/09/02 Python
python基于三阶贝塞尔曲线的数据平滑算法
2019/12/27 Python
Python递归函数特点及原理解析
2020/03/04 Python
Python SQLAlchemy库的使用方法
2020/10/13 Python
基于HTML5实现类似微信手机摇一摇功能(计算摇动次数)
2017/07/24 HTML / CSS
英国家居装饰品、户外家具和玻璃器皿购物网站:Rinkit.com
2019/11/04 全球购物
英国领先的独立酒精饮料零售商:DrinkSupermarket
2021/01/13 全球购物
衰败城市英国官网:Urban Decay英国
2020/04/29 全球购物
加拿大户外探险购物网站:SAIL
2020/06/27 全球购物
建筑投标担保书
2014/05/20 职场文书
2014年大学生社会实践自我鉴定
2014/09/26 职场文书
消防隐患整改通知书
2015/04/22 职场文书
化妆品促销活动总结
2015/05/07 职场文书
爱心捐助活动总结
2015/05/09 职场文书
运动会新闻稿
2015/07/17 职场文书
实验心得体会范文
2016/01/25 职场文书
SpringCloud Alibaba 基本开发框架搭建过程
2021/06/13 Java/Android