dpn网络的pytorch实现方式


Posted in Python onJanuary 14, 2020

我就废话不多说了,直接上代码吧!

import torch
import torch.nn as nn
import torch.nn.functional as F



class CatBnAct(nn.Module):
 def __init__(self, in_chs, activation_fn=nn.ReLU(inplace=True)):
  super(CatBnAct, self).__init__()
  self.bn = nn.BatchNorm2d(in_chs, eps=0.001)
  self.act = activation_fn

 def forward(self, x):
  x = torch.cat(x, dim=1) if isinstance(x, tuple) else x
  return self.act(self.bn(x))


class BnActConv2d(nn.Module):
 def __init__(self, s, out_chs, kernel_size, stride,
     padding=0, groups=1, activation_fn=nn.ReLU(inplace=True)):
  super(BnActConv2d, self).__init__()
  self.bn = nn.BatchNorm2d(in_chs, eps=0.001)
  self.act = activation_fn
  self.conv = nn.Conv2d(in_chs, out_chs, kernel_size, stride, padding, groups=groups, bias=False)

 def forward(self, x):
  return self.conv(self.act(self.bn(x)))


class InputBlock(nn.Module):
 def __init__(self, num_init_features, kernel_size=7,
     padding=3, activation_fn=nn.ReLU(inplace=True)):
  super(InputBlock, self).__init__()
  self.conv = nn.Conv2d(
   3, num_init_features, kernel_size=kernel_size, stride=2, padding=padding, bias=False)
  self.bn = nn.BatchNorm2d(num_init_features, eps=0.001)
  self.act = activation_fn
  self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

 def forward(self, x):
  x = self.conv(x)
  x = self.bn(x)
  x = self.act(x)
  x = self.pool(x)
  return x


class DualPathBlock(nn.Module):
 def __init__(
   self, in_chs, num_1x1_a, num_3x3_b, num_1x1_c, inc, groups, block_type='normal', b=False):
  super(DualPathBlock, self).__init__()
  self.num_1x1_c = num_1x1_c
  self.inc = inc
  self.b = b
  if block_type is 'proj':
   self.key_stride = 1
   self.has_proj = True
  elif block_type is 'down':
   self.key_stride = 2
   self.has_proj = True
  else:
   assert block_type is 'normal'
   self.key_stride = 1
   self.has_proj = False

  if self.has_proj:
   # Using different member names here to allow easier parameter key matching for conversion
   if self.key_stride == 2:
    self.c1x1_w_s2 = BnActConv2d(
     in_chs=in_chs, out_chs=num_1x1_c + 2 * inc, kernel_size=1, stride=2)
   else:
    self.c1x1_w_s1 = BnActConv2d(
     in_chs=in_chs, out_chs=num_1x1_c + 2 * inc, kernel_size=1, stride=1)
  self.c1x1_a = BnActConv2d(in_chs=in_chs, out_chs=num_1x1_a, kernel_size=1, stride=1)
  self.c3x3_b = BnActConv2d(
   in_chs=num_1x1_a, out_chs=num_3x3_b, kernel_size=3,
   stride=self.key_stride, padding=1, groups=groups)
  if b:
   self.c1x1_c = CatBnAct(in_chs=num_3x3_b)
   self.c1x1_c1 = nn.Conv2d(num_3x3_b, num_1x1_c, kernel_size=1, bias=False)
   self.c1x1_c2 = nn.Conv2d(num_3x3_b, inc, kernel_size=1, bias=False)
  else:
   self.c1x1_c = BnActConv2d(in_chs=num_3x3_b, out_chs=num_1x1_c + inc, kernel_size=1, stride=1)

 def forward(self, x):
  x_in = torch.cat(x, dim=1) if isinstance(x, tuple) else x
  if self.has_proj:
   if self.key_stride == 2:
    x_s = self.c1x1_w_s2(x_in)
   else:
    x_s = self.c1x1_w_s1(x_in)
   x_s1 = x_s[:, :self.num_1x1_c, :, :]
   x_s2 = x_s[:, self.num_1x1_c:, :, :]
  else:
   x_s1 = x[0]
   x_s2 = x[1]
  x_in = self.c1x1_a(x_in)
  x_in = self.c3x3_b(x_in)
  if self.b:
   x_in = self.c1x1_c(x_in)
   out1 = self.c1x1_c1(x_in)
   out2 = self.c1x1_c2(x_in)
  else:
   x_in = self.c1x1_c(x_in)
   out1 = x_in[:, :self.num_1x1_c, :, :]
   out2 = x_in[:, self.num_1x1_c:, :, :]
  resid = x_s1 + out1
  dense = torch.cat([x_s2, out2], dim=1)
  return resid, dense


class DPN(nn.Module):
 def __init__(self, small=False, num_init_features=64, k_r=96, groups=32,
     b=False, k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128),
     num_classes=1000, test_time_pool=False):
  super(DPN, self).__init__()
  self.test_time_pool = test_time_pool
  self.b = b
  bw_factor = 1 if small else 4

  blocks = OrderedDict()

  # conv1
  if small:
   blocks['conv1_1'] = InputBlock(num_init_features, kernel_size=3, padding=1)
  else:
   blocks['conv1_1'] = InputBlock(num_init_features, kernel_size=7, padding=3)

  # conv2
  bw = 64 * bw_factor
  inc = inc_sec[0]
  r = (k_r * bw) // (64 * bw_factor)
  blocks['conv2_1'] = DualPathBlock(num_init_features, r, r, bw, inc, groups, 'proj', b)
  in_chs = bw + 3 * inc
  for i in range(2, k_sec[0] + 1):
   blocks['conv2_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b)
   in_chs += inc

  # conv3
  bw = 128 * bw_factor
  inc = inc_sec[1]
  r = (k_r * bw) // (64 * bw_factor)
  blocks['conv3_1'] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'down', b)
  in_chs = bw + 3 * inc
  for i in range(2, k_sec[1] + 1):
   blocks['conv3_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b)
   in_chs += inc

  # conv4
  bw = 256 * bw_factor
  inc = inc_sec[2]
  r = (k_r * bw) // (64 * bw_factor)
  blocks['conv4_1'] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'down', b)
  in_chs = bw + 3 * inc
  for i in range(2, k_sec[2] + 1):
   blocks['conv4_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b)
   in_chs += inc

  # conv5
  bw = 512 * bw_factor
  inc = inc_sec[3]
  r = (k_r * bw) // (64 * bw_factor)
  blocks['conv5_1'] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'down', b)
  in_chs = bw + 3 * inc
  for i in range(2, k_sec[3] + 1):
   blocks['conv5_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b)
   in_chs += inc
  blocks['conv5_bn_ac'] = CatBnAct(in_chs)

  self.features = nn.Sequential(blocks)

  # Using 1x1 conv for the FC layer to allow the extra pooling scheme
  self.last_linear = nn.Conv2d(in_chs, num_classes, kernel_size=1, bias=True)

 def logits(self, features):
  if not self.training and self.test_time_pool:
   x = F.avg_pool2d(features, kernel_size=7, stride=1)
   out = self.last_linear(x)
   # The extra test time pool should be pooling an img_size//32 - 6 size patch
   out = adaptive_avgmax_pool2d(out, pool_type='avgmax')
  else:
   x = adaptive_avgmax_pool2d(features, pool_type='avg')
   out = self.last_linear(x)
  return out.view(out.size(0), -1)

 def forward(self, input):
  x = self.features(input)
  x = self.logits(x)
  return x

""" PyTorch selectable adaptive pooling
Adaptive pooling with the ability to select the type of pooling from:
 * 'avg' - Average pooling
 * 'max' - Max pooling
 * 'avgmax' - Sum of average and max pooling re-scaled by 0.5
 * 'avgmaxc' - Concatenation of average and max pooling along feature dim, doubles feature dim

Both a functional and a nn.Module version of the pooling is provided.

"""

def pooling_factor(pool_type='avg'):
 return 2 if pool_type == 'avgmaxc' else 1


def adaptive_avgmax_pool2d(x, pool_type='avg', padding=0, count_include_pad=False):
 """Selectable global pooling function with dynamic input kernel size
 """
 if pool_type == 'avgmaxc':
  x = torch.cat([
   F.avg_pool2d(
    x, kernel_size=(x.size(2), x.size(3)), padding=padding, count_include_pad=count_include_pad),
   F.max_pool2d(x, kernel_size=(x.size(2), x.size(3)), padding=padding)
  ], dim=1)
 elif pool_type == 'avgmax':
  x_avg = F.avg_pool2d(
    x, kernel_size=(x.size(2), x.size(3)), padding=padding, count_include_pad=count_include_pad)
  x_max = F.max_pool2d(x, kernel_size=(x.size(2), x.size(3)), padding=padding)
  x = 0.5 * (x_avg + x_max)
 elif pool_type == 'max':
  x = F.max_pool2d(x, kernel_size=(x.size(2), x.size(3)), padding=padding)
 else:
  if pool_type != 'avg':
   print('Invalid pool type %s specified. Defaulting to average pooling.' % pool_type)
  x = F.avg_pool2d(
   x, kernel_size=(x.size(2), x.size(3)), padding=padding, count_include_pad=count_include_pad)
 return x


class AdaptiveAvgMaxPool2d(torch.nn.Module):
 """Selectable global pooling layer with dynamic input kernel size
 """
 def __init__(self, output_size=1, pool_type='avg'):
  super(AdaptiveAvgMaxPool2d, self).__init__()
  self.output_size = output_size
  self.pool_type = pool_type
  if pool_type == 'avgmaxc' or pool_type == 'avgmax':
   self.pool = nn.ModuleList([nn.AdaptiveAvgPool2d(output_size), nn.AdaptiveMaxPool2d(output_size)])
  elif pool_type == 'max':
   self.pool = nn.AdaptiveMaxPool2d(output_size)
  else:
   if pool_type != 'avg':
    print('Invalid pool type %s specified. Defaulting to average pooling.' % pool_type)
   self.pool = nn.AdaptiveAvgPool2d(output_size)

 def forward(self, x):
  if self.pool_type == 'avgmaxc':
   x = torch.cat([p(x) for p in self.pool], dim=1)
  elif self.pool_type == 'avgmax':
   x = 0.5 * torch.sum(torch.stack([p(x) for p in self.pool]), 0).squeeze(dim=0)
  else:
   x = self.pool(x)
  return x

 def factor(self):
  return pooling_factor(self.pool_type)

 def __repr__(self):
  return self.__class__.__name__ + ' (' \
    + 'output_size=' + str(self.output_size) \
    + ', pool_type=' + self.pool_type + ')'

以上这篇dpn网络的pytorch实现方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python把csv数据写入list和字典类型的变量脚本方法
Jun 15 Python
Python中 map()函数的用法详解
Jul 10 Python
python: 自动安装缺失库文件的方法
Oct 22 Python
Python列表常见操作详解(获取,增加,删除,修改,排序等)
Feb 18 Python
selenium获取当前页面的url、源码、title的方法
Jun 12 Python
python下PyGame的下载与安装过程及遇到问题
Aug 04 Python
wxPython之wx.DC绘制形状
Nov 19 Python
Python之变量类型和if判断方式
May 05 Python
Python计算矩阵的和积的实例详解
Sep 10 Python
python实现一个简单RPC框架的示例
Oct 28 Python
python语言实现贪吃蛇游戏
Nov 13 Python
pd.DataFrame中的几种索引变换的实现
Jun 16 Python
Django之form组件自动校验数据实现
Jan 14 #Python
简单了解python filter、map、reduce的区别
Jan 14 #Python
Python vtk读取并显示dicom文件示例
Jan 13 #Python
Python解析多帧dicom数据详解
Jan 13 #Python
python 将dicom图片转换成jpg图片的实例
Jan 13 #Python
基于Python和PyYAML读取yaml配置文件数据
Jan 13 #Python
Python 实现判断图片格式并转换,将转换的图像存到生成的文件夹中
Jan 13 #Python
You might like
PHP 日期加减的类,很不错
2009/10/10 PHP
用PHP即时捕捉PHP中的错误并发送email通知的实现代码
2013/01/19 PHP
php以post形式发送xml的方法
2014/11/04 PHP
PHP使用ODBC连接数据库的方法
2015/07/18 PHP
PHP读取mssql json数据中文乱码的解决办法
2016/04/11 PHP
如何离线执行php任务
2017/02/21 PHP
js计数器代码
2006/11/04 Javascript
Extjs列表详细信息窗口新建后自动加载解决方法
2010/04/02 Javascript
safari,opera嵌入iframe页面cookie读取问题解决方法
2010/06/23 Javascript
JavaScript Memoization 让函数也有记忆功能
2011/10/27 Javascript
自己写的兼容ie和ff的在线文本编辑器类似ewebeditor
2012/12/12 Javascript
Mac下使用charles遇到的问题以及解决办法
2017/01/10 Javascript
使用jQuery操作DOM的方法小结
2017/02/27 Javascript
bootstrap table 多选框分页保留示例代码
2017/03/08 Javascript
Vue声明式渲染详解
2017/05/17 Javascript
简单的网页广告特效实例
2017/08/19 Javascript
vue 微信扫码登录(自定义样式)
2020/01/06 Javascript
[00:59]DOTA2荣耀之路1:Doom is back!weapon X!
2018/05/22 DOTA
phpsir 开发 一个检测百度关键字网站排名的python 程序
2009/09/17 Python
在Ubuntu系统下安装使用Python的GUI工具wxPython
2016/02/18 Python
详解使用python的logging模块在stdout输出的两种方法
2017/05/17 Python
Python字符串处理实现单词反转
2017/06/14 Python
TensorFlow 滑动平均的示例代码
2018/06/19 Python
Python @property原理解析和用法实例
2020/02/11 Python
10个python爬虫入门基础代码实例 + 1个简单的python爬虫完整实例
2020/12/16 Python
python 基于opencv实现图像增强
2020/12/23 Python
美国最大的骑马用品零售商:HorseLoverZ
2017/01/12 全球购物
迪斯尼假期(欧洲、中东及非洲):Disney Holidays EMEA
2021/02/15 全球购物
物理分数没达标检讨书
2014/09/13 职场文书
小学生推普周国旗下讲话稿
2014/09/21 职场文书
教师批评与自我批评剖析材料
2014/10/16 职场文书
工厂见习报告范文
2014/10/31 职场文书
网络营销计划书
2015/01/17 职场文书
幼儿园园务工作总结2015
2015/05/18 职场文书
使用canvas实现雪花飘动效果的示例代码
2021/03/30 HTML / CSS
基于python制作简易版学生信息管理系统
2021/04/20 Python