Pytorch实现GoogLeNet的方法


Posted in Python onAugust 18, 2019

GoogLeNet也叫InceptionNet,在2014年被提出,如今已到V4版本。GoogleNet比VGGNet具有更深的网络结构,一共有22层,但是参数比AlexNet要少12倍,但是计算量是AlexNet的4倍,原因就是它采用很有效的Inception模块,并且没有全连接层。

最重要的创新点就在于使用inception模块,通过使用不同维度的卷积提取不同尺度的特征图。左图是最初的Inception模块,右图是使用的1×1得卷积对左图的改进,降低了输入的特征图维度,同时降低了网络的参数量和计算复杂度,称为inception V1。

Pytorch实现GoogLeNet的方法

GoogleNet在架构设计上为保持低层为传统卷积方式不变,只在较高的层开始用Inception模块。

Pytorch实现GoogLeNet的方法

Pytorch实现GoogLeNet的方法

inception V2中将5x5的卷积改为2个3x3的卷积,扩大了感受野,原来是5x5,现在是6x6。Pytorch实现GoogLeNet(inception V2):

'''GoogLeNet with PyTorch.'''
import torch
import torch.nn as nn
import torch.nn.functional as F

# 编写卷积+bn+relu模块
class BasicConv2d(nn.Module):
  def __init__(self, in_channels, out_channals, **kwargs):
    super(BasicConv2d, self).__init__()
    self.conv = nn.Conv2d(in_channels, out_channals, **kwargs)
    self.bn = nn.BatchNorm2d(out_channals)

  def forward(self, x):
    x = self.conv(x)
    x = self.bn(x)
    return F.relu(x)

# 编写Inception模块
class Inception(nn.Module):
  def __init__(self, in_planes,
         n1x1, n3x3red, n3x3, n5x5red, n5x5, pool_planes):
    super(Inception, self).__init__()
    # 1x1 conv branch
    self.b1 = BasicConv2d(in_planes, n1x1, kernel_size=1)

    # 1x1 conv -> 3x3 conv branch
    self.b2_1x1_a = BasicConv2d(in_planes, n3x3red, 
                  kernel_size=1)
    self.b2_3x3_b = BasicConv2d(n3x3red, n3x3, 
                  kernel_size=3, padding=1)

    # 1x1 conv -> 3x3 conv -> 3x3 conv branch
    self.b3_1x1_a = BasicConv2d(in_planes, n5x5red, 
                  kernel_size=1)
    self.b3_3x3_b = BasicConv2d(n5x5red, n5x5, 
                  kernel_size=3, padding=1)
    self.b3_3x3_c = BasicConv2d(n5x5, n5x5, 
                  kernel_size=3, padding=1)

    # 3x3 pool -> 1x1 conv branch
    self.b4_pool = nn.MaxPool2d(3, stride=1, padding=1)
    self.b4_1x1 = BasicConv2d(in_planes, pool_planes, 
                 kernel_size=1)

  def forward(self, x):
    y1 = self.b1(x)
    y2 = self.b2_3x3_b(self.b2_1x1_a(x))
    y3 = self.b3_3x3_c(self.b3_3x3_b(self.b3_1x1_a(x)))
    y4 = self.b4_1x1(self.b4_pool(x))
    # y的维度为[batch_size, out_channels, C_out,L_out]
    # 合并不同卷积下的特征图
    return torch.cat([y1, y2, y3, y4], 1)


class GoogLeNet(nn.Module):
  def __init__(self):
    super(GoogLeNet, self).__init__()
    self.pre_layers = BasicConv2d(3, 192, 
                   kernel_size=3, padding=1)

    self.a3 = Inception(192, 64, 96, 128, 16, 32, 32)
    self.b3 = Inception(256, 128, 128, 192, 32, 96, 64)

    self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)

    self.a4 = Inception(480, 192, 96, 208, 16, 48, 64)
    self.b4 = Inception(512, 160, 112, 224, 24, 64, 64)
    self.c4 = Inception(512, 128, 128, 256, 24, 64, 64)
    self.d4 = Inception(512, 112, 144, 288, 32, 64, 64)
    self.e4 = Inception(528, 256, 160, 320, 32, 128, 128)

    self.a5 = Inception(832, 256, 160, 320, 32, 128, 128)
    self.b5 = Inception(832, 384, 192, 384, 48, 128, 128)

    self.avgpool = nn.AvgPool2d(8, stride=1)
    self.linear = nn.Linear(1024, 10)

  def forward(self, x):
    out = self.pre_layers(x)
    out = self.a3(out)
    out = self.b3(out)
    out = self.maxpool(out)
    out = self.a4(out)
    out = self.b4(out)
    out = self.c4(out)
    out = self.d4(out)
    out = self.e4(out)
    out = self.maxpool(out)
    out = self.a5(out)
    out = self.b5(out)
    out = self.avgpool(out)
    out = out.view(out.size(0), -1)
    out = self.linear(out)
    return out


def test():
  net = GoogLeNet()
  x = torch.randn(1,3,32,32)
  y = net(x)
  print(y.size())

test()

以上这篇Pytorch实现GoogLeNet的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用PYTHON接收多播数据的代码
Mar 01 Python
python 获取list特定元素下标的实例讲解
Apr 09 Python
python执行系统命令后获取返回值的几种方式集合
May 12 Python
Python 查看list中是否含有某元素的方法
Jun 27 Python
使用Python实现一个栈判断括号是否平衡
Aug 23 Python
Python3.6简单的操作Mysql数据库的三个实例
Oct 17 Python
python绘制简单彩虹图
Nov 19 Python
解决Python2.7中IDLE启动没有反应的问题
Nov 30 Python
啥是佩奇?使用Python自动绘画小猪佩奇的代码实例
Feb 20 Python
python脚本后台执行方式
Dec 21 Python
在Pytorch中计算卷积方法的区别详解(conv2d的区别)
Jan 03 Python
Django数据模型中on_delete使用详解
Nov 30 Python
PyTorch之图像和Tensor填充的实例
Aug 18 #Python
Pytorch Tensor的索引与切片例子
Aug 18 #Python
在PyTorch中Tensor的查找和筛选例子
Aug 18 #Python
对Pytorch神经网络初始化kaiming分布详解
Aug 18 #Python
pytorch中的embedding词向量的使用方法
Aug 18 #Python
Pytorch加载部分预训练模型的参数实例
Aug 18 #Python
在pytorch中查看可训练参数的例子
Aug 18 #Python
You might like
咖啡机如何保养和日常清洁?
2021/03/03 冲泡冲煮
ThinkPHP3.1新特性之对分组支持的改进与完善概述
2014/06/19 PHP
ThinkPHP在新浪SAE平台的部署实例
2014/10/31 PHP
php禁止直接从浏览器输入地址访问.php文件的方法
2014/11/04 PHP
php中addslashes函数与sql防注入
2014/11/17 PHP
PHP面向对象程序设计模拟一般面向对象语言中的方法重载(overload)示例
2019/06/13 PHP
JQUERY复选框CHECKBOX全选,取消全选
2008/08/30 Javascript
JavaScript 组件之旅(四):测试 JavaScript 组件
2009/10/28 Javascript
Node.js实现批量去除BOM文件头
2014/12/20 Javascript
javascript正则表达式之分组概念与用法实例
2016/06/16 Javascript
JS实现图片垂直居中显示小结
2016/12/13 Javascript
bootstrap滚动监控器使用方法解析
2017/01/13 Javascript
JavaScript字符串对象
2017/01/14 Javascript
JavaScript实现无穷滚动加载数据
2017/05/06 Javascript
Vue+Flask实现简单的登录验证跳转的示例代码
2018/01/13 Javascript
详解基于DllPlugin和DllReferencePlugin的webpack构建优化
2018/06/28 Javascript
对vue 键盘回车事件的实例讲解
2018/08/25 Javascript
js实现超级玛丽小游戏
2020/03/18 Javascript
vue实现移动端拖动排序
2020/08/21 Javascript
在Python中使用AOP实现Redis缓存示例
2017/07/11 Python
python利用requests库进行接口测试的方法详解
2018/07/06 Python
Python3.5 Json与pickle实现数据序列化与反序列化操作示例
2019/04/29 Python
如何用Python来搭建一个简单的推荐系统
2019/08/07 Python
Python scipy的二维图像卷积运算与图像模糊处理操作示例
2019/09/06 Python
Python中pyecharts安装及安装失败的解决方法
2020/02/18 Python
python二维图制作的实例代码
2020/12/03 Python
HTML5实现的震撼3D焦点图动画的示例代码
2019/09/26 HTML / CSS
html5.2 dialog简介详解
2018/02/27 HTML / CSS
什么造成了Java里面的异常
2016/04/24 面试题
学生干部学习的自我评价
2014/02/18 职场文书
《老山界》教学反思
2014/04/08 职场文书
2014党员整改措施思想汇报
2014/10/07 职场文书
商业计划书格式、范文
2019/03/21 职场文书
怎样评估创业计划书是否有可行性?
2019/08/07 职场文书
七年级作文之下雨天
2019/12/23 职场文书
MySQL慢查询优化解决问题
2022/03/17 MySQL