Pytorch实现GoogLeNet的方法


Posted in Python onAugust 18, 2019

GoogLeNet也叫InceptionNet,在2014年被提出,如今已到V4版本。GoogleNet比VGGNet具有更深的网络结构,一共有22层,但是参数比AlexNet要少12倍,但是计算量是AlexNet的4倍,原因就是它采用很有效的Inception模块,并且没有全连接层。

最重要的创新点就在于使用inception模块,通过使用不同维度的卷积提取不同尺度的特征图。左图是最初的Inception模块,右图是使用的1×1得卷积对左图的改进,降低了输入的特征图维度,同时降低了网络的参数量和计算复杂度,称为inception V1。

Pytorch实现GoogLeNet的方法

GoogleNet在架构设计上为保持低层为传统卷积方式不变,只在较高的层开始用Inception模块。

Pytorch实现GoogLeNet的方法

Pytorch实现GoogLeNet的方法

inception V2中将5x5的卷积改为2个3x3的卷积,扩大了感受野,原来是5x5,现在是6x6。Pytorch实现GoogLeNet(inception V2):

'''GoogLeNet with PyTorch.'''
import torch
import torch.nn as nn
import torch.nn.functional as F

# 编写卷积+bn+relu模块
class BasicConv2d(nn.Module):
  def __init__(self, in_channels, out_channals, **kwargs):
    super(BasicConv2d, self).__init__()
    self.conv = nn.Conv2d(in_channels, out_channals, **kwargs)
    self.bn = nn.BatchNorm2d(out_channals)

  def forward(self, x):
    x = self.conv(x)
    x = self.bn(x)
    return F.relu(x)

# 编写Inception模块
class Inception(nn.Module):
  def __init__(self, in_planes,
         n1x1, n3x3red, n3x3, n5x5red, n5x5, pool_planes):
    super(Inception, self).__init__()
    # 1x1 conv branch
    self.b1 = BasicConv2d(in_planes, n1x1, kernel_size=1)

    # 1x1 conv -> 3x3 conv branch
    self.b2_1x1_a = BasicConv2d(in_planes, n3x3red, 
                  kernel_size=1)
    self.b2_3x3_b = BasicConv2d(n3x3red, n3x3, 
                  kernel_size=3, padding=1)

    # 1x1 conv -> 3x3 conv -> 3x3 conv branch
    self.b3_1x1_a = BasicConv2d(in_planes, n5x5red, 
                  kernel_size=1)
    self.b3_3x3_b = BasicConv2d(n5x5red, n5x5, 
                  kernel_size=3, padding=1)
    self.b3_3x3_c = BasicConv2d(n5x5, n5x5, 
                  kernel_size=3, padding=1)

    # 3x3 pool -> 1x1 conv branch
    self.b4_pool = nn.MaxPool2d(3, stride=1, padding=1)
    self.b4_1x1 = BasicConv2d(in_planes, pool_planes, 
                 kernel_size=1)

  def forward(self, x):
    y1 = self.b1(x)
    y2 = self.b2_3x3_b(self.b2_1x1_a(x))
    y3 = self.b3_3x3_c(self.b3_3x3_b(self.b3_1x1_a(x)))
    y4 = self.b4_1x1(self.b4_pool(x))
    # y的维度为[batch_size, out_channels, C_out,L_out]
    # 合并不同卷积下的特征图
    return torch.cat([y1, y2, y3, y4], 1)


class GoogLeNet(nn.Module):
  def __init__(self):
    super(GoogLeNet, self).__init__()
    self.pre_layers = BasicConv2d(3, 192, 
                   kernel_size=3, padding=1)

    self.a3 = Inception(192, 64, 96, 128, 16, 32, 32)
    self.b3 = Inception(256, 128, 128, 192, 32, 96, 64)

    self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)

    self.a4 = Inception(480, 192, 96, 208, 16, 48, 64)
    self.b4 = Inception(512, 160, 112, 224, 24, 64, 64)
    self.c4 = Inception(512, 128, 128, 256, 24, 64, 64)
    self.d4 = Inception(512, 112, 144, 288, 32, 64, 64)
    self.e4 = Inception(528, 256, 160, 320, 32, 128, 128)

    self.a5 = Inception(832, 256, 160, 320, 32, 128, 128)
    self.b5 = Inception(832, 384, 192, 384, 48, 128, 128)

    self.avgpool = nn.AvgPool2d(8, stride=1)
    self.linear = nn.Linear(1024, 10)

  def forward(self, x):
    out = self.pre_layers(x)
    out = self.a3(out)
    out = self.b3(out)
    out = self.maxpool(out)
    out = self.a4(out)
    out = self.b4(out)
    out = self.c4(out)
    out = self.d4(out)
    out = self.e4(out)
    out = self.maxpool(out)
    out = self.a5(out)
    out = self.b5(out)
    out = self.avgpool(out)
    out = out.view(out.size(0), -1)
    out = self.linear(out)
    return out


def test():
  net = GoogLeNet()
  x = torch.randn(1,3,32,32)
  y = net(x)
  print(y.size())

test()

以上这篇Pytorch实现GoogLeNet的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python多线程编程(七):使用Condition实现复杂同步
Apr 05 Python
Python制作爬虫抓取美女图
Jan 20 Python
基于python yield机制的异步操作同步化编程模型
Mar 18 Python
Python读取Json字典写入Excel表格的方法
Jan 03 Python
Python实现Kmeans聚类算法
Jun 10 Python
使用Python写一个量化股票提醒系统
Aug 22 Python
Python实用工具FuckIt.py介绍
Jul 02 Python
对python中基于tcp协议的通信(数据传输)实例讲解
Jul 22 Python
使用pyqt5 tablewidget 单元格设置正则表达式
Dec 13 Python
春节到了 教你使用python来抢票回家
Jan 06 Python
Python实现敏感词过滤的4种方法
Sep 12 Python
记一次Django响应超慢的解决过程
Sep 17 Python
PyTorch之图像和Tensor填充的实例
Aug 18 #Python
Pytorch Tensor的索引与切片例子
Aug 18 #Python
在PyTorch中Tensor的查找和筛选例子
Aug 18 #Python
对Pytorch神经网络初始化kaiming分布详解
Aug 18 #Python
pytorch中的embedding词向量的使用方法
Aug 18 #Python
Pytorch加载部分预训练模型的参数实例
Aug 18 #Python
在pytorch中查看可训练参数的例子
Aug 18 #Python
You might like
在mysql数据库原有字段后增加新内容
2009/11/26 PHP
php防止网站被攻击的应急代码
2015/10/21 PHP
javascript 通用简单的table选项卡实现
2010/05/07 Javascript
JavaScript基本语法讲解
2015/06/03 Javascript
jQuery+HTML5实现手机摇一摇换衣特效
2015/06/05 Javascript
纯javascript实现四方向文本无缝滚动效果
2015/06/16 Javascript
vue2.0在没有dev-server.js下的本地数据配置方法
2018/02/23 Javascript
vue.js使用3DES加密的方法示例
2018/05/18 Javascript
详解Vue2.0组件的继承与扩展
2018/11/23 Javascript
基于Vue SEO的四种方案(小结)
2019/07/01 Javascript
js脚本中执行java后台代码方法解析
2019/10/11 Javascript
微信小程序 this.triggerEvent()的具体使用
2019/12/10 Javascript
js 将多个对象合并成一个对象 assign方法的实现
2020/09/24 Javascript
[57:55]完美世界DOTA2联赛PWL S3 Magma vs Phoenix 第二场 12.12
2020/12/16 DOTA
Python使用面向对象方式创建线程实现12306售票系统
2015/12/24 Python
Python 中的 else详解
2016/04/23 Python
Python线程同步的实现代码
2018/10/03 Python
python实现扫描ip地址的小程序
2019/04/16 Python
微信公众号token验证失败解决方案
2019/07/22 Python
python 实现控制鼠标键盘
2020/11/27 Python
世界上最大的餐具公司:Oneida
2016/12/17 全球购物
Brother加拿大官网:打印机、贴标机、缝纫机
2019/10/09 全球购物
在职人员函授期间自我评价分享
2013/11/08 职场文书
简单而又朴实的个人求职信分享
2013/12/12 职场文书
七一表彰活动方案
2014/01/18 职场文书
小学生打架检讨书
2014/01/26 职场文书
护士岗位职责
2014/02/16 职场文书
公司年会搞笑主持词
2014/03/24 职场文书
养牛场项目建议书
2014/05/13 职场文书
保护环境建议书100字
2014/05/13 职场文书
设备售后服务承诺书
2014/05/30 职场文书
植物生产学专业求职信
2014/08/08 职场文书
五好文明家庭事迹材料
2014/12/20 职场文书
趣味运动会赞词
2015/07/22 职场文书
MySQL表锁、行锁、排它锁及共享锁的使用详解
2022/04/02 MySQL
AndroidStudio图片压缩工具ImgCompressPlugin使用实例
2022/08/05 Java/Android