PyTorch实现ResNet50、ResNet101和ResNet152示例


Posted in Python onJanuary 14, 2020

PyTorch: https://github.com/shanglianlm0525/PyTorch-Networks

PyTorch实现ResNet50、ResNet101和ResNet152示例

import torch
import torch.nn as nn
import torchvision
import numpy as np

print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)

__all__ = ['ResNet50', 'ResNet101','ResNet152']

def Conv1(in_planes, places, stride=2):
  return nn.Sequential(
    nn.Conv2d(in_channels=in_planes,out_channels=places,kernel_size=7,stride=stride,padding=3, bias=False),
    nn.BatchNorm2d(places),
    nn.ReLU(inplace=True),
    nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  )

class Bottleneck(nn.Module):
  def __init__(self,in_places,places, stride=1,downsampling=False, expansion = 4):
    super(Bottleneck,self).__init__()
    self.expansion = expansion
    self.downsampling = downsampling

    self.bottleneck = nn.Sequential(
      nn.Conv2d(in_channels=in_places,out_channels=places,kernel_size=1,stride=1, bias=False),
      nn.BatchNorm2d(places),
      nn.ReLU(inplace=True),
      nn.Conv2d(in_channels=places, out_channels=places, kernel_size=3, stride=stride, padding=1, bias=False),
      nn.BatchNorm2d(places),
      nn.ReLU(inplace=True),
      nn.Conv2d(in_channels=places, out_channels=places*self.expansion, kernel_size=1, stride=1, bias=False),
      nn.BatchNorm2d(places*self.expansion),
    )

    if self.downsampling:
      self.downsample = nn.Sequential(
        nn.Conv2d(in_channels=in_places, out_channels=places*self.expansion, kernel_size=1, stride=stride, bias=False),
        nn.BatchNorm2d(places*self.expansion)
      )
    self.relu = nn.ReLU(inplace=True)
  def forward(self, x):
    residual = x
    out = self.bottleneck(x)

    if self.downsampling:
      residual = self.downsample(x)

    out += residual
    out = self.relu(out)
    return out

class ResNet(nn.Module):
  def __init__(self,blocks, num_classes=1000, expansion = 4):
    super(ResNet,self).__init__()
    self.expansion = expansion

    self.conv1 = Conv1(in_planes = 3, places= 64)

    self.layer1 = self.make_layer(in_places = 64, places= 64, block=blocks[0], stride=1)
    self.layer2 = self.make_layer(in_places = 256,places=128, block=blocks[1], stride=2)
    self.layer3 = self.make_layer(in_places=512,places=256, block=blocks[2], stride=2)
    self.layer4 = self.make_layer(in_places=1024,places=512, block=blocks[3], stride=2)

    self.avgpool = nn.AvgPool2d(7, stride=1)
    self.fc = nn.Linear(2048,num_classes)

    for m in self.modules():
      if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
      elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)

  def make_layer(self, in_places, places, block, stride):
    layers = []
    layers.append(Bottleneck(in_places, places,stride, downsampling =True))
    for i in range(1, block):
      layers.append(Bottleneck(places*self.expansion, places))

    return nn.Sequential(*layers)


  def forward(self, x):
    x = self.conv1(x)

    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    x = self.layer4(x)

    x = self.avgpool(x)
    x = x.view(x.size(0), -1)
    x = self.fc(x)
    return x

def ResNet50():
  return ResNet([3, 4, 6, 3])

def ResNet101():
  return ResNet([3, 4, 23, 3])

def ResNet152():
  return ResNet([3, 8, 36, 3])


if __name__=='__main__':
  #model = torchvision.models.resnet50()
  model = ResNet50()
  print(model)

  input = torch.randn(1, 3, 224, 224)
  out = model(input)
  print(out.shape)

以上这篇PyTorch实现ResNet50、ResNet101和ResNet152示例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python线程锁(thread)学习示例
Dec 04 Python
python进程管理工具supervisor的安装与使用教程
Sep 05 Python
python在ubuntu中的几种安装方法(小结)
Dec 08 Python
Python实现翻转数组功能示例
Jan 12 Python
朴素贝叶斯分类算法原理与Python实现与使用方法案例
Jun 26 Python
在PyCharm中实现关闭一个死循环程序的方法
Nov 29 Python
Python中字符串与编码示例代码
May 20 Python
利用python生成照片墙的示例代码
Apr 09 Python
python thrift 实现 单端口多服务的过程
Jun 08 Python
matplotlib基础绘图命令之bar的使用方法
Aug 13 Python
python识别验证码的思路及解决方案
Sep 13 Python
python 制作网站小说下载器
Feb 20 Python
python重要函数eval多种用法解析
Jan 14 #Python
关于ResNeXt网络的pytorch实现
Jan 14 #Python
Python属性和内建属性实例解析
Jan 14 #Python
Python程序控制语句用法实例分析
Jan 14 #Python
dpn网络的pytorch实现方式
Jan 14 #Python
Django之form组件自动校验数据实现
Jan 14 #Python
简单了解python filter、map、reduce的区别
Jan 14 #Python
You might like
解析php中var_dump,var_export,print_r三个函数的区别
2013/06/21 PHP
解决CodeIgniter伪静态失效
2014/06/09 PHP
PHP简单实现HTTP和HTTPS跨域共享session解决办法
2015/05/27 PHP
jQuery+Ajax+PHP“喜欢”评级功能实现代码
2015/10/08 PHP
Gambit vs ForZe BO3 第二场 2.13
2021/03/10 DOTA
JavaScript中的Document文档对象
2008/01/16 Javascript
js DataSet数据源处理代码
2010/03/29 Javascript
jquery.post用法示例代码
2014/01/03 Javascript
js数组依据下标删除元素
2015/04/14 Javascript
基于jquery实现动态竖向柱状条特效
2016/02/12 Javascript
Javascript 函数的四种调用模式
2016/11/05 Javascript
微信小程序开发(一) 微信登录流程详解
2017/01/11 Javascript
JS实现动态修改table及合并单元格的方法示例
2017/02/20 Javascript
详解node如何让一个端口同时支持https与http
2017/07/04 Javascript
Node.js如何使用Diffie-Hellman密钥交换算法详解
2017/09/05 Javascript
微信小程序使用swiper组件实现层叠轮播图
2018/11/04 Javascript
vue发送websocket请求和http post请求的实例代码
2019/07/11 Javascript
JavaScript实现动态生成表格
2020/08/02 Javascript
JavaScript中变量提升和函数提升的详解
2020/08/07 Javascript
[02:59]DOTA2完美大师赛主赛事第三日精彩集锦
2017/11/25 DOTA
python选择排序算法实例总结
2015/07/01 Python
使用pyecharts无法import Bar的解决方案
2020/04/23 Python
解决Python数据可视化中文部分显示方块问题
2020/05/16 Python
法国一家芭蕾舞鞋公司:Repetto
2018/11/12 全球购物
Etam艾格英国官网:法国著名女装品牌
2019/04/15 全球购物
市场营销专业求职信
2014/06/17 职场文书
卫生院艾滋病宣传活动小结
2014/07/09 职场文书
三八妇女节超市活动方案
2014/08/18 职场文书
党员反对四风问题思想汇报
2014/09/12 职场文书
防火标语大全
2014/10/06 职场文书
个人年终总结开头
2015/03/06 职场文书
2015年销售人员工作总结
2015/04/07 职场文书
办公室管理规章制度
2015/08/04 职场文书
2016年禁毒宣传活动总结
2016/04/05 职场文书
Python识别花卉种类鉴定网络热门植物并自动整理分类
2022/04/08 Python
centos环境下nginx高可用集群的搭建指南
2022/07/23 Servers