使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建)


Posted in Python onJanuary 18, 2020

本文介绍了如何在pytorch下搭建AlexNet,使用了两种方法,一种是直接加载预训练模型,并根据自己的需要微调(将最后一层全连接层输出由1000改为10),另一种是手动搭建。

构建模型类的时候需要继承自torch.nn.Module类,要自己重写__ \_\___init__ \_\___方法和正向传递时的forward方法,这里我自己的理解是,搭建网络写在__ \_\___init__ \_\___中,每次正向传递需要计算的部分写在forward中,例如把矩阵压平之类的。

加载预训练alexnet之后,可以print出来查看模型的结构及信息:

使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建)

model = models.alexnet(pretrained=True)
print(model)

分为两个部分,features及classifier,后续搭建模型时可以也写成这两部分,并且从打印出来的模型信息中也可以看出每一层的引用方式,便于修改,例如model.classifier[1]指的就是Linear(in_features=9216, out_features=4096, bias=True)这层。

下面放出完整的搭建代码:

import torch.nn as nn
from torchvision import models

class BuildAlexNet(nn.Module):
  def __init__(self, model_type, n_output):
    super(BuildAlexNet, self).__init__()
    self.model_type = model_type
    if model_type == 'pre':
      model = models.alexnet(pretrained=True)
      self.features = model.features
      fc1 = nn.Linear(9216, 4096)
      fc1.bias = model.classifier[1].bias
      fc1.weight = model.classifier[1].weight
      
      fc2 = nn.Linear(4096, 4096)
      fc2.bias = model.classifier[4].bias
      fc2.weight = model.classifier[4].weight
      
      self.classifier = nn.Sequential(
          nn.Dropout(),
          fc1,
          nn.ReLU(inplace=True),
          nn.Dropout(),
          fc2,
          nn.ReLU(inplace=True),
          nn.Linear(4096, n_output)) 
      #或者直接修改为
#      model.classifier[6]==nn.Linear(4096,n_output)
#      self.classifier = model.classifier
    if model_type == 'new':
      self.features = nn.Sequential(
          nn.Conv2d(3, 64, 11, 4, 2),
          nn.ReLU(inplace = True),
          nn.MaxPool2d(3, 2, 0),
          nn.Conv2d(64, 192, 5, 1, 2),
          nn.ReLU(inplace=True),
          nn.MaxPool2d(3, 2, 0),
          nn.Conv2d(192, 384, 3, 1, 1),
          nn.ReLU(inplace = True),
          nn.Conv2d(384, 256, 3, 1, 1),
          nn.ReLU(inplace=True),
          nn.MaxPool2d(3, 2, 0))
      self.classifier = nn.Sequential(
          nn.Dropout(),
          nn.Linear(9216, 4096),
          nn.ReLU(inplace=True),
          nn.Dropout(),
          nn.Linear(4096, 4096),
          nn.ReLU(inplace=True),
          nn.Linear(4096, n_output))
      
  def forward(self, x):
    x = self.features(x)
    x = x.view(x.size(0), -1)
    out = self.classifier(x)
    return out

微调预训练模型的思路为:直接保留原模型的features部分,重写classifier部分。在classifier部分中,我们实际需要修改的只有最后一层全连接层,之前的两个全连接层不需要修改,所以重写的时候需要把这两层的预训练权重和偏移保留下来,也可以像注释掉的两行代码里那样直接引用最后一层全连接层进行修改。

网络搭好之后可以小小的测试一下以检验维度是否正确。

import numpy as np
from torch.autograd import Variable
import torch

if __name__ == '__main__':
  model_type = 'pre'
  n_output = 10
  alexnet = BuildAlexNet(model_type, n_output)
  print(alexnet)
  
  x = np.random.rand(1,3,224,224)
  x = x.astype(np.float32)
  x_ts = torch.from_numpy(x)
  x_in = Variable(x_ts)
  y = alexnet(x_in)

这里如果不加“x = x.astype(np.float32)”的话会报一个类型错误,感觉有点奇怪。

输出y.data.numpy()可得10维输出,表明网络搭建正确。

以上这篇使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python命令行参数sys.argv使用示例
Jan 28 Python
Python3基础之函数用法
Aug 13 Python
python提取页面内url列表的方法
May 25 Python
利用PyInstaller将python程序.py转为.exe的方法详解
May 03 Python
Python2随机数列生成器简单实例
Sep 04 Python
django项目搭建与Session使用详解
Oct 10 Python
Python实现Event回调机制的方法
Feb 13 Python
对Python中class和instance以及self的用法详解
Jun 26 Python
Python使用uuid库生成唯一标识ID
Feb 12 Python
基于python实现微信好友数据分析(简单)
Feb 16 Python
Django ORM判断查询结果是否为空,判断django中的orm为空实例
Jul 09 Python
最新pycharm安装教程
Nov 18 Python
selenium 多窗口切换的实现(windows)
Jan 18 #Python
pytorch实现建立自己的数据集(以mnist为例)
Jan 18 #Python
使用PyTorch实现MNIST手写体识别代码
Jan 18 #Python
Pytorch之finetune使用详解
Jan 18 #Python
pytorch 修改预训练model实例
Jan 18 #Python
Pytorch自己加载单通道图片用作数据集训练的实例
Jan 18 #Python
pyinstaller 3.6版本通过pip安装失败的解决办法(推荐)
Jan 18 #Python
You might like
在WAMP环境下搭建ZendDebugger php调试工具的方法
2011/07/18 PHP
php数组函数序列之sort() 对数组的元素值进行升序排序
2011/11/02 PHP
PHP数组游标实现对数组的各种操作详解
2016/01/26 PHP
centos下file_put_contents()无法写入文件的原因及解决方法
2017/04/01 PHP
PHP封装的简单连接MongoDB类示例
2019/02/13 PHP
海量经典的jQuery插件集合
2010/01/12 Javascript
JavaScript 滚轮事件使用说明
2010/03/07 Javascript
JS定时刷新页面及跳转页面的方法
2013/07/04 Javascript
动态创建script在IE中缓存js文件时导致编码的解决方法
2014/05/04 Javascript
Bootstrap轮播图的使用和理解4
2016/12/14 Javascript
详解使用webpack打包编写一个vue-toast插件
2017/11/08 Javascript
在vue-cli搭建的项目中增加后台mock接口的方法
2018/04/26 Javascript
vue项目中,main.js,App.vue,index.html的调用方法
2018/09/20 Javascript
vue 右键菜单插件 简单、可扩展、样式自定义的右键菜单
2018/11/29 Javascript
实例讲解JavaScript截取字符串
2018/11/30 Javascript
通过说明与示例了解js五种设计模式
2019/06/17 Javascript
微信小程序防止多次点击跳转(函数节流)
2019/09/19 Javascript
python logging 日志轮转文件不删除问题的解决方法
2016/08/02 Python
由浅入深讲解python中的yield与generator
2017/04/05 Python
Python实现图片转字符画的示例代码
2017/08/21 Python
python实现人脸识别经典算法(一) 特征脸法
2018/03/13 Python
Flask入门之上传文件到服务器的方法示例
2018/07/18 Python
Python爬取数据保存为Json格式的代码示例
2019/04/09 Python
解决Python找不到ssl模块问题 No module named _ssl的方法
2019/04/29 Python
Pycharm如何打断点的方法步骤
2019/06/13 Python
深入了解Django中间件及其方法
2019/07/26 Python
Python socket 套接字实现通信详解
2019/08/27 Python
python能否java成为主流语言吗
2020/06/22 Python
为2021年的第一场雪锦上添花:用matplotlib绘制雪花和雪景
2021/01/05 Python
使用css3制作登录表单的步骤
2014/04/07 HTML / CSS
苹果香港官方商城:Apple香港
2016/09/14 全球购物
韩国保养品、日本药妆购物网:小三美日
2018/12/30 全球购物
卫生院健康教育实施方案
2014/06/07 职场文书
银行自荐信怎么写
2015/03/05 职场文书
浪漫婚礼主持词开场白
2015/11/24 职场文书
golang中的struct操作
2021/11/11 Golang