使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建)


Posted in Python onJanuary 18, 2020

本文介绍了如何在pytorch下搭建AlexNet,使用了两种方法,一种是直接加载预训练模型,并根据自己的需要微调(将最后一层全连接层输出由1000改为10),另一种是手动搭建。

构建模型类的时候需要继承自torch.nn.Module类,要自己重写__ \_\___init__ \_\___方法和正向传递时的forward方法,这里我自己的理解是,搭建网络写在__ \_\___init__ \_\___中,每次正向传递需要计算的部分写在forward中,例如把矩阵压平之类的。

加载预训练alexnet之后,可以print出来查看模型的结构及信息:

使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建)

model = models.alexnet(pretrained=True)
print(model)

分为两个部分,features及classifier,后续搭建模型时可以也写成这两部分,并且从打印出来的模型信息中也可以看出每一层的引用方式,便于修改,例如model.classifier[1]指的就是Linear(in_features=9216, out_features=4096, bias=True)这层。

下面放出完整的搭建代码:

import torch.nn as nn
from torchvision import models

class BuildAlexNet(nn.Module):
  def __init__(self, model_type, n_output):
    super(BuildAlexNet, self).__init__()
    self.model_type = model_type
    if model_type == 'pre':
      model = models.alexnet(pretrained=True)
      self.features = model.features
      fc1 = nn.Linear(9216, 4096)
      fc1.bias = model.classifier[1].bias
      fc1.weight = model.classifier[1].weight
      
      fc2 = nn.Linear(4096, 4096)
      fc2.bias = model.classifier[4].bias
      fc2.weight = model.classifier[4].weight
      
      self.classifier = nn.Sequential(
          nn.Dropout(),
          fc1,
          nn.ReLU(inplace=True),
          nn.Dropout(),
          fc2,
          nn.ReLU(inplace=True),
          nn.Linear(4096, n_output)) 
      #或者直接修改为
#      model.classifier[6]==nn.Linear(4096,n_output)
#      self.classifier = model.classifier
    if model_type == 'new':
      self.features = nn.Sequential(
          nn.Conv2d(3, 64, 11, 4, 2),
          nn.ReLU(inplace = True),
          nn.MaxPool2d(3, 2, 0),
          nn.Conv2d(64, 192, 5, 1, 2),
          nn.ReLU(inplace=True),
          nn.MaxPool2d(3, 2, 0),
          nn.Conv2d(192, 384, 3, 1, 1),
          nn.ReLU(inplace = True),
          nn.Conv2d(384, 256, 3, 1, 1),
          nn.ReLU(inplace=True),
          nn.MaxPool2d(3, 2, 0))
      self.classifier = nn.Sequential(
          nn.Dropout(),
          nn.Linear(9216, 4096),
          nn.ReLU(inplace=True),
          nn.Dropout(),
          nn.Linear(4096, 4096),
          nn.ReLU(inplace=True),
          nn.Linear(4096, n_output))
      
  def forward(self, x):
    x = self.features(x)
    x = x.view(x.size(0), -1)
    out = self.classifier(x)
    return out

微调预训练模型的思路为:直接保留原模型的features部分,重写classifier部分。在classifier部分中,我们实际需要修改的只有最后一层全连接层,之前的两个全连接层不需要修改,所以重写的时候需要把这两层的预训练权重和偏移保留下来,也可以像注释掉的两行代码里那样直接引用最后一层全连接层进行修改。

网络搭好之后可以小小的测试一下以检验维度是否正确。

import numpy as np
from torch.autograd import Variable
import torch

if __name__ == '__main__':
  model_type = 'pre'
  n_output = 10
  alexnet = BuildAlexNet(model_type, n_output)
  print(alexnet)
  
  x = np.random.rand(1,3,224,224)
  x = x.astype(np.float32)
  x_ts = torch.from_numpy(x)
  x_in = Variable(x_ts)
  y = alexnet(x_in)

这里如果不加“x = x.astype(np.float32)”的话会报一个类型错误,感觉有点奇怪。

输出y.data.numpy()可得10维输出,表明网络搭建正确。

以上这篇使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
下载给定网页上图片的方法
Feb 18 Python
使用PyInstaller将Python程序文件转换为可执行程序文件
Jul 08 Python
windows下python之mysqldb模块安装方法
Sep 07 Python
Python 字符串与二进制串的相互转换示例
Jul 23 Python
利用Python对文件夹下图片数据进行批量改名的代码实例
Feb 21 Python
python获取地震信息 微信实时推送
Jun 18 Python
解决win7操作系统Python3.7.1安装后启动提示缺少.dll文件问题
Jul 15 Python
Python Web框架之Django框架Form组件用法详解
Aug 16 Python
pytorch 自定义参数不更新方式
Jan 06 Python
python模拟实现斗地主发牌
Jan 07 Python
Python正则表达式学习小例子
Mar 03 Python
Python 如何查找特定类型文件
Aug 17 Python
selenium 多窗口切换的实现(windows)
Jan 18 #Python
pytorch实现建立自己的数据集(以mnist为例)
Jan 18 #Python
使用PyTorch实现MNIST手写体识别代码
Jan 18 #Python
Pytorch之finetune使用详解
Jan 18 #Python
pytorch 修改预训练model实例
Jan 18 #Python
Pytorch自己加载单通道图片用作数据集训练的实例
Jan 18 #Python
pyinstaller 3.6版本通过pip安装失败的解决办法(推荐)
Jan 18 #Python
You might like
php IP及IP段进行访问限制的代码
2008/12/17 PHP
PHP连接MongoDB示例代码
2012/09/06 PHP
Symfony2获取web目录绝对路径、相对路径、网址的方法
2016/11/14 PHP
Javascript创建Silverlight Plugin以及自定义nonSilverlight和lowSilverlight样式
2010/06/28 Javascript
关于捕获用户何时点击window.onbeforeunload的取消事件
2011/03/06 Javascript
js随机颜色代码的多种实现方式
2013/04/23 Javascript
JavaScript使用slice函数获取数组部分元素的方法
2015/04/06 Javascript
javascript中 try catch用法
2015/08/16 Javascript
学习JavaScript设计模式之模板方法模式
2016/01/20 Javascript
JS 动态加载js文件和css文件 同步/异步的两种简单方式
2016/09/23 Javascript
解决vue中修改了数据但视图无法更新的情况
2018/08/27 Javascript
Angular6使用forRoot() 注册单一实例服务问题
2019/08/27 Javascript
解决vue项目刷新后,导航菜单高亮显示的位置不对问题
2019/11/01 Javascript
JavaScript使用prototype属性实现继承操作示例
2020/05/22 Javascript
[03:38]2014DOTA2西雅图国际邀请赛 VG战队巡礼
2014/07/07 DOTA
[48:21]Mski vs VGJ.S Supermajor小组赛C组 BO3 第一场 6.3
2018/06/04 DOTA
python实现稀疏矩阵示例代码
2017/06/09 Python
关于Python中浮点数精度处理的技巧总结
2017/08/10 Python
Python语言实现百度语音识别API的使用实例
2017/12/13 Python
Python 元类实例解析
2018/04/04 Python
python爬虫实例详解
2018/06/19 Python
一百行python代码将图片转成字符画
2021/02/19 Python
python去除拼音声调字母,替换为字母的方法
2018/11/28 Python
Django的用户模块与权限系统的示例代码
2019/07/24 Python
Python 多线程共享变量的实现示例
2020/04/17 Python
HTML5 Canvas如何实现纹理填充与描边(Fill And Stroke)
2013/07/15 HTML / CSS
移动端html5判断是否滚动到底部并且下拉加载
2019/11/19 HTML / CSS
Hotels.com越南:酒店预订
2019/10/29 全球购物
计算机专业推荐信范文
2013/11/20 职场文书
我的求职计划书
2014/01/10 职场文书
党支部班子“四风”问题自我剖析材料
2014/09/28 职场文书
2015年七一建党节慰问信
2015/03/23 职场文书
联谊会开场白
2015/06/01 职场文书
导游词之江西赣州
2019/10/15 职场文书
详解Golang如何优雅的终止一个服务
2022/03/21 Golang
《勇者辞职不干了》上卷BD发售宣传CM公开
2022/04/08 日漫