PyTorch实现ResNet50、ResNet101和ResNet152示例


Posted in Python onJanuary 14, 2020

PyTorch: https://github.com/shanglianlm0525/PyTorch-Networks

PyTorch实现ResNet50、ResNet101和ResNet152示例

import torch
import torch.nn as nn
import torchvision
import numpy as np

print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)

__all__ = ['ResNet50', 'ResNet101','ResNet152']

def Conv1(in_planes, places, stride=2):
  return nn.Sequential(
    nn.Conv2d(in_channels=in_planes,out_channels=places,kernel_size=7,stride=stride,padding=3, bias=False),
    nn.BatchNorm2d(places),
    nn.ReLU(inplace=True),
    nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  )

class Bottleneck(nn.Module):
  def __init__(self,in_places,places, stride=1,downsampling=False, expansion = 4):
    super(Bottleneck,self).__init__()
    self.expansion = expansion
    self.downsampling = downsampling

    self.bottleneck = nn.Sequential(
      nn.Conv2d(in_channels=in_places,out_channels=places,kernel_size=1,stride=1, bias=False),
      nn.BatchNorm2d(places),
      nn.ReLU(inplace=True),
      nn.Conv2d(in_channels=places, out_channels=places, kernel_size=3, stride=stride, padding=1, bias=False),
      nn.BatchNorm2d(places),
      nn.ReLU(inplace=True),
      nn.Conv2d(in_channels=places, out_channels=places*self.expansion, kernel_size=1, stride=1, bias=False),
      nn.BatchNorm2d(places*self.expansion),
    )

    if self.downsampling:
      self.downsample = nn.Sequential(
        nn.Conv2d(in_channels=in_places, out_channels=places*self.expansion, kernel_size=1, stride=stride, bias=False),
        nn.BatchNorm2d(places*self.expansion)
      )
    self.relu = nn.ReLU(inplace=True)
  def forward(self, x):
    residual = x
    out = self.bottleneck(x)

    if self.downsampling:
      residual = self.downsample(x)

    out += residual
    out = self.relu(out)
    return out

class ResNet(nn.Module):
  def __init__(self,blocks, num_classes=1000, expansion = 4):
    super(ResNet,self).__init__()
    self.expansion = expansion

    self.conv1 = Conv1(in_planes = 3, places= 64)

    self.layer1 = self.make_layer(in_places = 64, places= 64, block=blocks[0], stride=1)
    self.layer2 = self.make_layer(in_places = 256,places=128, block=blocks[1], stride=2)
    self.layer3 = self.make_layer(in_places=512,places=256, block=blocks[2], stride=2)
    self.layer4 = self.make_layer(in_places=1024,places=512, block=blocks[3], stride=2)

    self.avgpool = nn.AvgPool2d(7, stride=1)
    self.fc = nn.Linear(2048,num_classes)

    for m in self.modules():
      if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
      elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)

  def make_layer(self, in_places, places, block, stride):
    layers = []
    layers.append(Bottleneck(in_places, places,stride, downsampling =True))
    for i in range(1, block):
      layers.append(Bottleneck(places*self.expansion, places))

    return nn.Sequential(*layers)


  def forward(self, x):
    x = self.conv1(x)

    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    x = self.layer4(x)

    x = self.avgpool(x)
    x = x.view(x.size(0), -1)
    x = self.fc(x)
    return x

def ResNet50():
  return ResNet([3, 4, 6, 3])

def ResNet101():
  return ResNet([3, 4, 23, 3])

def ResNet152():
  return ResNet([3, 8, 36, 3])


if __name__=='__main__':
  #model = torchvision.models.resnet50()
  model = ResNet50()
  print(model)

  input = torch.randn(1, 3, 224, 224)
  out = model(input)
  print(out.shape)

以上这篇PyTorch实现ResNet50、ResNet101和ResNet152示例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现从脚本里运行scrapy的方法
Apr 07 Python
python妹子图简单爬虫实例
Jul 07 Python
浅谈Python 的枚举 Enum
Jun 12 Python
pandas 获取季度,月度,年度首尾日期的方法
Apr 11 Python
Python之批量创建文件的实例讲解
May 10 Python
python pandas 对series和dataframe的重置索引reindex方法
Jun 07 Python
Python读取txt某几列绘图的方法
Oct 14 Python
使用Python实现毫秒级抢单功能
Jun 06 Python
python求绝对值的三种方法小结
Dec 04 Python
Python chardet库识别编码原理解析
Feb 18 Python
Python3 xml.etree.ElementTree支持的XPath语法详解
Mar 06 Python
Python实现汇率转换操作
May 03 Python
python重要函数eval多种用法解析
Jan 14 #Python
关于ResNeXt网络的pytorch实现
Jan 14 #Python
Python属性和内建属性实例解析
Jan 14 #Python
Python程序控制语句用法实例分析
Jan 14 #Python
dpn网络的pytorch实现方式
Jan 14 #Python
Django之form组件自动校验数据实现
Jan 14 #Python
简单了解python filter、map、reduce的区别
Jan 14 #Python
You might like
Zend Guard一些常见问题解答
2008/09/11 PHP
如何利用php array_multisort函数 对数据库结果进行复杂排序
2013/06/08 PHP
PHP转盘抽奖接口实例
2015/02/09 PHP
Ajax请求PHP后台接口返回信息的实例代码
2018/08/21 PHP
javascript下IE与FF兼容函数收集
2008/09/17 Javascript
js控制文本框只输入数字和小数点的方法
2015/03/10 Javascript
JavaScript的模块化开发框架Sea.js上手指南
2016/05/12 Javascript
React Native之TextInput组件解析示例
2017/08/22 Javascript
Vue.js中关于侦听器(watch)的高级用法示例
2018/05/02 Javascript
Node.js模块全局安装路径配置方法
2018/05/17 Javascript
微信小程序自定义select下拉选项框组件的实现代码
2018/08/28 Javascript
vue中使用vue-cli接入融云实现即时通信
2019/04/19 Javascript
vue路由教程之静态路由
2019/09/03 Javascript
vue登录注册实例详解
2019/09/14 Javascript
Vue data的数据响应式到底是如何实现的
2020/02/11 Javascript
js事件机制----捕获与冒泡机制实例分析
2020/05/22 Javascript
解决vue加scoped后就无法修改vant的UI组件的样式问题
2020/09/07 Javascript
编写Python爬虫抓取豆瓣电影TOP100及用户头像的方法
2016/01/20 Python
Python Json序列化与反序列化的示例
2018/01/31 Python
使用python语言,比较两个字符串是否相同的实例
2018/06/29 Python
python单例模式的多种实现方法
2019/07/26 Python
centos7之Python3.74安装教程
2019/08/15 Python
使用Python的Turtle绘制哆啦A梦实例
2019/11/21 Python
Python中BeautifulSoup通过查找Id获取元素信息
2020/12/07 Python
移动端开发HTML5页面点击按钮后出现闪烁或黑色背景的解决办法
2018/09/19 HTML / CSS
巴西男士胡须和头发护理产品商店:Beard
2017/11/13 全球购物
阿玛瑞酒店中文官方网站:Amari.com
2018/02/13 全球购物
最受欢迎的自我评价
2013/12/22 职场文书
酒店副总岗位职责
2013/12/24 职场文书
工厂实习感言
2014/01/14 职场文书
《一个中国孩子的呼声》教学反思
2014/02/12 职场文书
人力资源经理的岗位职责
2014/03/02 职场文书
预备党员转正思想汇报
2014/09/26 职场文书
2014年小学工作总结
2014/11/26 职场文书
2015国庆节放假通知范文
2015/07/30 职场文书
导游词之太行山青龙峡
2020/01/14 职场文书