使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建)


Posted in Python onJanuary 18, 2020

本文介绍了如何在pytorch下搭建AlexNet,使用了两种方法,一种是直接加载预训练模型,并根据自己的需要微调(将最后一层全连接层输出由1000改为10),另一种是手动搭建。

构建模型类的时候需要继承自torch.nn.Module类,要自己重写__ \_\___init__ \_\___方法和正向传递时的forward方法,这里我自己的理解是,搭建网络写在__ \_\___init__ \_\___中,每次正向传递需要计算的部分写在forward中,例如把矩阵压平之类的。

加载预训练alexnet之后,可以print出来查看模型的结构及信息:

使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建)

model = models.alexnet(pretrained=True)
print(model)

分为两个部分,features及classifier,后续搭建模型时可以也写成这两部分,并且从打印出来的模型信息中也可以看出每一层的引用方式,便于修改,例如model.classifier[1]指的就是Linear(in_features=9216, out_features=4096, bias=True)这层。

下面放出完整的搭建代码:

import torch.nn as nn
from torchvision import models

class BuildAlexNet(nn.Module):
  def __init__(self, model_type, n_output):
    super(BuildAlexNet, self).__init__()
    self.model_type = model_type
    if model_type == 'pre':
      model = models.alexnet(pretrained=True)
      self.features = model.features
      fc1 = nn.Linear(9216, 4096)
      fc1.bias = model.classifier[1].bias
      fc1.weight = model.classifier[1].weight
      
      fc2 = nn.Linear(4096, 4096)
      fc2.bias = model.classifier[4].bias
      fc2.weight = model.classifier[4].weight
      
      self.classifier = nn.Sequential(
          nn.Dropout(),
          fc1,
          nn.ReLU(inplace=True),
          nn.Dropout(),
          fc2,
          nn.ReLU(inplace=True),
          nn.Linear(4096, n_output)) 
      #或者直接修改为
#      model.classifier[6]==nn.Linear(4096,n_output)
#      self.classifier = model.classifier
    if model_type == 'new':
      self.features = nn.Sequential(
          nn.Conv2d(3, 64, 11, 4, 2),
          nn.ReLU(inplace = True),
          nn.MaxPool2d(3, 2, 0),
          nn.Conv2d(64, 192, 5, 1, 2),
          nn.ReLU(inplace=True),
          nn.MaxPool2d(3, 2, 0),
          nn.Conv2d(192, 384, 3, 1, 1),
          nn.ReLU(inplace = True),
          nn.Conv2d(384, 256, 3, 1, 1),
          nn.ReLU(inplace=True),
          nn.MaxPool2d(3, 2, 0))
      self.classifier = nn.Sequential(
          nn.Dropout(),
          nn.Linear(9216, 4096),
          nn.ReLU(inplace=True),
          nn.Dropout(),
          nn.Linear(4096, 4096),
          nn.ReLU(inplace=True),
          nn.Linear(4096, n_output))
      
  def forward(self, x):
    x = self.features(x)
    x = x.view(x.size(0), -1)
    out = self.classifier(x)
    return out

微调预训练模型的思路为:直接保留原模型的features部分,重写classifier部分。在classifier部分中,我们实际需要修改的只有最后一层全连接层,之前的两个全连接层不需要修改,所以重写的时候需要把这两层的预训练权重和偏移保留下来,也可以像注释掉的两行代码里那样直接引用最后一层全连接层进行修改。

网络搭好之后可以小小的测试一下以检验维度是否正确。

import numpy as np
from torch.autograd import Variable
import torch

if __name__ == '__main__':
  model_type = 'pre'
  n_output = 10
  alexnet = BuildAlexNet(model_type, n_output)
  print(alexnet)
  
  x = np.random.rand(1,3,224,224)
  x = x.astype(np.float32)
  x_ts = torch.from_numpy(x)
  x_in = Variable(x_ts)
  y = alexnet(x_in)

这里如果不加“x = x.astype(np.float32)”的话会报一个类型错误,感觉有点奇怪。

输出y.data.numpy()可得10维输出,表明网络搭建正确。

以上这篇使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python strip lstrip rstrip使用方法
Sep 06 Python
pyqt4教程之messagebox使用示例分享
Mar 07 Python
学习python类方法与对象方法
Mar 15 Python
python 数据清洗之数据合并、转换、过滤、排序
Feb 12 Python
python实现字典(dict)和字符串(string)的相互转换方法
Mar 01 Python
PyQt5每天必学之拖放事件
Aug 27 Python
Python中的pathlib.Path为什么不继承str详解
Jun 23 Python
python读取.mat文件的数据及实例代码
Jul 12 Python
python-tornado的接口用swagger进行包装的实例
Aug 29 Python
执行Python程序时模块报错问题
Mar 26 Python
Python pytesseract验证码识别库用法解析
Jun 29 Python
Python打印不合法的文件名
Jul 31 Python
selenium 多窗口切换的实现(windows)
Jan 18 #Python
pytorch实现建立自己的数据集(以mnist为例)
Jan 18 #Python
使用PyTorch实现MNIST手写体识别代码
Jan 18 #Python
Pytorch之finetune使用详解
Jan 18 #Python
pytorch 修改预训练model实例
Jan 18 #Python
Pytorch自己加载单通道图片用作数据集训练的实例
Jan 18 #Python
pyinstaller 3.6版本通过pip安装失败的解决办法(推荐)
Jan 18 #Python
You might like
PHP最常用的ini函数分析 针对PHP.ini配置文件
2010/04/22 PHP
PHP利用REFERER根居访问来地址进行页面跳转
2013/09/28 PHP
解密ThinkPHP3.1.2版本之独立分组功能应用
2014/06/19 PHP
PHP生成短网址的3种方法代码实例
2014/07/08 PHP
PHP实现递归复制整个文件夹的类实例
2015/08/03 PHP
PHP中file_put_contents追加和换行的实现方法
2017/04/01 PHP
iis 7下安装laravel 5.4环境的方法教程
2017/06/14 PHP
两个SUBMIT按钮,如何区分处理
2006/08/22 Javascript
Javascript和Ajax中文乱码吐血版解决方案
2009/12/21 Javascript
通过jQuery源码学习javascript(一)
2012/12/27 Javascript
JavaScript中的原型和继承详解(图文)
2014/07/18 Javascript
jQuery实现tag便签去重效果的方法
2015/01/20 Javascript
jQuery插件Elastislide实现响应式的焦点图无缝滚动切换特效
2015/04/12 Javascript
jquery+css3实现网页背景花瓣随机飘落特效
2015/08/17 Javascript
javascript日期格式化方法汇总
2015/10/04 Javascript
jQuery中通过ajax的get()函数读取页面的方法
2016/02/29 Javascript
jquery 无限极下拉菜单的简单实例(精简浓缩版)
2016/05/31 Javascript
js 提交form表单和设置form表单请求路径的实现方法
2016/10/25 Javascript
在js里怎么实现Xcode里的callFuncN方法(详解)
2016/11/05 Javascript
jquery 实时监听输入框值变化的完美方法(必看)
2017/01/26 Javascript
详解JavaScript对象的深浅复制
2017/03/30 Javascript
js学习总结之DOM2兼容处理重复问题的解决方法
2017/07/27 Javascript
为vue-router懒加载时下载js的过程中添加loading提示避免无响应问题
2018/04/03 Javascript
vue自定义指令之面板拖拽的实现
2019/04/14 Javascript
vue中v-show和v-if的异同及v-show用法
2019/06/06 Javascript
如何在vue中使用kindeditor富文本编辑器
2020/12/19 Vue.js
解决uWSGI的编码问题详解
2017/03/24 Python
Ubuntu 下 vim 搭建python 环境 配置
2017/06/12 Python
pip install 使用国内镜像的方法示例
2020/04/03 Python
Python定时从Mysql提取数据存入Redis的实现
2020/05/03 Python
Python新手学习装饰器
2020/06/04 Python
俄罗斯购买剧院和演唱会门票网站:Parter.ru
2019/11/09 全球购物
2014物价局民主生活会对照检查材料思想汇报
2014/09/24 职场文书
对党的十八届四中全会的期盼
2014/10/17 职场文书
使用Redis实现分布式锁的方法
2022/06/16 Redis
Python使用pandas导入csv文件内容的示例代码
2022/12/24 Python