dpn网络的pytorch实现方式


Posted in Python onJanuary 14, 2020

我就废话不多说了,直接上代码吧!

import torch
import torch.nn as nn
import torch.nn.functional as F



class CatBnAct(nn.Module):
 def __init__(self, in_chs, activation_fn=nn.ReLU(inplace=True)):
  super(CatBnAct, self).__init__()
  self.bn = nn.BatchNorm2d(in_chs, eps=0.001)
  self.act = activation_fn

 def forward(self, x):
  x = torch.cat(x, dim=1) if isinstance(x, tuple) else x
  return self.act(self.bn(x))


class BnActConv2d(nn.Module):
 def __init__(self, s, out_chs, kernel_size, stride,
     padding=0, groups=1, activation_fn=nn.ReLU(inplace=True)):
  super(BnActConv2d, self).__init__()
  self.bn = nn.BatchNorm2d(in_chs, eps=0.001)
  self.act = activation_fn
  self.conv = nn.Conv2d(in_chs, out_chs, kernel_size, stride, padding, groups=groups, bias=False)

 def forward(self, x):
  return self.conv(self.act(self.bn(x)))


class InputBlock(nn.Module):
 def __init__(self, num_init_features, kernel_size=7,
     padding=3, activation_fn=nn.ReLU(inplace=True)):
  super(InputBlock, self).__init__()
  self.conv = nn.Conv2d(
   3, num_init_features, kernel_size=kernel_size, stride=2, padding=padding, bias=False)
  self.bn = nn.BatchNorm2d(num_init_features, eps=0.001)
  self.act = activation_fn
  self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

 def forward(self, x):
  x = self.conv(x)
  x = self.bn(x)
  x = self.act(x)
  x = self.pool(x)
  return x


class DualPathBlock(nn.Module):
 def __init__(
   self, in_chs, num_1x1_a, num_3x3_b, num_1x1_c, inc, groups, block_type='normal', b=False):
  super(DualPathBlock, self).__init__()
  self.num_1x1_c = num_1x1_c
  self.inc = inc
  self.b = b
  if block_type is 'proj':
   self.key_stride = 1
   self.has_proj = True
  elif block_type is 'down':
   self.key_stride = 2
   self.has_proj = True
  else:
   assert block_type is 'normal'
   self.key_stride = 1
   self.has_proj = False

  if self.has_proj:
   # Using different member names here to allow easier parameter key matching for conversion
   if self.key_stride == 2:
    self.c1x1_w_s2 = BnActConv2d(
     in_chs=in_chs, out_chs=num_1x1_c + 2 * inc, kernel_size=1, stride=2)
   else:
    self.c1x1_w_s1 = BnActConv2d(
     in_chs=in_chs, out_chs=num_1x1_c + 2 * inc, kernel_size=1, stride=1)
  self.c1x1_a = BnActConv2d(in_chs=in_chs, out_chs=num_1x1_a, kernel_size=1, stride=1)
  self.c3x3_b = BnActConv2d(
   in_chs=num_1x1_a, out_chs=num_3x3_b, kernel_size=3,
   stride=self.key_stride, padding=1, groups=groups)
  if b:
   self.c1x1_c = CatBnAct(in_chs=num_3x3_b)
   self.c1x1_c1 = nn.Conv2d(num_3x3_b, num_1x1_c, kernel_size=1, bias=False)
   self.c1x1_c2 = nn.Conv2d(num_3x3_b, inc, kernel_size=1, bias=False)
  else:
   self.c1x1_c = BnActConv2d(in_chs=num_3x3_b, out_chs=num_1x1_c + inc, kernel_size=1, stride=1)

 def forward(self, x):
  x_in = torch.cat(x, dim=1) if isinstance(x, tuple) else x
  if self.has_proj:
   if self.key_stride == 2:
    x_s = self.c1x1_w_s2(x_in)
   else:
    x_s = self.c1x1_w_s1(x_in)
   x_s1 = x_s[:, :self.num_1x1_c, :, :]
   x_s2 = x_s[:, self.num_1x1_c:, :, :]
  else:
   x_s1 = x[0]
   x_s2 = x[1]
  x_in = self.c1x1_a(x_in)
  x_in = self.c3x3_b(x_in)
  if self.b:
   x_in = self.c1x1_c(x_in)
   out1 = self.c1x1_c1(x_in)
   out2 = self.c1x1_c2(x_in)
  else:
   x_in = self.c1x1_c(x_in)
   out1 = x_in[:, :self.num_1x1_c, :, :]
   out2 = x_in[:, self.num_1x1_c:, :, :]
  resid = x_s1 + out1
  dense = torch.cat([x_s2, out2], dim=1)
  return resid, dense


class DPN(nn.Module):
 def __init__(self, small=False, num_init_features=64, k_r=96, groups=32,
     b=False, k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128),
     num_classes=1000, test_time_pool=False):
  super(DPN, self).__init__()
  self.test_time_pool = test_time_pool
  self.b = b
  bw_factor = 1 if small else 4

  blocks = OrderedDict()

  # conv1
  if small:
   blocks['conv1_1'] = InputBlock(num_init_features, kernel_size=3, padding=1)
  else:
   blocks['conv1_1'] = InputBlock(num_init_features, kernel_size=7, padding=3)

  # conv2
  bw = 64 * bw_factor
  inc = inc_sec[0]
  r = (k_r * bw) // (64 * bw_factor)
  blocks['conv2_1'] = DualPathBlock(num_init_features, r, r, bw, inc, groups, 'proj', b)
  in_chs = bw + 3 * inc
  for i in range(2, k_sec[0] + 1):
   blocks['conv2_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b)
   in_chs += inc

  # conv3
  bw = 128 * bw_factor
  inc = inc_sec[1]
  r = (k_r * bw) // (64 * bw_factor)
  blocks['conv3_1'] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'down', b)
  in_chs = bw + 3 * inc
  for i in range(2, k_sec[1] + 1):
   blocks['conv3_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b)
   in_chs += inc

  # conv4
  bw = 256 * bw_factor
  inc = inc_sec[2]
  r = (k_r * bw) // (64 * bw_factor)
  blocks['conv4_1'] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'down', b)
  in_chs = bw + 3 * inc
  for i in range(2, k_sec[2] + 1):
   blocks['conv4_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b)
   in_chs += inc

  # conv5
  bw = 512 * bw_factor
  inc = inc_sec[3]
  r = (k_r * bw) // (64 * bw_factor)
  blocks['conv5_1'] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'down', b)
  in_chs = bw + 3 * inc
  for i in range(2, k_sec[3] + 1):
   blocks['conv5_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b)
   in_chs += inc
  blocks['conv5_bn_ac'] = CatBnAct(in_chs)

  self.features = nn.Sequential(blocks)

  # Using 1x1 conv for the FC layer to allow the extra pooling scheme
  self.last_linear = nn.Conv2d(in_chs, num_classes, kernel_size=1, bias=True)

 def logits(self, features):
  if not self.training and self.test_time_pool:
   x = F.avg_pool2d(features, kernel_size=7, stride=1)
   out = self.last_linear(x)
   # The extra test time pool should be pooling an img_size//32 - 6 size patch
   out = adaptive_avgmax_pool2d(out, pool_type='avgmax')
  else:
   x = adaptive_avgmax_pool2d(features, pool_type='avg')
   out = self.last_linear(x)
  return out.view(out.size(0), -1)

 def forward(self, input):
  x = self.features(input)
  x = self.logits(x)
  return x

""" PyTorch selectable adaptive pooling
Adaptive pooling with the ability to select the type of pooling from:
 * 'avg' - Average pooling
 * 'max' - Max pooling
 * 'avgmax' - Sum of average and max pooling re-scaled by 0.5
 * 'avgmaxc' - Concatenation of average and max pooling along feature dim, doubles feature dim

Both a functional and a nn.Module version of the pooling is provided.

"""

def pooling_factor(pool_type='avg'):
 return 2 if pool_type == 'avgmaxc' else 1


def adaptive_avgmax_pool2d(x, pool_type='avg', padding=0, count_include_pad=False):
 """Selectable global pooling function with dynamic input kernel size
 """
 if pool_type == 'avgmaxc':
  x = torch.cat([
   F.avg_pool2d(
    x, kernel_size=(x.size(2), x.size(3)), padding=padding, count_include_pad=count_include_pad),
   F.max_pool2d(x, kernel_size=(x.size(2), x.size(3)), padding=padding)
  ], dim=1)
 elif pool_type == 'avgmax':
  x_avg = F.avg_pool2d(
    x, kernel_size=(x.size(2), x.size(3)), padding=padding, count_include_pad=count_include_pad)
  x_max = F.max_pool2d(x, kernel_size=(x.size(2), x.size(3)), padding=padding)
  x = 0.5 * (x_avg + x_max)
 elif pool_type == 'max':
  x = F.max_pool2d(x, kernel_size=(x.size(2), x.size(3)), padding=padding)
 else:
  if pool_type != 'avg':
   print('Invalid pool type %s specified. Defaulting to average pooling.' % pool_type)
  x = F.avg_pool2d(
   x, kernel_size=(x.size(2), x.size(3)), padding=padding, count_include_pad=count_include_pad)
 return x


class AdaptiveAvgMaxPool2d(torch.nn.Module):
 """Selectable global pooling layer with dynamic input kernel size
 """
 def __init__(self, output_size=1, pool_type='avg'):
  super(AdaptiveAvgMaxPool2d, self).__init__()
  self.output_size = output_size
  self.pool_type = pool_type
  if pool_type == 'avgmaxc' or pool_type == 'avgmax':
   self.pool = nn.ModuleList([nn.AdaptiveAvgPool2d(output_size), nn.AdaptiveMaxPool2d(output_size)])
  elif pool_type == 'max':
   self.pool = nn.AdaptiveMaxPool2d(output_size)
  else:
   if pool_type != 'avg':
    print('Invalid pool type %s specified. Defaulting to average pooling.' % pool_type)
   self.pool = nn.AdaptiveAvgPool2d(output_size)

 def forward(self, x):
  if self.pool_type == 'avgmaxc':
   x = torch.cat([p(x) for p in self.pool], dim=1)
  elif self.pool_type == 'avgmax':
   x = 0.5 * torch.sum(torch.stack([p(x) for p in self.pool]), 0).squeeze(dim=0)
  else:
   x = self.pool(x)
  return x

 def factor(self):
  return pooling_factor(self.pool_type)

 def __repr__(self):
  return self.__class__.__name__ + ' (' \
    + 'output_size=' + str(self.output_size) \
    + ', pool_type=' + self.pool_type + ')'

以上这篇dpn网络的pytorch实现方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
9种python web 程序的部署方式小结
Jun 30 Python
Python单元测试实例详解
May 25 Python
python实现对指定字符串补足固定长度倍数截断输出的方法
Nov 15 Python
python使用selenium登录QQ邮箱(附带滑动解锁)
Jan 23 Python
python实现在一个画布上画多个子图
Jan 19 Python
python3 logging日志封装实例
Apr 08 Python
如何基于python实现不邻接植花
May 01 Python
使用jupyter notebook运行python和R的步骤
Aug 13 Python
python实现b站直播自动发送弹幕功能
Feb 20 Python
字典算法实现及操作 --python(实用)
Mar 31 Python
Python编写可视化界面的全过程(Python+PyCharm+PyQt)
May 17 Python
python自动化八大定位元素讲解
Jul 09 Python
Django之form组件自动校验数据实现
Jan 14 #Python
简单了解python filter、map、reduce的区别
Jan 14 #Python
Python vtk读取并显示dicom文件示例
Jan 13 #Python
Python解析多帧dicom数据详解
Jan 13 #Python
python 将dicom图片转换成jpg图片的实例
Jan 13 #Python
基于Python和PyYAML读取yaml配置文件数据
Jan 13 #Python
Python 实现判断图片格式并转换,将转换的图像存到生成的文件夹中
Jan 13 #Python
You might like
博士208HAF收音机实习报告
2021/03/02 无线电
别人整理的服务器变量:$_SERVER
2006/10/20 PHP
PHP上传图片时判断上传文件是否为可用图片的方法
2016/10/20 PHP
利用PHP访问MySql数据库的逻辑操作以及增删改查的实例讲解
2017/08/30 PHP
CentOS7编译安装php7.1的教程详解
2019/04/18 PHP
JavaScript 实现模态对话框 源代码大全
2009/05/02 Javascript
JQuery中SetTimeOut传参问题探讨
2013/05/10 Javascript
js判断FCKeditor内容是否为空的两种形式
2013/05/14 Javascript
最好用的省市二级联动 原生js实现你值得拥有
2013/09/22 Javascript
jquery单行文字向上滚动效果示例
2014/03/06 Javascript
使用javascript提交form表单方法汇总
2015/06/25 Javascript
基于js实现微信发送好友如何分享到朋友圈、微博
2015/11/30 Javascript
Bootstrap框架下下拉框select搜索功能
2020/03/26 Javascript
BootStrap下拉菜单和滚动监听插件实现代码
2016/09/26 Javascript
利用jquery获取select下拉框的值
2016/11/23 Javascript
Ajax 加载数据 练习代码
2017/01/05 Javascript
Angular4的输入属性与输出属性实例详解
2017/11/29 Javascript
详解vue 在移动端体验上的优化解决方案
2019/05/20 Javascript
vue中使用element ui的弹窗与echarts之间的问题详解
2019/10/25 Javascript
[00:31]2016完美“圣”典风云人物:国士无双宣传片
2016/12/04 DOTA
python中反射用法实例
2015/03/27 Python
在Python中用keys()方法返回字典键的教程
2015/05/21 Python
基于Python代码编辑器的选用(详解)
2017/09/13 Python
python如何让类支持比较运算
2018/03/20 Python
python提取具有某种特定字符串的行数据方法
2018/12/11 Python
Python PyCharm如何进行断点调试
2019/07/05 Python
python实现对服务器脚本敏感信息的加密解密功能
2019/08/13 Python
Python Flask上下文管理机制实例解析
2020/03/16 Python
巧克力蛋糕店创业计划书
2014/01/14 职场文书
安全教育月活动总结
2014/05/05 职场文书
服务承诺书
2015/01/19 职场文书
人民调解协议书
2016/03/21 职场文书
小学记事作文之200字
2019/08/06 职场文书
Redis缓存-序列化对象存储乱码问题的解决
2021/06/21 Redis
IDEA使用SpringAssistant插件创建SpringCloud项目
2021/06/23 Java/Android
带你了解Java中的ForkJoin
2022/04/28 Java/Android