使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建)


Posted in Python onJanuary 18, 2020

本文介绍了如何在pytorch下搭建AlexNet,使用了两种方法,一种是直接加载预训练模型,并根据自己的需要微调(将最后一层全连接层输出由1000改为10),另一种是手动搭建。

构建模型类的时候需要继承自torch.nn.Module类,要自己重写__ \_\___init__ \_\___方法和正向传递时的forward方法,这里我自己的理解是,搭建网络写在__ \_\___init__ \_\___中,每次正向传递需要计算的部分写在forward中,例如把矩阵压平之类的。

加载预训练alexnet之后,可以print出来查看模型的结构及信息:

使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建)

model = models.alexnet(pretrained=True)
print(model)

分为两个部分,features及classifier,后续搭建模型时可以也写成这两部分,并且从打印出来的模型信息中也可以看出每一层的引用方式,便于修改,例如model.classifier[1]指的就是Linear(in_features=9216, out_features=4096, bias=True)这层。

下面放出完整的搭建代码:

import torch.nn as nn
from torchvision import models

class BuildAlexNet(nn.Module):
  def __init__(self, model_type, n_output):
    super(BuildAlexNet, self).__init__()
    self.model_type = model_type
    if model_type == 'pre':
      model = models.alexnet(pretrained=True)
      self.features = model.features
      fc1 = nn.Linear(9216, 4096)
      fc1.bias = model.classifier[1].bias
      fc1.weight = model.classifier[1].weight
      
      fc2 = nn.Linear(4096, 4096)
      fc2.bias = model.classifier[4].bias
      fc2.weight = model.classifier[4].weight
      
      self.classifier = nn.Sequential(
          nn.Dropout(),
          fc1,
          nn.ReLU(inplace=True),
          nn.Dropout(),
          fc2,
          nn.ReLU(inplace=True),
          nn.Linear(4096, n_output)) 
      #或者直接修改为
#      model.classifier[6]==nn.Linear(4096,n_output)
#      self.classifier = model.classifier
    if model_type == 'new':
      self.features = nn.Sequential(
          nn.Conv2d(3, 64, 11, 4, 2),
          nn.ReLU(inplace = True),
          nn.MaxPool2d(3, 2, 0),
          nn.Conv2d(64, 192, 5, 1, 2),
          nn.ReLU(inplace=True),
          nn.MaxPool2d(3, 2, 0),
          nn.Conv2d(192, 384, 3, 1, 1),
          nn.ReLU(inplace = True),
          nn.Conv2d(384, 256, 3, 1, 1),
          nn.ReLU(inplace=True),
          nn.MaxPool2d(3, 2, 0))
      self.classifier = nn.Sequential(
          nn.Dropout(),
          nn.Linear(9216, 4096),
          nn.ReLU(inplace=True),
          nn.Dropout(),
          nn.Linear(4096, 4096),
          nn.ReLU(inplace=True),
          nn.Linear(4096, n_output))
      
  def forward(self, x):
    x = self.features(x)
    x = x.view(x.size(0), -1)
    out = self.classifier(x)
    return out

微调预训练模型的思路为:直接保留原模型的features部分,重写classifier部分。在classifier部分中,我们实际需要修改的只有最后一层全连接层,之前的两个全连接层不需要修改,所以重写的时候需要把这两层的预训练权重和偏移保留下来,也可以像注释掉的两行代码里那样直接引用最后一层全连接层进行修改。

网络搭好之后可以小小的测试一下以检验维度是否正确。

import numpy as np
from torch.autograd import Variable
import torch

if __name__ == '__main__':
  model_type = 'pre'
  n_output = 10
  alexnet = BuildAlexNet(model_type, n_output)
  print(alexnet)
  
  x = np.random.rand(1,3,224,224)
  x = x.astype(np.float32)
  x_ts = torch.from_numpy(x)
  x_in = Variable(x_ts)
  y = alexnet(x_in)

这里如果不加“x = x.astype(np.float32)”的话会报一个类型错误,感觉有点奇怪。

输出y.data.numpy()可得10维输出,表明网络搭建正确。

以上这篇使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中非常实用的一些功能和函数分享
Feb 14 Python
python打开文件并获取文件相关属性的方法
Apr 23 Python
python正常时间和unix时间戳相互转换的方法
Apr 23 Python
在Python中使用next()方法操作文件的教程
May 24 Python
python中计算一个列表中连续相同的元素个数方法
Jun 29 Python
python 实现倒排索引的方法
Dec 25 Python
Django框架实现的分页demo示例
May 25 Python
浅谈matplotlib.pyplot与axes的关系
Mar 06 Python
QML实现钟表效果
Jun 02 Python
Python如何将字符串转换为日期
Jul 31 Python
python3.7中安装paddleocr及paddlepaddle包的多种方法
Nov 27 Python
Python实现对齐打印 format函数的用法
Apr 28 Python
selenium 多窗口切换的实现(windows)
Jan 18 #Python
pytorch实现建立自己的数据集(以mnist为例)
Jan 18 #Python
使用PyTorch实现MNIST手写体识别代码
Jan 18 #Python
Pytorch之finetune使用详解
Jan 18 #Python
pytorch 修改预训练model实例
Jan 18 #Python
Pytorch自己加载单通道图片用作数据集训练的实例
Jan 18 #Python
pyinstaller 3.6版本通过pip安装失败的解决办法(推荐)
Jan 18 #Python
You might like
php版淘宝网查询商品接口代码示例
2014/06/17 PHP
PHP实现图片的等比缩放和Logo水印功能示例
2017/05/04 PHP
IE下js调试工具Companion.JS
2010/10/15 Javascript
基于jquery的图片幻灯展示源码
2012/07/15 Javascript
调用HttpHanlder的几种返回方式小结
2013/12/20 Javascript
多种方法实现load加载完成后把图片一次性显示出来
2014/02/19 Javascript
js创建表单元素并使用submit进行提交
2014/08/14 Javascript
js设置控件的隐藏与显示的两种方法
2014/08/21 Javascript
浅谈被jQuery抛弃的函数及替代函数
2015/05/03 Javascript
uploadify多文件上传参数设置技巧
2015/11/16 Javascript
Bootstrap网格系统详解
2016/04/26 Javascript
TinyMCE汉化及本地上传图片功能实例详解
2016/05/31 Javascript
学习vue.js条件渲染
2016/12/03 Javascript
AngularJS Select(选择框)使用详解
2017/01/18 Javascript
VUE元素的隐藏和显示(v-show指令)
2017/06/23 Javascript
BootStrap模态框不垂直居中的解决方法
2017/10/19 Javascript
jQuery实现模糊搜索功能的方法分析
2018/06/29 jQuery
PWA介绍及快速上手搭建一个PWA应用的方法
2019/01/27 Javascript
[01:35]2014DOTA2西雅图邀请赛 专访狐狸妈青春献给刀塔
2014/07/08 DOTA
[01:54]TI4西雅图DOTA2选手欢迎晚宴 现场报道
2014/07/08 DOTA
[53:43]VP vs NewBee Supermajor 胜者组 BO3 第三场 6.5
2018/06/06 DOTA
Python Tkinter简单布局实例教程
2014/09/03 Python
Python微信库:itchat的用法详解
2017/08/14 Python
Python使用pip安装报错:is not a supported wheel on this platform的解决方法
2018/01/23 Python
Django教程笔记之中间件middleware详解
2018/08/01 Python
Python3 安装PyQt5及exe打包图文教程
2019/01/08 Python
Python实现读取txt文件中的数据并绘制出图形操作示例
2019/02/26 Python
flask框架中的cookie和session使用
2021/01/31 Python
澳大利亚UGG工厂直销:Australian Ugg Boots
2017/10/14 全球购物
丝绸和人造花卉、植物和树木:Nearly Natural
2018/11/28 全球购物
应聘教师求职信
2014/07/19 职场文书
教师党员群众路线教育实践活动心得体会
2014/11/04 职场文书
2014年办公室文员工作总结
2014/11/12 职场文书
反邪教警示教育活动总结
2015/05/09 职场文书
《我的长生果》教学反思
2016/02/20 职场文书
Windows11 Insider Preview Build 25206今日发布 更新内容汇总
2022/09/23 数码科技