PyTorch实现ResNet50、ResNet101和ResNet152示例


Posted in Python onJanuary 14, 2020

PyTorch: https://github.com/shanglianlm0525/PyTorch-Networks

PyTorch实现ResNet50、ResNet101和ResNet152示例

import torch
import torch.nn as nn
import torchvision
import numpy as np

print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)

__all__ = ['ResNet50', 'ResNet101','ResNet152']

def Conv1(in_planes, places, stride=2):
  return nn.Sequential(
    nn.Conv2d(in_channels=in_planes,out_channels=places,kernel_size=7,stride=stride,padding=3, bias=False),
    nn.BatchNorm2d(places),
    nn.ReLU(inplace=True),
    nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  )

class Bottleneck(nn.Module):
  def __init__(self,in_places,places, stride=1,downsampling=False, expansion = 4):
    super(Bottleneck,self).__init__()
    self.expansion = expansion
    self.downsampling = downsampling

    self.bottleneck = nn.Sequential(
      nn.Conv2d(in_channels=in_places,out_channels=places,kernel_size=1,stride=1, bias=False),
      nn.BatchNorm2d(places),
      nn.ReLU(inplace=True),
      nn.Conv2d(in_channels=places, out_channels=places, kernel_size=3, stride=stride, padding=1, bias=False),
      nn.BatchNorm2d(places),
      nn.ReLU(inplace=True),
      nn.Conv2d(in_channels=places, out_channels=places*self.expansion, kernel_size=1, stride=1, bias=False),
      nn.BatchNorm2d(places*self.expansion),
    )

    if self.downsampling:
      self.downsample = nn.Sequential(
        nn.Conv2d(in_channels=in_places, out_channels=places*self.expansion, kernel_size=1, stride=stride, bias=False),
        nn.BatchNorm2d(places*self.expansion)
      )
    self.relu = nn.ReLU(inplace=True)
  def forward(self, x):
    residual = x
    out = self.bottleneck(x)

    if self.downsampling:
      residual = self.downsample(x)

    out += residual
    out = self.relu(out)
    return out

class ResNet(nn.Module):
  def __init__(self,blocks, num_classes=1000, expansion = 4):
    super(ResNet,self).__init__()
    self.expansion = expansion

    self.conv1 = Conv1(in_planes = 3, places= 64)

    self.layer1 = self.make_layer(in_places = 64, places= 64, block=blocks[0], stride=1)
    self.layer2 = self.make_layer(in_places = 256,places=128, block=blocks[1], stride=2)
    self.layer3 = self.make_layer(in_places=512,places=256, block=blocks[2], stride=2)
    self.layer4 = self.make_layer(in_places=1024,places=512, block=blocks[3], stride=2)

    self.avgpool = nn.AvgPool2d(7, stride=1)
    self.fc = nn.Linear(2048,num_classes)

    for m in self.modules():
      if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
      elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)

  def make_layer(self, in_places, places, block, stride):
    layers = []
    layers.append(Bottleneck(in_places, places,stride, downsampling =True))
    for i in range(1, block):
      layers.append(Bottleneck(places*self.expansion, places))

    return nn.Sequential(*layers)


  def forward(self, x):
    x = self.conv1(x)

    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    x = self.layer4(x)

    x = self.avgpool(x)
    x = x.view(x.size(0), -1)
    x = self.fc(x)
    return x

def ResNet50():
  return ResNet([3, 4, 6, 3])

def ResNet101():
  return ResNet([3, 4, 23, 3])

def ResNet152():
  return ResNet([3, 8, 36, 3])


if __name__=='__main__':
  #model = torchvision.models.resnet50()
  model = ResNet50()
  print(model)

  input = torch.randn(1, 3, 224, 224)
  out = model(input)
  print(out.shape)

以上这篇PyTorch实现ResNet50、ResNet101和ResNet152示例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
完美解决Python2操作中文名文件乱码的问题
Jan 04 Python
基于python实现在excel中读取与生成随机数写入excel中
Jan 04 Python
Python 实现某个功能每隔一段时间被执行一次的功能方法
Oct 14 Python
Python实现简单层次聚类算法以及可视化
Mar 18 Python
Python 爬虫实现增加播客访问量的方法实现
Oct 31 Python
详解Django admin高级用法
Nov 06 Python
Python字符串格式化输出代码实例
Nov 22 Python
python函数声明和调用定义及原理详解
Dec 02 Python
Python3实现mysql连接和数据框的形成(实例代码)
Jan 17 Python
Django实现后台上传并显示图片功能
May 29 Python
Python通过队列来实现进程间通信的示例
Oct 14 Python
pytorch加载语音类自定义数据集的方法教程
Nov 10 Python
python重要函数eval多种用法解析
Jan 14 #Python
关于ResNeXt网络的pytorch实现
Jan 14 #Python
Python属性和内建属性实例解析
Jan 14 #Python
Python程序控制语句用法实例分析
Jan 14 #Python
dpn网络的pytorch实现方式
Jan 14 #Python
Django之form组件自动校验数据实现
Jan 14 #Python
简单了解python filter、map、reduce的区别
Jan 14 #Python
You might like
PHP中动态显示签名和ip原理
2007/03/28 PHP
百度地图API应用之获取用户的具体位置
2014/06/10 PHP
CI使用Tank Auth转移数据库导致密码用户错误的解决办法
2014/06/12 PHP
php版本的cron定时任务执行器使用实例
2014/08/19 PHP
PHP实现十进制、二进制、八进制和十六进制转换相关函数用法分析
2017/04/25 PHP
PHP 无限级分类
2017/05/04 PHP
通过PHP的Wrapper无缝迁移原有项目到新服务的实现方法
2020/04/02 PHP
JavaScript 开发中规范性的一点感想
2009/06/23 Javascript
JavaScript中的this实例分析
2011/04/28 Javascript
apply和call方法定义及apply和call方法的区别
2015/11/15 Javascript
javascript中加var和不加var的区别 你真的懂吗
2016/01/06 Javascript
深入理解JavaScript内置函数
2016/06/03 Javascript
JS原型链 详解及示例代码
2016/09/06 Javascript
简单的jQuery拖拽排序效果的实现(增强动态)
2017/02/09 Javascript
微信小程序-滚动消息通知的实例代码
2017/08/03 Javascript
Vue from-validate 表单验证的示例代码
2017/09/26 Javascript
React数据传递之组件内部通信的方法
2017/12/31 Javascript
JavaScript中join()、splice()、slice()和split()函数用法示例
2018/08/24 Javascript
vue 在单页面应用里使用二级套嵌路由
2020/12/19 Vue.js
[43:26]完美世界DOTA2联赛PWL S2 Forest vs Rebirth 第二场 11.20
2020/11/23 DOTA
Python 元组(Tuple)操作详解
2014/03/11 Python
Python数据结构与算法之使用队列解决小猫钓鱼问题
2017/12/14 Python
Python中将dataframe转换为字典的实例
2018/04/13 Python
Jupyter中直接显示Matplotlib的图形方法
2018/05/24 Python
python matplotlib饼状图参数及用法解析
2019/11/04 Python
如何用OpenCV -python3实现视频物体追踪
2019/12/04 Python
django为Form生成的label标签添加class方式
2020/05/20 Python
详解python使用金山词霸的翻译功能(调试工具断点的使用)
2021/01/07 Python
《观舞记》教学反思
2014/04/16 职场文书
2014年社区重阳节活动策划方案
2014/09/16 职场文书
政府领导干部个人对照检查材料思想汇报
2014/09/24 职场文书
师德承诺书
2015/01/20 职场文书
教师年度个人总结
2015/02/11 职场文书
2015年企业团支部工作总结
2015/05/21 职场文书
详解Java线程池是如何重复利用空闲线程的
2021/06/26 Java/Android
Python学习之迭代器详解
2022/04/01 Python