使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建)


Posted in Python onJanuary 18, 2020

本文介绍了如何在pytorch下搭建AlexNet,使用了两种方法,一种是直接加载预训练模型,并根据自己的需要微调(将最后一层全连接层输出由1000改为10),另一种是手动搭建。

构建模型类的时候需要继承自torch.nn.Module类,要自己重写__ \_\___init__ \_\___方法和正向传递时的forward方法,这里我自己的理解是,搭建网络写在__ \_\___init__ \_\___中,每次正向传递需要计算的部分写在forward中,例如把矩阵压平之类的。

加载预训练alexnet之后,可以print出来查看模型的结构及信息:

使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建)

model = models.alexnet(pretrained=True)
print(model)

分为两个部分,features及classifier,后续搭建模型时可以也写成这两部分,并且从打印出来的模型信息中也可以看出每一层的引用方式,便于修改,例如model.classifier[1]指的就是Linear(in_features=9216, out_features=4096, bias=True)这层。

下面放出完整的搭建代码:

import torch.nn as nn
from torchvision import models

class BuildAlexNet(nn.Module):
  def __init__(self, model_type, n_output):
    super(BuildAlexNet, self).__init__()
    self.model_type = model_type
    if model_type == 'pre':
      model = models.alexnet(pretrained=True)
      self.features = model.features
      fc1 = nn.Linear(9216, 4096)
      fc1.bias = model.classifier[1].bias
      fc1.weight = model.classifier[1].weight
      
      fc2 = nn.Linear(4096, 4096)
      fc2.bias = model.classifier[4].bias
      fc2.weight = model.classifier[4].weight
      
      self.classifier = nn.Sequential(
          nn.Dropout(),
          fc1,
          nn.ReLU(inplace=True),
          nn.Dropout(),
          fc2,
          nn.ReLU(inplace=True),
          nn.Linear(4096, n_output)) 
      #或者直接修改为
#      model.classifier[6]==nn.Linear(4096,n_output)
#      self.classifier = model.classifier
    if model_type == 'new':
      self.features = nn.Sequential(
          nn.Conv2d(3, 64, 11, 4, 2),
          nn.ReLU(inplace = True),
          nn.MaxPool2d(3, 2, 0),
          nn.Conv2d(64, 192, 5, 1, 2),
          nn.ReLU(inplace=True),
          nn.MaxPool2d(3, 2, 0),
          nn.Conv2d(192, 384, 3, 1, 1),
          nn.ReLU(inplace = True),
          nn.Conv2d(384, 256, 3, 1, 1),
          nn.ReLU(inplace=True),
          nn.MaxPool2d(3, 2, 0))
      self.classifier = nn.Sequential(
          nn.Dropout(),
          nn.Linear(9216, 4096),
          nn.ReLU(inplace=True),
          nn.Dropout(),
          nn.Linear(4096, 4096),
          nn.ReLU(inplace=True),
          nn.Linear(4096, n_output))
      
  def forward(self, x):
    x = self.features(x)
    x = x.view(x.size(0), -1)
    out = self.classifier(x)
    return out

微调预训练模型的思路为:直接保留原模型的features部分,重写classifier部分。在classifier部分中,我们实际需要修改的只有最后一层全连接层,之前的两个全连接层不需要修改,所以重写的时候需要把这两层的预训练权重和偏移保留下来,也可以像注释掉的两行代码里那样直接引用最后一层全连接层进行修改。

网络搭好之后可以小小的测试一下以检验维度是否正确。

import numpy as np
from torch.autograd import Variable
import torch

if __name__ == '__main__':
  model_type = 'pre'
  n_output = 10
  alexnet = BuildAlexNet(model_type, n_output)
  print(alexnet)
  
  x = np.random.rand(1,3,224,224)
  x = x.astype(np.float32)
  x_ts = torch.from_numpy(x)
  x_in = Variable(x_ts)
  y = alexnet(x_in)

这里如果不加“x = x.astype(np.float32)”的话会报一个类型错误,感觉有点奇怪。

输出y.data.numpy()可得10维输出,表明网络搭建正确。

以上这篇使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python对list列表结构中的值进行去重的方法总结
May 07 Python
Python内置数据结构与操作符的练习题集锦
Jul 01 Python
python使用fork实现守护进程的方法
Nov 16 Python
Python面向对象编程之继承与多态详解
Jan 16 Python
python链接oracle数据库以及数据库的增删改查实例
Jan 30 Python
python matlibplot绘制3D图形
Jul 02 Python
对python 数据处理中的LabelEncoder 和 OneHotEncoder详解
Jul 11 Python
Python线程下使用锁的技巧分享
Sep 13 Python
Python字节单位转换实例
Dec 05 Python
详解Python3.8+PyQt5+pyqt5-tools+Pycharm配置详细教程
Nov 02 Python
python用海龟绘图写贪吃蛇游戏
Jun 18 Python
详解python的异常捕获
Mar 03 Python
selenium 多窗口切换的实现(windows)
Jan 18 #Python
pytorch实现建立自己的数据集(以mnist为例)
Jan 18 #Python
使用PyTorch实现MNIST手写体识别代码
Jan 18 #Python
Pytorch之finetune使用详解
Jan 18 #Python
pytorch 修改预训练model实例
Jan 18 #Python
Pytorch自己加载单通道图片用作数据集训练的实例
Jan 18 #Python
pyinstaller 3.6版本通过pip安装失败的解决办法(推荐)
Jan 18 #Python
You might like
Zend Studio (eclipse)使用速度优化方法
2011/03/23 PHP
ThinkPHP的Widget扩展实例
2014/06/19 PHP
详解提高使用Java反射的效率方法
2019/04/29 PHP
JavaScript进阶教程(第四课第一部分)
2007/04/05 Javascript
用javascript实现计算两个日期的间隔天数
2007/08/14 Javascript
优化 JavaScript 代码的方法小结
2009/07/16 Javascript
jQuery不兼容input的change事件问题解决过程
2014/12/05 Javascript
60个很实用的jQuery代码开发技巧收集
2014/12/15 Javascript
浅谈JavaScript中的String对象常用方法
2015/02/25 Javascript
简介JavaScript中substring()方法的使用
2015/06/06 Javascript
jquery实现文本框textarea自适应高度
2016/03/09 Javascript
解决拦截器对ajax请求的拦截实例详解
2016/12/21 Javascript
详谈jQuery unbind 删除绑定事件 / 移除标签方法
2017/03/02 Javascript
vue与TypeScript集成配置最简教程(推荐)
2017/10/17 Javascript
微信小程序loading组件显示载入动画用法示例【附源码下载】
2017/12/09 Javascript
angularjs实现table增加tr的方法
2018/02/27 Javascript
Python获取网页上图片下载地址的方法
2015/03/11 Python
利用python批量给云主机配置安全组的方法教程
2017/06/21 Python
浅谈Pycharm中的Python Console与Terminal
2019/01/17 Python
使用Python向DataFrame中指定位置添加一列或多列的方法
2019/01/29 Python
详解Python匿名函数(lambda函数)
2019/04/19 Python
在PyCharm中实现添加快捷模块
2020/02/12 Python
基于Python共轭梯度法与最速下降法之间的对比
2020/04/02 Python
CSS3 优势以及网页设计师如何使用CSS3技术
2009/07/29 HTML / CSS
详解CSS3中常用的样式【基本文本和字体样式】
2020/10/20 HTML / CSS
德国汉莎航空中国官网: Lufthansa中国
2017/03/30 全球购物
荷兰男士时尚网上商店:Suitable
2017/12/25 全球购物
银行介绍信范文
2014/01/10 职场文书
宣传策划类求职信范文
2014/01/31 职场文书
小学教师个人先进事迹材料
2014/05/17 职场文书
2014年教师节演讲稿
2014/09/03 职场文书
乡镇领导班子四风整顿行动工作汇报
2014/10/25 职场文书
2015年元旦标语大全
2014/12/09 职场文书
工伤认定行政答辩状
2015/05/22 职场文书
CSS3中Animation实现简单的手指点击动画的示例
2021/07/15 HTML / CSS
关于springboot 配置date字段返回时间戳的问题
2021/07/25 Java/Android