Pytorch实现GoogLeNet的方法


Posted in Python onAugust 18, 2019

GoogLeNet也叫InceptionNet,在2014年被提出,如今已到V4版本。GoogleNet比VGGNet具有更深的网络结构,一共有22层,但是参数比AlexNet要少12倍,但是计算量是AlexNet的4倍,原因就是它采用很有效的Inception模块,并且没有全连接层。

最重要的创新点就在于使用inception模块,通过使用不同维度的卷积提取不同尺度的特征图。左图是最初的Inception模块,右图是使用的1×1得卷积对左图的改进,降低了输入的特征图维度,同时降低了网络的参数量和计算复杂度,称为inception V1。

Pytorch实现GoogLeNet的方法

GoogleNet在架构设计上为保持低层为传统卷积方式不变,只在较高的层开始用Inception模块。

Pytorch实现GoogLeNet的方法

Pytorch实现GoogLeNet的方法

inception V2中将5x5的卷积改为2个3x3的卷积,扩大了感受野,原来是5x5,现在是6x6。Pytorch实现GoogLeNet(inception V2):

'''GoogLeNet with PyTorch.'''
import torch
import torch.nn as nn
import torch.nn.functional as F

# 编写卷积+bn+relu模块
class BasicConv2d(nn.Module):
  def __init__(self, in_channels, out_channals, **kwargs):
    super(BasicConv2d, self).__init__()
    self.conv = nn.Conv2d(in_channels, out_channals, **kwargs)
    self.bn = nn.BatchNorm2d(out_channals)

  def forward(self, x):
    x = self.conv(x)
    x = self.bn(x)
    return F.relu(x)

# 编写Inception模块
class Inception(nn.Module):
  def __init__(self, in_planes,
         n1x1, n3x3red, n3x3, n5x5red, n5x5, pool_planes):
    super(Inception, self).__init__()
    # 1x1 conv branch
    self.b1 = BasicConv2d(in_planes, n1x1, kernel_size=1)

    # 1x1 conv -> 3x3 conv branch
    self.b2_1x1_a = BasicConv2d(in_planes, n3x3red, 
                  kernel_size=1)
    self.b2_3x3_b = BasicConv2d(n3x3red, n3x3, 
                  kernel_size=3, padding=1)

    # 1x1 conv -> 3x3 conv -> 3x3 conv branch
    self.b3_1x1_a = BasicConv2d(in_planes, n5x5red, 
                  kernel_size=1)
    self.b3_3x3_b = BasicConv2d(n5x5red, n5x5, 
                  kernel_size=3, padding=1)
    self.b3_3x3_c = BasicConv2d(n5x5, n5x5, 
                  kernel_size=3, padding=1)

    # 3x3 pool -> 1x1 conv branch
    self.b4_pool = nn.MaxPool2d(3, stride=1, padding=1)
    self.b4_1x1 = BasicConv2d(in_planes, pool_planes, 
                 kernel_size=1)

  def forward(self, x):
    y1 = self.b1(x)
    y2 = self.b2_3x3_b(self.b2_1x1_a(x))
    y3 = self.b3_3x3_c(self.b3_3x3_b(self.b3_1x1_a(x)))
    y4 = self.b4_1x1(self.b4_pool(x))
    # y的维度为[batch_size, out_channels, C_out,L_out]
    # 合并不同卷积下的特征图
    return torch.cat([y1, y2, y3, y4], 1)


class GoogLeNet(nn.Module):
  def __init__(self):
    super(GoogLeNet, self).__init__()
    self.pre_layers = BasicConv2d(3, 192, 
                   kernel_size=3, padding=1)

    self.a3 = Inception(192, 64, 96, 128, 16, 32, 32)
    self.b3 = Inception(256, 128, 128, 192, 32, 96, 64)

    self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)

    self.a4 = Inception(480, 192, 96, 208, 16, 48, 64)
    self.b4 = Inception(512, 160, 112, 224, 24, 64, 64)
    self.c4 = Inception(512, 128, 128, 256, 24, 64, 64)
    self.d4 = Inception(512, 112, 144, 288, 32, 64, 64)
    self.e4 = Inception(528, 256, 160, 320, 32, 128, 128)

    self.a5 = Inception(832, 256, 160, 320, 32, 128, 128)
    self.b5 = Inception(832, 384, 192, 384, 48, 128, 128)

    self.avgpool = nn.AvgPool2d(8, stride=1)
    self.linear = nn.Linear(1024, 10)

  def forward(self, x):
    out = self.pre_layers(x)
    out = self.a3(out)
    out = self.b3(out)
    out = self.maxpool(out)
    out = self.a4(out)
    out = self.b4(out)
    out = self.c4(out)
    out = self.d4(out)
    out = self.e4(out)
    out = self.maxpool(out)
    out = self.a5(out)
    out = self.b5(out)
    out = self.avgpool(out)
    out = out.view(out.size(0), -1)
    out = self.linear(out)
    return out


def test():
  net = GoogLeNet()
  x = torch.randn(1,3,32,32)
  y = net(x)
  print(y.size())

test()

以上这篇Pytorch实现GoogLeNet的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python使用matplotlib实现在坐标系中画一个矩形的方法
May 20 Python
讲解Python的Scrapy爬虫框架使用代理进行采集的方法
Feb 18 Python
编写Python小程序来统计测试脚本的关键字
Mar 12 Python
Python基于tkinter模块实现的改名小工具示例
Jul 27 Python
python实现守护进程、守护线程、守护非守护并行
May 05 Python
python分批定量读取文件内容,输出到不同文件中的方法
Dec 08 Python
关于python中密码加盐的学习体会小结
Jul 15 Python
python opencv将图片转为灰度图的方法示例
Jul 31 Python
pycharm部署、配置anaconda环境的教程
Mar 24 Python
详解pandas绘制矩阵散点图(scatter_matrix)的方法
Apr 23 Python
python中常见错误及解决方法
Jun 21 Python
Python+OpenCV图像处理——实现直线检测
Oct 23 Python
PyTorch之图像和Tensor填充的实例
Aug 18 #Python
Pytorch Tensor的索引与切片例子
Aug 18 #Python
在PyTorch中Tensor的查找和筛选例子
Aug 18 #Python
对Pytorch神经网络初始化kaiming分布详解
Aug 18 #Python
pytorch中的embedding词向量的使用方法
Aug 18 #Python
Pytorch加载部分预训练模型的参数实例
Aug 18 #Python
在pytorch中查看可训练参数的例子
Aug 18 #Python
You might like
截获网站title标签之家内容的例子
2006/10/09 PHP
笑谈配置,使用Smarty技术
2007/01/04 PHP
IStream与TStream之间的相互转换
2008/08/01 PHP
php 中英文语言转换类代码
2011/08/11 PHP
PHP函数spl_autoload_register()用法和__autoload()介绍
2012/02/04 PHP
CodeIgniter上传图片成功的全部过程分享
2013/08/12 PHP
php中的动态调用实例分析
2015/01/07 PHP
PHP简单选择排序算法实例
2015/01/26 PHP
一个简单安全的PHP验证码类、PHP验证码
2016/09/24 PHP
php上传excel表格并获取数据
2017/04/27 PHP
Laravel构建即时应用的一种实现方法详解
2017/08/31 PHP
PHP实现创建一个RPC服务操作示例
2020/02/23 PHP
js实现兼容IE6与IE7的DIV高度
2010/05/13 Javascript
JS代码优化技巧之通俗版(减少js体积)
2011/12/23 Javascript
ie支持function.bind()方法实现代码
2012/12/27 Javascript
jquery+ajax实现省市区三级联动效果简单示例
2017/01/04 Javascript
js模块加载方式浅析
2017/08/12 Javascript
微信小程序实现选项卡功能
2020/06/19 Javascript
微信小程序使用radio显示单选项功能【附源码下载】
2017/12/11 Javascript
IE8中jQuery.load()加载页面不显示的原因
2018/11/15 jQuery
用jQuery实现抽奖程序
2020/04/12 jQuery
初步讲解Python中的元组概念
2015/05/21 Python
编写Python CGI脚本的教程
2015/06/29 Python
浅谈python jieba分词模块的基本用法
2017/11/09 Python
pandas.DataFrame.to_json按行转json的方法
2018/06/05 Python
Python使用requests提交HTTP表单的方法
2018/12/26 Python
Python中文编码知识点
2019/02/18 Python
python爬虫 execjs安装配置及使用
2019/07/30 Python
用Python解数独的方法示例
2019/10/24 Python
numpy数组做图片拼接的实现(concatenate、vstack、hstack)
2019/11/08 Python
python 实现批量替换文本中的某部分内容
2019/12/13 Python
Myprotein法国官网:欧洲第一运动营养品牌
2019/03/26 全球购物
暑假安全保证书
2015/02/28 职场文书
小学生班干部竞选稿
2015/11/20 职场文书
党章党规党纪学习心得体会
2016/01/14 职场文书
vue打包时去掉所有的console.log
2022/04/10 Vue.js