PyTorch实现ResNet50、ResNet101和ResNet152示例


Posted in Python onJanuary 14, 2020

PyTorch: https://github.com/shanglianlm0525/PyTorch-Networks

PyTorch实现ResNet50、ResNet101和ResNet152示例

import torch
import torch.nn as nn
import torchvision
import numpy as np

print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)

__all__ = ['ResNet50', 'ResNet101','ResNet152']

def Conv1(in_planes, places, stride=2):
  return nn.Sequential(
    nn.Conv2d(in_channels=in_planes,out_channels=places,kernel_size=7,stride=stride,padding=3, bias=False),
    nn.BatchNorm2d(places),
    nn.ReLU(inplace=True),
    nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  )

class Bottleneck(nn.Module):
  def __init__(self,in_places,places, stride=1,downsampling=False, expansion = 4):
    super(Bottleneck,self).__init__()
    self.expansion = expansion
    self.downsampling = downsampling

    self.bottleneck = nn.Sequential(
      nn.Conv2d(in_channels=in_places,out_channels=places,kernel_size=1,stride=1, bias=False),
      nn.BatchNorm2d(places),
      nn.ReLU(inplace=True),
      nn.Conv2d(in_channels=places, out_channels=places, kernel_size=3, stride=stride, padding=1, bias=False),
      nn.BatchNorm2d(places),
      nn.ReLU(inplace=True),
      nn.Conv2d(in_channels=places, out_channels=places*self.expansion, kernel_size=1, stride=1, bias=False),
      nn.BatchNorm2d(places*self.expansion),
    )

    if self.downsampling:
      self.downsample = nn.Sequential(
        nn.Conv2d(in_channels=in_places, out_channels=places*self.expansion, kernel_size=1, stride=stride, bias=False),
        nn.BatchNorm2d(places*self.expansion)
      )
    self.relu = nn.ReLU(inplace=True)
  def forward(self, x):
    residual = x
    out = self.bottleneck(x)

    if self.downsampling:
      residual = self.downsample(x)

    out += residual
    out = self.relu(out)
    return out

class ResNet(nn.Module):
  def __init__(self,blocks, num_classes=1000, expansion = 4):
    super(ResNet,self).__init__()
    self.expansion = expansion

    self.conv1 = Conv1(in_planes = 3, places= 64)

    self.layer1 = self.make_layer(in_places = 64, places= 64, block=blocks[0], stride=1)
    self.layer2 = self.make_layer(in_places = 256,places=128, block=blocks[1], stride=2)
    self.layer3 = self.make_layer(in_places=512,places=256, block=blocks[2], stride=2)
    self.layer4 = self.make_layer(in_places=1024,places=512, block=blocks[3], stride=2)

    self.avgpool = nn.AvgPool2d(7, stride=1)
    self.fc = nn.Linear(2048,num_classes)

    for m in self.modules():
      if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
      elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)

  def make_layer(self, in_places, places, block, stride):
    layers = []
    layers.append(Bottleneck(in_places, places,stride, downsampling =True))
    for i in range(1, block):
      layers.append(Bottleneck(places*self.expansion, places))

    return nn.Sequential(*layers)


  def forward(self, x):
    x = self.conv1(x)

    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    x = self.layer4(x)

    x = self.avgpool(x)
    x = x.view(x.size(0), -1)
    x = self.fc(x)
    return x

def ResNet50():
  return ResNet([3, 4, 6, 3])

def ResNet101():
  return ResNet([3, 4, 23, 3])

def ResNet152():
  return ResNet([3, 8, 36, 3])


if __name__=='__main__':
  #model = torchvision.models.resnet50()
  model = ResNet50()
  print(model)

  input = torch.randn(1, 3, 224, 224)
  out = model(input)
  print(out.shape)

以上这篇PyTorch实现ResNet50、ResNet101和ResNet152示例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 正则表达式 概述及常用字符
May 04 Python
Python random模块(获取随机数)常用方法和使用例子
May 13 Python
python连接远程ftp服务器并列出目录下文件的方法
Apr 01 Python
分数霸榜! python助你微信跳一跳拿高分
Jan 08 Python
完美解决安装完tensorflow后pip无法使用的问题
Jun 11 Python
Django生成PDF文档显示在网页上以及解决PDF中文显示乱码的问题
Jul 04 Python
python框架flask表单实现详解
Nov 04 Python
Python Numpy 自然数填充数组的实现
Nov 28 Python
python中的split()函数和os.path.split()函数使用详解
Dec 21 Python
python系统指定文件的查找只输出目录下所有文件及文件夹
Jan 19 Python
python编程进阶之异常处理用法实例分析
Feb 21 Python
Python页面加载的等待方式总结
Feb 28 Python
python重要函数eval多种用法解析
Jan 14 #Python
关于ResNeXt网络的pytorch实现
Jan 14 #Python
Python属性和内建属性实例解析
Jan 14 #Python
Python程序控制语句用法实例分析
Jan 14 #Python
dpn网络的pytorch实现方式
Jan 14 #Python
Django之form组件自动校验数据实现
Jan 14 #Python
简单了解python filter、map、reduce的区别
Jan 14 #Python
You might like
一个PHP+MSSQL分页的例子
2006/10/09 PHP
PHP中使用unset销毁变量并内存释放问题
2012/07/05 PHP
PHP服务器页面间跳转实现方法
2012/08/02 PHP
PHP 异步执行方法,模拟多线程的应用分析
2013/06/03 PHP
php实现利用phpexcel导出数据
2013/08/24 PHP
php简单随机字符串生成方法示例
2017/04/19 PHP
JavaScript类和继承 prototype属性
2010/09/03 Javascript
JS中showModalDialog 的使用解析
2013/04/17 Javascript
调试代码导致IE出错的避免方法
2014/04/04 Javascript
jQuery中prev()方法用法实例
2015/01/08 Javascript
javascript实现超炫的向上滑行菜单实例
2015/08/03 Javascript
jquery+CSS3实现3D拖拽相册效果
2016/07/18 Javascript
jQuery基于排序功能实现上移、下移的方法
2016/11/26 Javascript
JavaScript自定义分页样式
2017/01/17 Javascript
分分钟玩转Vue.js组件(二)
2017/03/01 Javascript
解决iview多表头动态更改列元素发生的错误的方法
2018/11/02 Javascript
详解vue中async-await的使用误区
2018/12/05 Javascript
vue中过滤器filter的讲解
2019/01/21 Javascript
简述vue-cli中chainWebpack的使用方法
2019/07/30 Javascript
js实现小时钟效果
2020/03/25 Javascript
[54:06]OG vs TNC 2018国际邀请赛小组赛BO2 第二场 8.19
2018/08/21 DOTA
Python 学习笔记
2008/12/27 Python
python批量下载图片的三种方法
2013/04/22 Python
python常见排序算法基础教程
2017/04/13 Python
python+pyqt5实现KFC点餐收银系统
2019/01/24 Python
Python Numpy中数据的常用保存与读取方法
2020/04/01 Python
Python tkinter实现日期选择器
2021/02/22 Python
如何用border-image实现文字气泡边框的示例代码
2020/01/21 HTML / CSS
HTML5通过调用canvas对象的getContext()方法来获取绘图环境
2014/06/23 HTML / CSS
Pottery Barn阿联酋:购买家具、家居装饰及更多
2019/12/08 全球购物
畜牧兽医本科生的自我评价
2014/03/03 职场文书
个人自我剖析材料
2014/09/30 职场文书
地下停车场租赁协议范本
2014/10/07 职场文书
2014年节能减排工作总结
2014/12/06 职场文书
浅谈JavaScript浅拷贝和深拷贝
2021/11/07 Javascript
python_tkinter弹出对话框创建
2022/03/20 Python