Pytorch实现GoogLeNet的方法


Posted in Python onAugust 18, 2019

GoogLeNet也叫InceptionNet,在2014年被提出,如今已到V4版本。GoogleNet比VGGNet具有更深的网络结构,一共有22层,但是参数比AlexNet要少12倍,但是计算量是AlexNet的4倍,原因就是它采用很有效的Inception模块,并且没有全连接层。

最重要的创新点就在于使用inception模块,通过使用不同维度的卷积提取不同尺度的特征图。左图是最初的Inception模块,右图是使用的1×1得卷积对左图的改进,降低了输入的特征图维度,同时降低了网络的参数量和计算复杂度,称为inception V1。

Pytorch实现GoogLeNet的方法

GoogleNet在架构设计上为保持低层为传统卷积方式不变,只在较高的层开始用Inception模块。

Pytorch实现GoogLeNet的方法

Pytorch实现GoogLeNet的方法

inception V2中将5x5的卷积改为2个3x3的卷积,扩大了感受野,原来是5x5,现在是6x6。Pytorch实现GoogLeNet(inception V2):

'''GoogLeNet with PyTorch.'''
import torch
import torch.nn as nn
import torch.nn.functional as F

# 编写卷积+bn+relu模块
class BasicConv2d(nn.Module):
  def __init__(self, in_channels, out_channals, **kwargs):
    super(BasicConv2d, self).__init__()
    self.conv = nn.Conv2d(in_channels, out_channals, **kwargs)
    self.bn = nn.BatchNorm2d(out_channals)

  def forward(self, x):
    x = self.conv(x)
    x = self.bn(x)
    return F.relu(x)

# 编写Inception模块
class Inception(nn.Module):
  def __init__(self, in_planes,
         n1x1, n3x3red, n3x3, n5x5red, n5x5, pool_planes):
    super(Inception, self).__init__()
    # 1x1 conv branch
    self.b1 = BasicConv2d(in_planes, n1x1, kernel_size=1)

    # 1x1 conv -> 3x3 conv branch
    self.b2_1x1_a = BasicConv2d(in_planes, n3x3red, 
                  kernel_size=1)
    self.b2_3x3_b = BasicConv2d(n3x3red, n3x3, 
                  kernel_size=3, padding=1)

    # 1x1 conv -> 3x3 conv -> 3x3 conv branch
    self.b3_1x1_a = BasicConv2d(in_planes, n5x5red, 
                  kernel_size=1)
    self.b3_3x3_b = BasicConv2d(n5x5red, n5x5, 
                  kernel_size=3, padding=1)
    self.b3_3x3_c = BasicConv2d(n5x5, n5x5, 
                  kernel_size=3, padding=1)

    # 3x3 pool -> 1x1 conv branch
    self.b4_pool = nn.MaxPool2d(3, stride=1, padding=1)
    self.b4_1x1 = BasicConv2d(in_planes, pool_planes, 
                 kernel_size=1)

  def forward(self, x):
    y1 = self.b1(x)
    y2 = self.b2_3x3_b(self.b2_1x1_a(x))
    y3 = self.b3_3x3_c(self.b3_3x3_b(self.b3_1x1_a(x)))
    y4 = self.b4_1x1(self.b4_pool(x))
    # y的维度为[batch_size, out_channels, C_out,L_out]
    # 合并不同卷积下的特征图
    return torch.cat([y1, y2, y3, y4], 1)


class GoogLeNet(nn.Module):
  def __init__(self):
    super(GoogLeNet, self).__init__()
    self.pre_layers = BasicConv2d(3, 192, 
                   kernel_size=3, padding=1)

    self.a3 = Inception(192, 64, 96, 128, 16, 32, 32)
    self.b3 = Inception(256, 128, 128, 192, 32, 96, 64)

    self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)

    self.a4 = Inception(480, 192, 96, 208, 16, 48, 64)
    self.b4 = Inception(512, 160, 112, 224, 24, 64, 64)
    self.c4 = Inception(512, 128, 128, 256, 24, 64, 64)
    self.d4 = Inception(512, 112, 144, 288, 32, 64, 64)
    self.e4 = Inception(528, 256, 160, 320, 32, 128, 128)

    self.a5 = Inception(832, 256, 160, 320, 32, 128, 128)
    self.b5 = Inception(832, 384, 192, 384, 48, 128, 128)

    self.avgpool = nn.AvgPool2d(8, stride=1)
    self.linear = nn.Linear(1024, 10)

  def forward(self, x):
    out = self.pre_layers(x)
    out = self.a3(out)
    out = self.b3(out)
    out = self.maxpool(out)
    out = self.a4(out)
    out = self.b4(out)
    out = self.c4(out)
    out = self.d4(out)
    out = self.e4(out)
    out = self.maxpool(out)
    out = self.a5(out)
    out = self.b5(out)
    out = self.avgpool(out)
    out = out.view(out.size(0), -1)
    out = self.linear(out)
    return out


def test():
  net = GoogLeNet()
  x = torch.randn(1,3,32,32)
  y = net(x)
  print(y.size())

test()

以上这篇Pytorch实现GoogLeNet的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
浅析python 中__name__ = '__main__' 的作用
Jul 05 Python
python爬虫获取京东手机图片的图文教程
Dec 29 Python
查看Django和flask版本的方法
May 14 Python
python实现根据指定字符截取对应的行的内容方法
Oct 23 Python
深入了解和应用Python 装饰器 @decorator
Apr 02 Python
python反转列表的三种方式解析
Nov 08 Python
kafka-python 获取topic lag值方式
Dec 23 Python
mac 上配置Pycharm连接远程服务器并实现使用远程服务器Python解释器的方法
Mar 19 Python
如何利用python正则表达式匹配版本信息
Dec 09 Python
详解Python中@staticmethod和@classmethod区别及使用示例代码
Dec 14 Python
Python中的面向接口编程示例详解
Jan 17 Python
Python3爬虫ChromeDriver的安装实例
Feb 06 Python
PyTorch之图像和Tensor填充的实例
Aug 18 #Python
Pytorch Tensor的索引与切片例子
Aug 18 #Python
在PyTorch中Tensor的查找和筛选例子
Aug 18 #Python
对Pytorch神经网络初始化kaiming分布详解
Aug 18 #Python
pytorch中的embedding词向量的使用方法
Aug 18 #Python
Pytorch加载部分预训练模型的参数实例
Aug 18 #Python
在pytorch中查看可训练参数的例子
Aug 18 #Python
You might like
让你同时上传 1000 个文件 (二)
2006/10/09 PHP
php mysql索引问题
2008/06/07 PHP
php 操作调试的方法
2012/07/12 PHP
php添加文章时生成静态HTML文章的实现代码
2013/02/17 PHP
php的hash算法介绍
2014/02/13 PHP
PHP mkdir()无写权限的问题解决方法
2014/06/19 PHP
PHP内核探索之变量
2015/12/22 PHP
使用PHP下载CSS文件中的所有图片【几行代码即可实现】
2016/12/14 PHP
Thinkphp5 微信公众号token验证不成功的原因及解决方法
2017/11/12 PHP
Javascript学习笔记8 用JSON做原型
2010/01/11 Javascript
基于jquery的Repeater实现代码
2010/07/17 Javascript
javascript 终止函数执行操作
2014/02/14 Javascript
javascript中的undefined和not defined区别示例介绍
2014/02/26 Javascript
JavaScript实现网页截图功能
2014/10/16 Javascript
javascript制作2048游戏
2015/03/30 Javascript
javascript下拉列表菜单的实现方法
2015/11/18 Javascript
Angular4表单验证代码详解
2017/09/03 Javascript
JavaScript实现的贝塞尔曲线算法简单示例
2018/01/30 Javascript
在vue中,v-for的索引index在html中的使用方法
2018/03/06 Javascript
基于ionic实现下拉刷新功能
2018/05/10 Javascript
微信小程序自定义可滑动日历界面
2018/12/28 Javascript
Node.js中出现未捕获异常的处理方法
2020/06/29 Javascript
[19:24]DOTA2客户端使用指南 一分钟快速设置轻松超神
2013/09/24 DOTA
Python实现删除Android工程中的冗余字符串
2015/01/19 Python
Python实现数据库并行读取和写入实例
2017/06/09 Python
Python实现加载及解析properties配置文件的方法
2018/03/29 Python
python3 assert 断言的使用详解 (区别于python2)
2019/11/27 Python
解决Python pip 自动更新升级失败的问题
2020/02/21 Python
pandas数据拼接的实现示例
2020/04/16 Python
VSCode配合pipenv搞定虚拟环境的实现方法
2020/05/17 Python
浅谈python出错时traceback的解读
2020/07/15 Python
Python实现壁纸下载与轮换
2020/10/19 Python
CSS中垂直居中的简单实现方法
2015/07/06 HTML / CSS
介绍一下Linux文件的记录形式
2012/04/18 面试题
幼儿教师师德师风自我剖析材料
2014/09/29 职场文书
学校党的群众路线教育实践活动总结材料
2014/10/30 职场文书