使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建)


Posted in Python onJanuary 18, 2020

本文介绍了如何在pytorch下搭建AlexNet,使用了两种方法,一种是直接加载预训练模型,并根据自己的需要微调(将最后一层全连接层输出由1000改为10),另一种是手动搭建。

构建模型类的时候需要继承自torch.nn.Module类,要自己重写__ \_\___init__ \_\___方法和正向传递时的forward方法,这里我自己的理解是,搭建网络写在__ \_\___init__ \_\___中,每次正向传递需要计算的部分写在forward中,例如把矩阵压平之类的。

加载预训练alexnet之后,可以print出来查看模型的结构及信息:

使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建)

model = models.alexnet(pretrained=True)
print(model)

分为两个部分,features及classifier,后续搭建模型时可以也写成这两部分,并且从打印出来的模型信息中也可以看出每一层的引用方式,便于修改,例如model.classifier[1]指的就是Linear(in_features=9216, out_features=4096, bias=True)这层。

下面放出完整的搭建代码:

import torch.nn as nn
from torchvision import models

class BuildAlexNet(nn.Module):
  def __init__(self, model_type, n_output):
    super(BuildAlexNet, self).__init__()
    self.model_type = model_type
    if model_type == 'pre':
      model = models.alexnet(pretrained=True)
      self.features = model.features
      fc1 = nn.Linear(9216, 4096)
      fc1.bias = model.classifier[1].bias
      fc1.weight = model.classifier[1].weight
      
      fc2 = nn.Linear(4096, 4096)
      fc2.bias = model.classifier[4].bias
      fc2.weight = model.classifier[4].weight
      
      self.classifier = nn.Sequential(
          nn.Dropout(),
          fc1,
          nn.ReLU(inplace=True),
          nn.Dropout(),
          fc2,
          nn.ReLU(inplace=True),
          nn.Linear(4096, n_output)) 
      #或者直接修改为
#      model.classifier[6]==nn.Linear(4096,n_output)
#      self.classifier = model.classifier
    if model_type == 'new':
      self.features = nn.Sequential(
          nn.Conv2d(3, 64, 11, 4, 2),
          nn.ReLU(inplace = True),
          nn.MaxPool2d(3, 2, 0),
          nn.Conv2d(64, 192, 5, 1, 2),
          nn.ReLU(inplace=True),
          nn.MaxPool2d(3, 2, 0),
          nn.Conv2d(192, 384, 3, 1, 1),
          nn.ReLU(inplace = True),
          nn.Conv2d(384, 256, 3, 1, 1),
          nn.ReLU(inplace=True),
          nn.MaxPool2d(3, 2, 0))
      self.classifier = nn.Sequential(
          nn.Dropout(),
          nn.Linear(9216, 4096),
          nn.ReLU(inplace=True),
          nn.Dropout(),
          nn.Linear(4096, 4096),
          nn.ReLU(inplace=True),
          nn.Linear(4096, n_output))
      
  def forward(self, x):
    x = self.features(x)
    x = x.view(x.size(0), -1)
    out = self.classifier(x)
    return out

微调预训练模型的思路为:直接保留原模型的features部分,重写classifier部分。在classifier部分中,我们实际需要修改的只有最后一层全连接层,之前的两个全连接层不需要修改,所以重写的时候需要把这两层的预训练权重和偏移保留下来,也可以像注释掉的两行代码里那样直接引用最后一层全连接层进行修改。

网络搭好之后可以小小的测试一下以检验维度是否正确。

import numpy as np
from torch.autograd import Variable
import torch

if __name__ == '__main__':
  model_type = 'pre'
  n_output = 10
  alexnet = BuildAlexNet(model_type, n_output)
  print(alexnet)
  
  x = np.random.rand(1,3,224,224)
  x = x.astype(np.float32)
  x_ts = torch.from_numpy(x)
  x_in = Variable(x_ts)
  y = alexnet(x_in)

这里如果不加“x = x.astype(np.float32)”的话会报一个类型错误,感觉有点奇怪。

输出y.data.numpy()可得10维输出,表明网络搭建正确。

以上这篇使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Python中PDB模块中的命令来调试Python代码的教程
Mar 30 Python
整理Python最基本的操作字典的方法
Apr 24 Python
Python中正则表达式详解
May 17 Python
Win7下Python与Tensorflow-CPU版开发环境的安装与配置过程
Jan 04 Python
python中for用来遍历range函数的方法
Jun 08 Python
django配置连接数据库及原生sql语句的使用方法
Mar 03 Python
python3利用Socket实现通信的方法示例
May 06 Python
PyCharm导入python项目并配置虚拟环境的教程详解
Oct 13 Python
python函数不定长参数使用方法解析
Dec 14 Python
基于Python和C++实现删除链表的节点
Jul 06 Python
python3.9实现pyinstaller打包python文件成exe
Dec 13 Python
python某漫画app逆向
Mar 31 Python
selenium 多窗口切换的实现(windows)
Jan 18 #Python
pytorch实现建立自己的数据集(以mnist为例)
Jan 18 #Python
使用PyTorch实现MNIST手写体识别代码
Jan 18 #Python
Pytorch之finetune使用详解
Jan 18 #Python
pytorch 修改预训练model实例
Jan 18 #Python
Pytorch自己加载单通道图片用作数据集训练的实例
Jan 18 #Python
pyinstaller 3.6版本通过pip安装失败的解决办法(推荐)
Jan 18 #Python
You might like
php分页函数
2006/07/08 PHP
php下使用curl模拟用户登陆的代码
2010/09/10 PHP
深入for,while,foreach遍历时间比较的详解
2013/06/08 PHP
jQuery插件-jRating评分插件源码分析及使用方法
2012/12/28 Javascript
JS获取html对象的几种方式介绍
2013/12/05 Javascript
js加密解密字符串可自定义密码因子
2014/05/13 Javascript
你所未知的3种Node.js代码优化方式
2016/02/25 Javascript
jQuery动态添加
2016/04/07 Javascript
微信小程序 网络请求(GET请求)详解
2016/11/16 Javascript
全面解析vue中的数据双向绑定
2017/05/10 Javascript
jQuery实现打开网页自动弹出遮罩层或点击弹出遮罩层功能示例
2017/10/19 jQuery
微信小程序多音频播放进度条问题
2018/08/28 Javascript
Vue.set() this.$set()引发的视图更新思考及注意事项
2018/08/30 Javascript
vue中 数字相加为字串转化为数值的例子
2019/11/07 Javascript
Preload基础使用方法详解
2020/02/03 Javascript
解决echarts 一条柱状图显示两个值,类似进度条的问题
2020/07/20 Javascript
Python字符转换
2008/09/06 Python
Python读写Excel文件方法介绍
2014/11/22 Python
Python中Django框架利用url来控制登录的方法
2015/07/25 Python
利用Python批量压缩png方法实例(支持过滤个别文件与文件夹)
2017/07/30 Python
python cx_Oracle的基础使用方法(连接和增删改查)
2017/11/19 Python
python处理DICOM并计算三维模型体积
2019/02/26 Python
PyQt5创建一个新窗口的实例
2019/06/20 Python
详解Python 多线程 Timer定时器/延迟执行、Event事件
2019/06/27 Python
在PyCharm中实现添加快捷模块
2020/02/12 Python
python操作docx写入内容,并控制文本的字体颜色
2020/02/13 Python
全方位了解CSS3的Regions扩展
2015/08/07 HTML / CSS
CSS3实现歌词进度文字颜色填充变化动态效果的思路详解
2020/06/02 HTML / CSS
CSS 说明横向进度条最后显示文字的实现代码
2020/11/10 HTML / CSS
世界上最好的野生海鲜和有机食品:Vital Choice
2020/01/16 全球购物
个人四风问题对照检查材料思想汇报
2014/10/06 职场文书
2014年财政所工作总结
2014/11/22 职场文书
龙猫观后感
2015/06/09 职场文书
小学家庭教育心得体会
2016/01/14 职场文书
Python数据处理的三个实用技巧分享
2022/04/01 Python
使用Nginx的访问日志统计PV与UV
2022/05/06 Servers