Pytorch实现GoogLeNet的方法


Posted in Python onAugust 18, 2019

GoogLeNet也叫InceptionNet,在2014年被提出,如今已到V4版本。GoogleNet比VGGNet具有更深的网络结构,一共有22层,但是参数比AlexNet要少12倍,但是计算量是AlexNet的4倍,原因就是它采用很有效的Inception模块,并且没有全连接层。

最重要的创新点就在于使用inception模块,通过使用不同维度的卷积提取不同尺度的特征图。左图是最初的Inception模块,右图是使用的1×1得卷积对左图的改进,降低了输入的特征图维度,同时降低了网络的参数量和计算复杂度,称为inception V1。

Pytorch实现GoogLeNet的方法

GoogleNet在架构设计上为保持低层为传统卷积方式不变,只在较高的层开始用Inception模块。

Pytorch实现GoogLeNet的方法

Pytorch实现GoogLeNet的方法

inception V2中将5x5的卷积改为2个3x3的卷积,扩大了感受野,原来是5x5,现在是6x6。Pytorch实现GoogLeNet(inception V2):

'''GoogLeNet with PyTorch.'''
import torch
import torch.nn as nn
import torch.nn.functional as F

# 编写卷积+bn+relu模块
class BasicConv2d(nn.Module):
  def __init__(self, in_channels, out_channals, **kwargs):
    super(BasicConv2d, self).__init__()
    self.conv = nn.Conv2d(in_channels, out_channals, **kwargs)
    self.bn = nn.BatchNorm2d(out_channals)

  def forward(self, x):
    x = self.conv(x)
    x = self.bn(x)
    return F.relu(x)

# 编写Inception模块
class Inception(nn.Module):
  def __init__(self, in_planes,
         n1x1, n3x3red, n3x3, n5x5red, n5x5, pool_planes):
    super(Inception, self).__init__()
    # 1x1 conv branch
    self.b1 = BasicConv2d(in_planes, n1x1, kernel_size=1)

    # 1x1 conv -> 3x3 conv branch
    self.b2_1x1_a = BasicConv2d(in_planes, n3x3red, 
                  kernel_size=1)
    self.b2_3x3_b = BasicConv2d(n3x3red, n3x3, 
                  kernel_size=3, padding=1)

    # 1x1 conv -> 3x3 conv -> 3x3 conv branch
    self.b3_1x1_a = BasicConv2d(in_planes, n5x5red, 
                  kernel_size=1)
    self.b3_3x3_b = BasicConv2d(n5x5red, n5x5, 
                  kernel_size=3, padding=1)
    self.b3_3x3_c = BasicConv2d(n5x5, n5x5, 
                  kernel_size=3, padding=1)

    # 3x3 pool -> 1x1 conv branch
    self.b4_pool = nn.MaxPool2d(3, stride=1, padding=1)
    self.b4_1x1 = BasicConv2d(in_planes, pool_planes, 
                 kernel_size=1)

  def forward(self, x):
    y1 = self.b1(x)
    y2 = self.b2_3x3_b(self.b2_1x1_a(x))
    y3 = self.b3_3x3_c(self.b3_3x3_b(self.b3_1x1_a(x)))
    y4 = self.b4_1x1(self.b4_pool(x))
    # y的维度为[batch_size, out_channels, C_out,L_out]
    # 合并不同卷积下的特征图
    return torch.cat([y1, y2, y3, y4], 1)


class GoogLeNet(nn.Module):
  def __init__(self):
    super(GoogLeNet, self).__init__()
    self.pre_layers = BasicConv2d(3, 192, 
                   kernel_size=3, padding=1)

    self.a3 = Inception(192, 64, 96, 128, 16, 32, 32)
    self.b3 = Inception(256, 128, 128, 192, 32, 96, 64)

    self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)

    self.a4 = Inception(480, 192, 96, 208, 16, 48, 64)
    self.b4 = Inception(512, 160, 112, 224, 24, 64, 64)
    self.c4 = Inception(512, 128, 128, 256, 24, 64, 64)
    self.d4 = Inception(512, 112, 144, 288, 32, 64, 64)
    self.e4 = Inception(528, 256, 160, 320, 32, 128, 128)

    self.a5 = Inception(832, 256, 160, 320, 32, 128, 128)
    self.b5 = Inception(832, 384, 192, 384, 48, 128, 128)

    self.avgpool = nn.AvgPool2d(8, stride=1)
    self.linear = nn.Linear(1024, 10)

  def forward(self, x):
    out = self.pre_layers(x)
    out = self.a3(out)
    out = self.b3(out)
    out = self.maxpool(out)
    out = self.a4(out)
    out = self.b4(out)
    out = self.c4(out)
    out = self.d4(out)
    out = self.e4(out)
    out = self.maxpool(out)
    out = self.a5(out)
    out = self.b5(out)
    out = self.avgpool(out)
    out = out.view(out.size(0), -1)
    out = self.linear(out)
    return out


def test():
  net = GoogLeNet()
  x = torch.randn(1,3,32,32)
  y = net(x)
  print(y.size())

test()

以上这篇Pytorch实现GoogLeNet的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现分析apache和nginx日志文件并输出访客ip列表的方法
Apr 04 Python
Python实现的数据结构与算法之基本搜索详解
Apr 22 Python
Python爬虫之模拟知乎登录的方法教程
May 25 Python
Python_LDA实现方法详解
Oct 25 Python
使用Python检测文章抄袭及去重算法原理解析
Jun 14 Python
python绘制已知点的坐标的直线实例
Jul 04 Python
python重要函数eval多种用法解析
Jan 14 Python
Python使用循环神经网络解决文本分类问题的方法详解
Jan 16 Python
Python使用OpenPyXL处理Excel表格
Jul 02 Python
python Paramiko使用示例
Sep 21 Python
Python常用base64 md5 aes des crc32加密解密方法汇总
Nov 06 Python
通过实例解析python and和or使用方法
Nov 14 Python
PyTorch之图像和Tensor填充的实例
Aug 18 #Python
Pytorch Tensor的索引与切片例子
Aug 18 #Python
在PyTorch中Tensor的查找和筛选例子
Aug 18 #Python
对Pytorch神经网络初始化kaiming分布详解
Aug 18 #Python
pytorch中的embedding词向量的使用方法
Aug 18 #Python
Pytorch加载部分预训练模型的参数实例
Aug 18 #Python
在pytorch中查看可训练参数的例子
Aug 18 #Python
You might like
PHP.MVC的模板标签系统(二)
2006/09/05 PHP
Smarty Foreach 使用说明
2010/03/23 PHP
解析PHP自带的进位制之间的转换函数
2013/06/08 PHP
windows7下安装php的php-ssh2扩展教程
2014/07/04 PHP
laravel 解决路由除了根目录其他都404的问题
2019/10/18 PHP
使用PHP开发留言板功能
2019/11/19 PHP
javascript基于jQuery的表格悬停变色/恢复,表格点击变色/恢复,点击行选Checkbox
2008/08/05 Javascript
JQuery 引发两次$(document.ready)事件
2010/01/15 Javascript
jQuery html() in Firefox (uses .innerHTML) ignores DOM changes
2010/03/05 Javascript
一些常用且实用的原生JavaScript函数
2010/09/08 Javascript
setTimeout和setInterval的深入理解
2013/11/08 Javascript
alert出数组中的随即值代码
2014/09/25 Javascript
php,js,css字符串截取的办法集锦
2014/09/26 Javascript
Jquery遍历Json数据的方法
2015/04/20 Javascript
浅谈jQuery构造函数分析
2015/05/11 Javascript
JS实现兼容性好,自动置顶的淘宝悬浮工具栏效果
2015/09/18 Javascript
javascript读取文本节点方法小结
2016/12/15 Javascript
JS异步加载的三种实现方式
2017/03/16 Javascript
vue v-model动态生成详解
2018/06/30 Javascript
JavaScript(js)处理的HTML事件、键盘事件、鼠标事件简单示例
2019/11/19 Javascript
微信小程序自定义模态弹窗组件详解
2019/12/24 Javascript
[00:12]DAC SOLO赛卫冕冠军 VG.Paparazi灬展现SOLO技巧
2018/04/06 DOTA
python实现给字典添加条目的方法
2014/09/25 Python
python计算文本文件行数的方法
2015/07/06 Python
Python利用Nagios增加微信报警通知的功能
2016/02/18 Python
基于python 二维数组及画图的实例详解
2018/04/03 Python
使用pip安装python库的多种方式
2019/07/31 Python
TensorBoard 计算图的查看方式
2020/02/15 Python
联想加拿大官方网站:Lenovo Canada
2018/04/05 全球购物
介绍一下内联、左联、右联
2013/12/31 面试题
如何防止同一个帐户被多人同时登录
2013/08/01 面试题
父母对孩子说的话
2014/04/12 职场文书
家长学校教学计划
2015/01/19 职场文书
超强台风观后感
2015/06/09 职场文书
2016年公司“3.12”植树节活动总结
2016/03/16 职场文书
IDEA使用SpringAssistant插件创建SpringCloud项目
2021/06/23 Java/Android