PyTorch实现ResNet50、ResNet101和ResNet152示例


Posted in Python onJanuary 14, 2020

PyTorch: https://github.com/shanglianlm0525/PyTorch-Networks

PyTorch实现ResNet50、ResNet101和ResNet152示例

import torch
import torch.nn as nn
import torchvision
import numpy as np

print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)

__all__ = ['ResNet50', 'ResNet101','ResNet152']

def Conv1(in_planes, places, stride=2):
  return nn.Sequential(
    nn.Conv2d(in_channels=in_planes,out_channels=places,kernel_size=7,stride=stride,padding=3, bias=False),
    nn.BatchNorm2d(places),
    nn.ReLU(inplace=True),
    nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  )

class Bottleneck(nn.Module):
  def __init__(self,in_places,places, stride=1,downsampling=False, expansion = 4):
    super(Bottleneck,self).__init__()
    self.expansion = expansion
    self.downsampling = downsampling

    self.bottleneck = nn.Sequential(
      nn.Conv2d(in_channels=in_places,out_channels=places,kernel_size=1,stride=1, bias=False),
      nn.BatchNorm2d(places),
      nn.ReLU(inplace=True),
      nn.Conv2d(in_channels=places, out_channels=places, kernel_size=3, stride=stride, padding=1, bias=False),
      nn.BatchNorm2d(places),
      nn.ReLU(inplace=True),
      nn.Conv2d(in_channels=places, out_channels=places*self.expansion, kernel_size=1, stride=1, bias=False),
      nn.BatchNorm2d(places*self.expansion),
    )

    if self.downsampling:
      self.downsample = nn.Sequential(
        nn.Conv2d(in_channels=in_places, out_channels=places*self.expansion, kernel_size=1, stride=stride, bias=False),
        nn.BatchNorm2d(places*self.expansion)
      )
    self.relu = nn.ReLU(inplace=True)
  def forward(self, x):
    residual = x
    out = self.bottleneck(x)

    if self.downsampling:
      residual = self.downsample(x)

    out += residual
    out = self.relu(out)
    return out

class ResNet(nn.Module):
  def __init__(self,blocks, num_classes=1000, expansion = 4):
    super(ResNet,self).__init__()
    self.expansion = expansion

    self.conv1 = Conv1(in_planes = 3, places= 64)

    self.layer1 = self.make_layer(in_places = 64, places= 64, block=blocks[0], stride=1)
    self.layer2 = self.make_layer(in_places = 256,places=128, block=blocks[1], stride=2)
    self.layer3 = self.make_layer(in_places=512,places=256, block=blocks[2], stride=2)
    self.layer4 = self.make_layer(in_places=1024,places=512, block=blocks[3], stride=2)

    self.avgpool = nn.AvgPool2d(7, stride=1)
    self.fc = nn.Linear(2048,num_classes)

    for m in self.modules():
      if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
      elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)

  def make_layer(self, in_places, places, block, stride):
    layers = []
    layers.append(Bottleneck(in_places, places,stride, downsampling =True))
    for i in range(1, block):
      layers.append(Bottleneck(places*self.expansion, places))

    return nn.Sequential(*layers)


  def forward(self, x):
    x = self.conv1(x)

    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    x = self.layer4(x)

    x = self.avgpool(x)
    x = x.view(x.size(0), -1)
    x = self.fc(x)
    return x

def ResNet50():
  return ResNet([3, 4, 6, 3])

def ResNet101():
  return ResNet([3, 4, 23, 3])

def ResNet152():
  return ResNet([3, 8, 36, 3])


if __name__=='__main__':
  #model = torchvision.models.resnet50()
  model = ResNet50()
  print(model)

  input = torch.randn(1, 3, 224, 224)
  out = model(input)
  print(out.shape)

以上这篇PyTorch实现ResNet50、ResNet101和ResNet152示例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现读取命令行参数的方法
May 22 Python
python编写微信远程控制电脑的程序
Jan 05 Python
pycharm执行python时,填写参数的方法
Oct 29 Python
讲解Python3中NumPy数组寻找特定元素下标的两种方法
Aug 04 Python
docker django无法访问redis容器的解决方法
Aug 21 Python
python与mysql数据库交互的实现
Jan 06 Python
Python threading.local代码实例及原理解析
Mar 16 Python
pytorch查看模型weight与grad方式
Jun 24 Python
Python正则re模块使用步骤及原理解析
Aug 18 Python
python 如何把docker-compose.yaml导入到数据库相关条目里
Jan 15 Python
使用python如何删除同一文件夹下相似的图片
May 07 Python
Python利用机器学习算法实现垃圾邮件的识别
Jun 28 Python
python重要函数eval多种用法解析
Jan 14 #Python
关于ResNeXt网络的pytorch实现
Jan 14 #Python
Python属性和内建属性实例解析
Jan 14 #Python
Python程序控制语句用法实例分析
Jan 14 #Python
dpn网络的pytorch实现方式
Jan 14 #Python
Django之form组件自动校验数据实现
Jan 14 #Python
简单了解python filter、map、reduce的区别
Jan 14 #Python
You might like
PHP魔术方法的使用示例
2015/06/23 PHP
PHP中__set()实例用法和基础讲解
2019/07/23 PHP
表单提交验证类
2006/07/14 Javascript
Firebug入门指南(Firefox浏览器)
2010/08/21 Javascript
jquery $.getJSON()跨域请求
2011/12/21 Javascript
Prototype源码浅析 String部分(一)之有关indexOf优化
2012/01/15 Javascript
JQuery入门——用one()方法绑定事件处理函数(仅触发一次)
2013/02/05 Javascript
uploadify在Firefox下丢失session问题的解决方法
2013/08/07 Javascript
js鼠标及对象坐标控制属性详细解析
2013/12/14 Javascript
JavaScript给url网址进行encode编码的方法
2015/03/18 Javascript
js实现简单div拖拽功能实例
2015/05/12 Javascript
JS+CSS实现的经典tab选项卡效果代码
2015/09/16 Javascript
Express URL跳转(重定向)的实现方法
2017/04/07 Javascript
Vue.js做select下拉列表的实例(ul-li标签仿select标签)
2018/03/02 Javascript
深度了解vue.js中hooks的相关知识
2019/06/14 Javascript
解决vue动态路由异步加载import组件,加载不到module的问题
2020/07/26 Javascript
[01:24:34]2014 DOTA2华西杯精英邀请赛5 24 DK VS LGD
2014/05/25 DOTA
[49:08]OpTic vs Serenity 2018国际邀请赛小组赛BO2 第一场 8.18
2018/08/19 DOTA
[54:43]DOTA2-DPC中国联赛 正赛 CDEC vs Dynasty BO3 第一场 2月22日
2021/03/11 DOTA
python解析xml文件实例分析
2015/05/27 Python
Python 通配符删除文件的实例
2018/04/24 Python
Python常用模块之requests模块用法分析
2019/05/15 Python
python使用 cx_Oracle 模块进行查询操作示例
2019/11/28 Python
HTML5 本地存储实现购物车功能
2017/09/07 HTML / CSS
浅析HTML5的WebSocket与服务器推送事件
2016/02/19 HTML / CSS
使用html5新特性轻松监听任何App自带返回键的示例
2018/03/13 HTML / CSS
理肤泉俄罗斯官网:La Roche-Posay俄罗斯
2018/07/24 全球购物
汽车运用工程系毕业生自荐信
2013/12/27 职场文书
员工工作表扬信范文
2014/01/13 职场文书
小学教师暑期培训方案
2014/08/28 职场文书
村主任“四风”问题个人对照检查材料思想汇报
2014/10/02 职场文书
2014年绩效考核工作总结
2014/12/11 职场文书
运动员加油词
2015/07/18 职场文书
Vue Element-ui表单校验规则实现
2021/07/09 Vue.js
基于Python实现一个春节倒计时脚本
2022/01/22 Python
Java十分钟精通进阶适配器模式
2022/04/06 Java/Android