Pytorch实现GoogLeNet的方法


Posted in Python onAugust 18, 2019

GoogLeNet也叫InceptionNet,在2014年被提出,如今已到V4版本。GoogleNet比VGGNet具有更深的网络结构,一共有22层,但是参数比AlexNet要少12倍,但是计算量是AlexNet的4倍,原因就是它采用很有效的Inception模块,并且没有全连接层。

最重要的创新点就在于使用inception模块,通过使用不同维度的卷积提取不同尺度的特征图。左图是最初的Inception模块,右图是使用的1×1得卷积对左图的改进,降低了输入的特征图维度,同时降低了网络的参数量和计算复杂度,称为inception V1。

Pytorch实现GoogLeNet的方法

GoogleNet在架构设计上为保持低层为传统卷积方式不变,只在较高的层开始用Inception模块。

Pytorch实现GoogLeNet的方法

Pytorch实现GoogLeNet的方法

inception V2中将5x5的卷积改为2个3x3的卷积,扩大了感受野,原来是5x5,现在是6x6。Pytorch实现GoogLeNet(inception V2):

'''GoogLeNet with PyTorch.'''
import torch
import torch.nn as nn
import torch.nn.functional as F

# 编写卷积+bn+relu模块
class BasicConv2d(nn.Module):
  def __init__(self, in_channels, out_channals, **kwargs):
    super(BasicConv2d, self).__init__()
    self.conv = nn.Conv2d(in_channels, out_channals, **kwargs)
    self.bn = nn.BatchNorm2d(out_channals)

  def forward(self, x):
    x = self.conv(x)
    x = self.bn(x)
    return F.relu(x)

# 编写Inception模块
class Inception(nn.Module):
  def __init__(self, in_planes,
         n1x1, n3x3red, n3x3, n5x5red, n5x5, pool_planes):
    super(Inception, self).__init__()
    # 1x1 conv branch
    self.b1 = BasicConv2d(in_planes, n1x1, kernel_size=1)

    # 1x1 conv -> 3x3 conv branch
    self.b2_1x1_a = BasicConv2d(in_planes, n3x3red, 
                  kernel_size=1)
    self.b2_3x3_b = BasicConv2d(n3x3red, n3x3, 
                  kernel_size=3, padding=1)

    # 1x1 conv -> 3x3 conv -> 3x3 conv branch
    self.b3_1x1_a = BasicConv2d(in_planes, n5x5red, 
                  kernel_size=1)
    self.b3_3x3_b = BasicConv2d(n5x5red, n5x5, 
                  kernel_size=3, padding=1)
    self.b3_3x3_c = BasicConv2d(n5x5, n5x5, 
                  kernel_size=3, padding=1)

    # 3x3 pool -> 1x1 conv branch
    self.b4_pool = nn.MaxPool2d(3, stride=1, padding=1)
    self.b4_1x1 = BasicConv2d(in_planes, pool_planes, 
                 kernel_size=1)

  def forward(self, x):
    y1 = self.b1(x)
    y2 = self.b2_3x3_b(self.b2_1x1_a(x))
    y3 = self.b3_3x3_c(self.b3_3x3_b(self.b3_1x1_a(x)))
    y4 = self.b4_1x1(self.b4_pool(x))
    # y的维度为[batch_size, out_channels, C_out,L_out]
    # 合并不同卷积下的特征图
    return torch.cat([y1, y2, y3, y4], 1)


class GoogLeNet(nn.Module):
  def __init__(self):
    super(GoogLeNet, self).__init__()
    self.pre_layers = BasicConv2d(3, 192, 
                   kernel_size=3, padding=1)

    self.a3 = Inception(192, 64, 96, 128, 16, 32, 32)
    self.b3 = Inception(256, 128, 128, 192, 32, 96, 64)

    self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)

    self.a4 = Inception(480, 192, 96, 208, 16, 48, 64)
    self.b4 = Inception(512, 160, 112, 224, 24, 64, 64)
    self.c4 = Inception(512, 128, 128, 256, 24, 64, 64)
    self.d4 = Inception(512, 112, 144, 288, 32, 64, 64)
    self.e4 = Inception(528, 256, 160, 320, 32, 128, 128)

    self.a5 = Inception(832, 256, 160, 320, 32, 128, 128)
    self.b5 = Inception(832, 384, 192, 384, 48, 128, 128)

    self.avgpool = nn.AvgPool2d(8, stride=1)
    self.linear = nn.Linear(1024, 10)

  def forward(self, x):
    out = self.pre_layers(x)
    out = self.a3(out)
    out = self.b3(out)
    out = self.maxpool(out)
    out = self.a4(out)
    out = self.b4(out)
    out = self.c4(out)
    out = self.d4(out)
    out = self.e4(out)
    out = self.maxpool(out)
    out = self.a5(out)
    out = self.b5(out)
    out = self.avgpool(out)
    out = out.view(out.size(0), -1)
    out = self.linear(out)
    return out


def test():
  net = GoogLeNet()
  x = torch.randn(1,3,32,32)
  y = net(x)
  print(y.size())

test()

以上这篇Pytorch实现GoogLeNet的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python while、for、生成器、列表推导等语句的执行效率测试
Jun 03 Python
Python编程中字符串和列表的基本知识讲解
Oct 14 Python
小议Python中自定义函数的可变参数的使用及注意点
Jun 21 Python
python 容器总结整理
Apr 04 Python
Python2.7 实现引入自己写的类方法
Apr 29 Python
Python实现的json文件读取及中文乱码显示问题解决方法
Aug 06 Python
python进行TCP端口扫描的实现
Dec 21 Python
Django Channel实时推送与聊天的示例代码
Apr 30 Python
Python flask框架端口失效解决方案
Jun 04 Python
django美化后台django-suit的安装配置操作
Jul 12 Python
Pandas中DataFrame交换列顺序的方法实现
Dec 14 Python
Python机器学习之逻辑回归
May 11 Python
PyTorch之图像和Tensor填充的实例
Aug 18 #Python
Pytorch Tensor的索引与切片例子
Aug 18 #Python
在PyTorch中Tensor的查找和筛选例子
Aug 18 #Python
对Pytorch神经网络初始化kaiming分布详解
Aug 18 #Python
pytorch中的embedding词向量的使用方法
Aug 18 #Python
Pytorch加载部分预训练模型的参数实例
Aug 18 #Python
在pytorch中查看可训练参数的例子
Aug 18 #Python
You might like
用 PHP5 轻松解析 XML
2006/12/04 PHP
PHP写MySQL数据 实现代码
2009/06/15 PHP
thinkphp实现多语言功能(语言包)
2014/03/04 PHP
PHP levenshtein()函数用法讲解
2019/03/08 PHP
7个Javascript地图脚本整理
2009/10/20 Javascript
jquery+ajax+C#实现无刷新操作数据库数据的简单实例
2014/02/08 Javascript
jquery实现简单的banner轮播效果【实例】
2016/03/30 Javascript
在Node.js中使用Javascript Generators详解
2016/05/05 Javascript
jquery表单插件form使用方法详解
2017/01/20 Javascript
利用JavaScript如何查询某个值是否数组内
2017/07/30 Javascript
对于input 框限定输入值为浮点型的js代码
2017/09/25 Javascript
深入理解Puppeteer的入门教程和实践
2019/03/05 Javascript
微信小程序实现用table显示数据库反馈的多条数据功能示例
2019/05/07 Javascript
vue 表单之通过v-model绑定单选按钮radio
2019/05/13 Javascript
Django+Vue实现WebSocket连接的示例代码
2019/05/28 Javascript
ES6 Set结构的应用实例分析
2019/06/26 Javascript
Node.js fs模块(文件模块)创建、删除目录(文件)读取写入文件流的方法
2019/09/03 Javascript
Python警察与小偷的实现之一客户端与服务端通信实例
2014/10/09 Python
在Python的Django框架中包装视图函数
2015/07/20 Python
Django与遗留的数据库整合的方法指南
2015/07/24 Python
举例讲解Python中的Null模式与桥接模式编程
2016/02/02 Python
遗传算法之Python实现代码
2017/10/10 Python
解决Python安装后pip不能用的问题
2018/06/12 Python
python Matplotlib底图中鼠标滑过显示隐藏内容的实例代码
2019/07/31 Python
对Django中的权限和分组管理实例讲解
2019/08/16 Python
浅谈HTML5 服务器推送事件(Server-sent Events)
2017/08/01 HTML / CSS
html5使用Drag事件编辑器拖拽上传图片的示例代码
2017/08/22 HTML / CSS
德国消费电子产品购物网站:Guter Kauf
2020/09/15 全球购物
生产部厂长职位说明书
2014/03/03 职场文书
《荷花》教学反思
2014/04/16 职场文书
教研处工作方案
2014/05/26 职场文书
出国留学导师推荐信
2015/03/26 职场文书
经费申请报告范文
2015/05/18 职场文书
2015年卫生局工作总结
2015/07/24 职场文书
2019年健身俱乐部的创业计划书
2019/08/26 职场文书
python 实现定时任务的四种方式
2021/04/01 Python