PyTorch实现ResNet50、ResNet101和ResNet152示例


Posted in Python onJanuary 14, 2020

PyTorch: https://github.com/shanglianlm0525/PyTorch-Networks

PyTorch实现ResNet50、ResNet101和ResNet152示例

import torch
import torch.nn as nn
import torchvision
import numpy as np

print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)

__all__ = ['ResNet50', 'ResNet101','ResNet152']

def Conv1(in_planes, places, stride=2):
  return nn.Sequential(
    nn.Conv2d(in_channels=in_planes,out_channels=places,kernel_size=7,stride=stride,padding=3, bias=False),
    nn.BatchNorm2d(places),
    nn.ReLU(inplace=True),
    nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  )

class Bottleneck(nn.Module):
  def __init__(self,in_places,places, stride=1,downsampling=False, expansion = 4):
    super(Bottleneck,self).__init__()
    self.expansion = expansion
    self.downsampling = downsampling

    self.bottleneck = nn.Sequential(
      nn.Conv2d(in_channels=in_places,out_channels=places,kernel_size=1,stride=1, bias=False),
      nn.BatchNorm2d(places),
      nn.ReLU(inplace=True),
      nn.Conv2d(in_channels=places, out_channels=places, kernel_size=3, stride=stride, padding=1, bias=False),
      nn.BatchNorm2d(places),
      nn.ReLU(inplace=True),
      nn.Conv2d(in_channels=places, out_channels=places*self.expansion, kernel_size=1, stride=1, bias=False),
      nn.BatchNorm2d(places*self.expansion),
    )

    if self.downsampling:
      self.downsample = nn.Sequential(
        nn.Conv2d(in_channels=in_places, out_channels=places*self.expansion, kernel_size=1, stride=stride, bias=False),
        nn.BatchNorm2d(places*self.expansion)
      )
    self.relu = nn.ReLU(inplace=True)
  def forward(self, x):
    residual = x
    out = self.bottleneck(x)

    if self.downsampling:
      residual = self.downsample(x)

    out += residual
    out = self.relu(out)
    return out

class ResNet(nn.Module):
  def __init__(self,blocks, num_classes=1000, expansion = 4):
    super(ResNet,self).__init__()
    self.expansion = expansion

    self.conv1 = Conv1(in_planes = 3, places= 64)

    self.layer1 = self.make_layer(in_places = 64, places= 64, block=blocks[0], stride=1)
    self.layer2 = self.make_layer(in_places = 256,places=128, block=blocks[1], stride=2)
    self.layer3 = self.make_layer(in_places=512,places=256, block=blocks[2], stride=2)
    self.layer4 = self.make_layer(in_places=1024,places=512, block=blocks[3], stride=2)

    self.avgpool = nn.AvgPool2d(7, stride=1)
    self.fc = nn.Linear(2048,num_classes)

    for m in self.modules():
      if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
      elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)

  def make_layer(self, in_places, places, block, stride):
    layers = []
    layers.append(Bottleneck(in_places, places,stride, downsampling =True))
    for i in range(1, block):
      layers.append(Bottleneck(places*self.expansion, places))

    return nn.Sequential(*layers)


  def forward(self, x):
    x = self.conv1(x)

    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    x = self.layer4(x)

    x = self.avgpool(x)
    x = x.view(x.size(0), -1)
    x = self.fc(x)
    return x

def ResNet50():
  return ResNet([3, 4, 6, 3])

def ResNet101():
  return ResNet([3, 4, 23, 3])

def ResNet152():
  return ResNet([3, 8, 36, 3])


if __name__=='__main__':
  #model = torchvision.models.resnet50()
  model = ResNet50()
  print(model)

  input = torch.randn(1, 3, 224, 224)
  out = model(input)
  print(out.shape)

以上这篇PyTorch实现ResNet50、ResNet101和ResNet152示例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 字符串中的字符倒转
Sep 06 Python
python实现获取序列中最小的几个元素
Sep 25 Python
Python映射拆分操作符用法实例
May 19 Python
详解Python中for循环是如何工作的
Jun 30 Python
python实现redis三种cas事务操作
Dec 19 Python
Python使用MD5加密算法对字符串进行加密操作示例
Mar 30 Python
一篇文章读懂Python赋值与拷贝
Apr 19 Python
Selenium的使用详解
Oct 19 Python
python列表list保留顺序去重的实例
Dec 14 Python
Python 下载及安装详细步骤
Nov 04 Python
了解一下python内建模块collections
Sep 07 Python
python之随机数函数的实现示例
Dec 30 Python
python重要函数eval多种用法解析
Jan 14 #Python
关于ResNeXt网络的pytorch实现
Jan 14 #Python
Python属性和内建属性实例解析
Jan 14 #Python
Python程序控制语句用法实例分析
Jan 14 #Python
dpn网络的pytorch实现方式
Jan 14 #Python
Django之form组件自动校验数据实现
Jan 14 #Python
简单了解python filter、map、reduce的区别
Jan 14 #Python
You might like
使用PHP socke 向指定页面提交数据
2008/07/23 PHP
PHP在不同页面间传递Json数据示例代码
2013/06/08 PHP
载入进度条 效果
2006/07/08 Javascript
不错的JS中变量相关的细节分析
2007/08/13 Javascript
JS小框架 fly javascript framework
2009/11/26 Javascript
基于jquery的防止大图片撑破页面的实现代码(立即缩放)
2011/10/24 Javascript
让JavaScript和其它资源并发下载的方法
2014/10/16 Javascript
JS实现在页面随时自定义背景颜色的方法
2015/02/27 Javascript
EasyUI中datagrid在ie下reload失败解决方案
2015/03/09 Javascript
关于webpack2和模块打包的新手指南(小结)
2017/08/07 Javascript
原生js实现省市区三级联动代码分享
2018/02/12 Javascript
vue.js中toast用法及使用toast弹框的实例代码
2018/08/27 Javascript
vue实现二级导航栏效果
2019/10/19 Javascript
微信小程序中插入激励视频广告并获取收益(实例代码)
2019/12/06 Javascript
react PropTypes校验传递的值操作示例
2020/04/28 Javascript
浅谈React中组件逻辑复用的那些事儿
2020/05/21 Javascript
实例讲解React 组件生命周期
2020/07/08 Javascript
python正则表达式之作业计算器
2016/03/18 Python
windows系统下Python环境的搭建(Aptana Studio)
2017/03/06 Python
Python使用SQLite和Excel操作进行数据分析
2018/01/20 Python
Python cookbook(数据结构与算法)将多个映射合并为单个映射的方法
2018/04/19 Python
使用python编写udp协议的ping程序方法
2018/04/22 Python
基于Python实现船舶的MMSI的获取(推荐)
2019/10/21 Python
详解python 中in 的 用法
2019/12/12 Python
Python SQLAlchemy库的使用方法
2020/10/13 Python
Python如何实现感知器的逻辑电路
2020/12/25 Python
Python入门基础之数字字符串与列表
2021/02/01 Python
比利时香水网上商店:NOTINO
2018/03/28 全球购物
幼儿园毕业家长感言
2014/02/10 职场文书
公务员转正鉴定材料
2014/02/11 职场文书
优秀医生事迹材料
2014/02/12 职场文书
开学典礼感言
2014/02/16 职场文书
公司自我介绍演讲稿
2014/08/21 职场文书
2015年端午节活动策划书
2015/05/05 职场文书
详解JAVA的控制语句
2021/11/11 Java/Android
Redis基本数据类型Zset有序集合常用操作
2022/06/01 Redis