PyTorch实现ResNet50、ResNet101和ResNet152示例


Posted in Python onJanuary 14, 2020

PyTorch: https://github.com/shanglianlm0525/PyTorch-Networks

PyTorch实现ResNet50、ResNet101和ResNet152示例

import torch
import torch.nn as nn
import torchvision
import numpy as np

print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)

__all__ = ['ResNet50', 'ResNet101','ResNet152']

def Conv1(in_planes, places, stride=2):
  return nn.Sequential(
    nn.Conv2d(in_channels=in_planes,out_channels=places,kernel_size=7,stride=stride,padding=3, bias=False),
    nn.BatchNorm2d(places),
    nn.ReLU(inplace=True),
    nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  )

class Bottleneck(nn.Module):
  def __init__(self,in_places,places, stride=1,downsampling=False, expansion = 4):
    super(Bottleneck,self).__init__()
    self.expansion = expansion
    self.downsampling = downsampling

    self.bottleneck = nn.Sequential(
      nn.Conv2d(in_channels=in_places,out_channels=places,kernel_size=1,stride=1, bias=False),
      nn.BatchNorm2d(places),
      nn.ReLU(inplace=True),
      nn.Conv2d(in_channels=places, out_channels=places, kernel_size=3, stride=stride, padding=1, bias=False),
      nn.BatchNorm2d(places),
      nn.ReLU(inplace=True),
      nn.Conv2d(in_channels=places, out_channels=places*self.expansion, kernel_size=1, stride=1, bias=False),
      nn.BatchNorm2d(places*self.expansion),
    )

    if self.downsampling:
      self.downsample = nn.Sequential(
        nn.Conv2d(in_channels=in_places, out_channels=places*self.expansion, kernel_size=1, stride=stride, bias=False),
        nn.BatchNorm2d(places*self.expansion)
      )
    self.relu = nn.ReLU(inplace=True)
  def forward(self, x):
    residual = x
    out = self.bottleneck(x)

    if self.downsampling:
      residual = self.downsample(x)

    out += residual
    out = self.relu(out)
    return out

class ResNet(nn.Module):
  def __init__(self,blocks, num_classes=1000, expansion = 4):
    super(ResNet,self).__init__()
    self.expansion = expansion

    self.conv1 = Conv1(in_planes = 3, places= 64)

    self.layer1 = self.make_layer(in_places = 64, places= 64, block=blocks[0], stride=1)
    self.layer2 = self.make_layer(in_places = 256,places=128, block=blocks[1], stride=2)
    self.layer3 = self.make_layer(in_places=512,places=256, block=blocks[2], stride=2)
    self.layer4 = self.make_layer(in_places=1024,places=512, block=blocks[3], stride=2)

    self.avgpool = nn.AvgPool2d(7, stride=1)
    self.fc = nn.Linear(2048,num_classes)

    for m in self.modules():
      if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
      elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)

  def make_layer(self, in_places, places, block, stride):
    layers = []
    layers.append(Bottleneck(in_places, places,stride, downsampling =True))
    for i in range(1, block):
      layers.append(Bottleneck(places*self.expansion, places))

    return nn.Sequential(*layers)


  def forward(self, x):
    x = self.conv1(x)

    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    x = self.layer4(x)

    x = self.avgpool(x)
    x = x.view(x.size(0), -1)
    x = self.fc(x)
    return x

def ResNet50():
  return ResNet([3, 4, 6, 3])

def ResNet101():
  return ResNet([3, 4, 23, 3])

def ResNet152():
  return ResNet([3, 8, 36, 3])


if __name__=='__main__':
  #model = torchvision.models.resnet50()
  model = ResNet50()
  print(model)

  input = torch.randn(1, 3, 224, 224)
  out = model(input)
  print(out.shape)

以上这篇PyTorch实现ResNet50、ResNet101和ResNet152示例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 数据加密代码
Dec 24 Python
python异步任务队列示例
Apr 01 Python
Python脚本实现代码行数统计代码分享
Mar 10 Python
python实现比较两段文本不同之处的方法
May 30 Python
Python使用一行代码获取上个月是几月
Aug 30 Python
使用Python获取网段IP个数以及地址清单的方法
Nov 01 Python
利用python实现简易版的贪吃蛇游戏(面向python小白)
Dec 30 Python
python tkinter组件使用详解
Sep 16 Python
使用Python实现Wake On Lan远程开机功能
Jan 22 Python
Pycharm添加虚拟解释器报错问题解决方案
Oct 13 Python
Django vue前后端分离整合过程解析
Nov 20 Python
Python爬虫爬取微博热搜保存为 Markdown 文件的源码
Feb 22 Python
python重要函数eval多种用法解析
Jan 14 #Python
关于ResNeXt网络的pytorch实现
Jan 14 #Python
Python属性和内建属性实例解析
Jan 14 #Python
Python程序控制语句用法实例分析
Jan 14 #Python
dpn网络的pytorch实现方式
Jan 14 #Python
Django之form组件自动校验数据实现
Jan 14 #Python
简单了解python filter、map、reduce的区别
Jan 14 #Python
You might like
使用PHPMYADMIN操作mysql数据库添加新用户和数据库的方法
2010/04/02 PHP
PHP下对数组进行排序的函数
2010/08/08 PHP
ThinkPHP的cookie和session冲突造成Cookie不能使用的解决方法
2014/07/01 PHP
php实现网页端验证码功能
2017/07/11 PHP
JavaScript中Object和Function的关系小结
2009/09/26 Javascript
JavaScript中的匀速运动和变速(缓冲)运动详细介绍
2012/11/11 Javascript
jquery 通过name快速取值示例
2014/01/24 Javascript
异步JavaScript编程中的Promise使用方法
2015/07/28 Javascript
JavaScript子窗口调用父窗口变量和函数的方法
2015/10/09 Javascript
解决jquery无法找到其他父级子集问题的方法
2016/05/10 Javascript
JS简单实现仿百度控制台输出信息效果
2016/09/04 Javascript
vue+vue-validator 表单验证功能的实现代码
2017/11/13 Javascript
js中async函数结合promise的小案例浅析
2019/04/14 Javascript
layUI使用layer.open,在content打开数据表格,获取值并返回的方法
2019/09/26 Javascript
vue element-ui实现input输入框金额数字添加千分位
2019/12/29 Javascript
javascript实现时钟动画
2020/12/03 Javascript
[00:33]DOTA2上海特级锦标赛 CDEC战队宣传片
2016/03/04 DOTA
[40:19]2018完美盛典CS.GO表演赛
2018/12/17 DOTA
wxpython中Textctrl回车事件无效的解决方法
2016/07/21 Python
python技能之数据导出excel的实例代码
2017/08/11 Python
Python基于Matplotlib库简单绘制折线图的方法示例
2017/08/14 Python
python安装后的目录在哪里
2020/06/21 Python
CSS3 滤镜 webkit-filter详细介绍及使用方法
2012/12/27 HTML / CSS
canvas学习笔记之绘制简单路径
2019/01/28 HTML / CSS
澳大利亚自然和有机的健康美容产品一站式商店:Ziani Beauty
2017/12/28 全球购物
CHARLES & KEITH澳大利亚官网:新加坡时尚品牌
2019/01/22 全球购物
大一自我鉴定范文
2013/12/27 职场文书
应届生求职自荐信
2014/07/04 职场文书
元旦趣味活动方案
2014/08/22 职场文书
党员四风自我剖析材料思想汇报
2014/09/13 职场文书
党员对照检查材料思想汇报
2014/09/16 职场文书
2015年宣传思想工作总结
2015/05/22 职场文书
2015秋季运动会通讯稿
2015/07/18 职场文书
2016年优秀共产党员先进事迹材料
2016/02/29 职场文书
干货:如何写好工作总结报告!
2019/05/10 职场文书
KVM基础命令详解
2022/04/30 Servers