使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建)


Posted in Python onJanuary 18, 2020

本文介绍了如何在pytorch下搭建AlexNet,使用了两种方法,一种是直接加载预训练模型,并根据自己的需要微调(将最后一层全连接层输出由1000改为10),另一种是手动搭建。

构建模型类的时候需要继承自torch.nn.Module类,要自己重写__ \_\___init__ \_\___方法和正向传递时的forward方法,这里我自己的理解是,搭建网络写在__ \_\___init__ \_\___中,每次正向传递需要计算的部分写在forward中,例如把矩阵压平之类的。

加载预训练alexnet之后,可以print出来查看模型的结构及信息:

使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建)

model = models.alexnet(pretrained=True)
print(model)

分为两个部分,features及classifier,后续搭建模型时可以也写成这两部分,并且从打印出来的模型信息中也可以看出每一层的引用方式,便于修改,例如model.classifier[1]指的就是Linear(in_features=9216, out_features=4096, bias=True)这层。

下面放出完整的搭建代码:

import torch.nn as nn
from torchvision import models

class BuildAlexNet(nn.Module):
  def __init__(self, model_type, n_output):
    super(BuildAlexNet, self).__init__()
    self.model_type = model_type
    if model_type == 'pre':
      model = models.alexnet(pretrained=True)
      self.features = model.features
      fc1 = nn.Linear(9216, 4096)
      fc1.bias = model.classifier[1].bias
      fc1.weight = model.classifier[1].weight
      
      fc2 = nn.Linear(4096, 4096)
      fc2.bias = model.classifier[4].bias
      fc2.weight = model.classifier[4].weight
      
      self.classifier = nn.Sequential(
          nn.Dropout(),
          fc1,
          nn.ReLU(inplace=True),
          nn.Dropout(),
          fc2,
          nn.ReLU(inplace=True),
          nn.Linear(4096, n_output)) 
      #或者直接修改为
#      model.classifier[6]==nn.Linear(4096,n_output)
#      self.classifier = model.classifier
    if model_type == 'new':
      self.features = nn.Sequential(
          nn.Conv2d(3, 64, 11, 4, 2),
          nn.ReLU(inplace = True),
          nn.MaxPool2d(3, 2, 0),
          nn.Conv2d(64, 192, 5, 1, 2),
          nn.ReLU(inplace=True),
          nn.MaxPool2d(3, 2, 0),
          nn.Conv2d(192, 384, 3, 1, 1),
          nn.ReLU(inplace = True),
          nn.Conv2d(384, 256, 3, 1, 1),
          nn.ReLU(inplace=True),
          nn.MaxPool2d(3, 2, 0))
      self.classifier = nn.Sequential(
          nn.Dropout(),
          nn.Linear(9216, 4096),
          nn.ReLU(inplace=True),
          nn.Dropout(),
          nn.Linear(4096, 4096),
          nn.ReLU(inplace=True),
          nn.Linear(4096, n_output))
      
  def forward(self, x):
    x = self.features(x)
    x = x.view(x.size(0), -1)
    out = self.classifier(x)
    return out

微调预训练模型的思路为:直接保留原模型的features部分,重写classifier部分。在classifier部分中,我们实际需要修改的只有最后一层全连接层,之前的两个全连接层不需要修改,所以重写的时候需要把这两层的预训练权重和偏移保留下来,也可以像注释掉的两行代码里那样直接引用最后一层全连接层进行修改。

网络搭好之后可以小小的测试一下以检验维度是否正确。

import numpy as np
from torch.autograd import Variable
import torch

if __name__ == '__main__':
  model_type = 'pre'
  n_output = 10
  alexnet = BuildAlexNet(model_type, n_output)
  print(alexnet)
  
  x = np.random.rand(1,3,224,224)
  x = x.astype(np.float32)
  x_ts = torch.from_numpy(x)
  x_in = Variable(x_ts)
  y = alexnet(x_in)

这里如果不加“x = x.astype(np.float32)”的话会报一个类型错误,感觉有点奇怪。

输出y.data.numpy()可得10维输出,表明网络搭建正确。

以上这篇使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
django DRF图片路径问题的解决方法
Sep 10 Python
pandas 数据归一化以及行删除例程的方法
Nov 10 Python
Python跳出多重循环的方法示例
Jul 03 Python
从列表或字典创建Pandas的DataFrame对象的方法
Jul 06 Python
详解Python对JSON中的特殊类型进行Encoder
Jul 15 Python
python3中替换python2中cmp函数的实现
Aug 20 Python
python或C++读取指定文件夹下的所有图片
Aug 31 Python
python如何将两张图片生成为全景图片
Mar 05 Python
Python接口开发实现步骤详解
Apr 26 Python
scrapy与selenium结合爬取数据(爬取动态网站)的示例代码
Sep 28 Python
Pytorch 统计模型参数量的操作 param.numel()
May 13 Python
Django实现聊天机器人
May 31 Python
selenium 多窗口切换的实现(windows)
Jan 18 #Python
pytorch实现建立自己的数据集(以mnist为例)
Jan 18 #Python
使用PyTorch实现MNIST手写体识别代码
Jan 18 #Python
Pytorch之finetune使用详解
Jan 18 #Python
pytorch 修改预训练model实例
Jan 18 #Python
Pytorch自己加载单通道图片用作数据集训练的实例
Jan 18 #Python
pyinstaller 3.6版本通过pip安装失败的解决办法(推荐)
Jan 18 #Python
You might like
实用函数7
2007/11/08 PHP
PHP获取http请求的头信息实现步骤
2012/12/16 PHP
php实现仿写CodeIgniter的购物车类
2015/07/29 PHP
使用PHP接受文件并获得其后缀名的方法
2015/08/05 PHP
php获取给定日期相差天数的方法分析
2017/02/20 PHP
php扩展开发入门demo示例
2019/09/23 PHP
javascript 写类方式之九
2009/07/05 Javascript
js跨域和ajax 跨域问题的实现思路
2009/09/05 Javascript
input按钮的事件处理大全
2010/12/10 Javascript
js自动查找select下拉的菜单并选择(示例代码)
2014/02/26 Javascript
iframe实用操作锦集
2014/04/22 Javascript
JavaScript中的无阻塞加载性能优化方案
2014/10/10 Javascript
轻松创建nodejs服务器(3):代码模块化
2014/12/18 NodeJs
jquery 中ajax执行的优先级
2015/06/22 Javascript
好好了解一下Cookie(强烈推荐)
2016/06/14 Javascript
js实现加载更多功能实例
2016/10/27 Javascript
巧用数组制作图片切换js代码
2016/11/29 Javascript
Angular.js中window.onload(),$(document).ready()的写法浅析
2017/09/28 Javascript
js实现动态增加文件域表单功能
2018/10/22 Javascript
微信小程序自定义tabBar组件开发详解
2020/09/24 Javascript
vue实现微信获取用户信息的方法
2019/03/21 Javascript
vue实现信息管理系统
2020/05/30 Javascript
python Django模板的使用方法(图文)
2013/11/04 Python
Python调用adb命令实现对多台设备同时进行reboot的方法
2018/10/15 Python
python基于TCP实现的文件下载器功能案例
2019/12/10 Python
Tensorflow实现部分参数梯度更新操作
2020/01/23 Python
在Python中通过threshold创建mask方式
2020/02/19 Python
从0到1使用python开发一个半自动答题小程序的实现
2020/05/12 Python
财务会计专业求职信范文
2013/12/31 职场文书
网络程序员自荐信
2014/01/25 职场文书
护理专业优质毕业生自荐书
2014/01/31 职场文书
历史专业学生的自我评价
2014/02/28 职场文书
工作态度不端正检讨书
2014/10/04 职场文书
礼貌问候语大全
2015/11/10 职场文书
Python利用Turtle绘制哆啦A梦和小猪佩奇
2022/04/04 Python
《群青的幻想曲》京力秋树角色PV公开
2022/04/08 日漫