dpn网络的pytorch实现方式


Posted in Python onJanuary 14, 2020

我就废话不多说了,直接上代码吧!

import torch
import torch.nn as nn
import torch.nn.functional as F



class CatBnAct(nn.Module):
 def __init__(self, in_chs, activation_fn=nn.ReLU(inplace=True)):
  super(CatBnAct, self).__init__()
  self.bn = nn.BatchNorm2d(in_chs, eps=0.001)
  self.act = activation_fn

 def forward(self, x):
  x = torch.cat(x, dim=1) if isinstance(x, tuple) else x
  return self.act(self.bn(x))


class BnActConv2d(nn.Module):
 def __init__(self, s, out_chs, kernel_size, stride,
     padding=0, groups=1, activation_fn=nn.ReLU(inplace=True)):
  super(BnActConv2d, self).__init__()
  self.bn = nn.BatchNorm2d(in_chs, eps=0.001)
  self.act = activation_fn
  self.conv = nn.Conv2d(in_chs, out_chs, kernel_size, stride, padding, groups=groups, bias=False)

 def forward(self, x):
  return self.conv(self.act(self.bn(x)))


class InputBlock(nn.Module):
 def __init__(self, num_init_features, kernel_size=7,
     padding=3, activation_fn=nn.ReLU(inplace=True)):
  super(InputBlock, self).__init__()
  self.conv = nn.Conv2d(
   3, num_init_features, kernel_size=kernel_size, stride=2, padding=padding, bias=False)
  self.bn = nn.BatchNorm2d(num_init_features, eps=0.001)
  self.act = activation_fn
  self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

 def forward(self, x):
  x = self.conv(x)
  x = self.bn(x)
  x = self.act(x)
  x = self.pool(x)
  return x


class DualPathBlock(nn.Module):
 def __init__(
   self, in_chs, num_1x1_a, num_3x3_b, num_1x1_c, inc, groups, block_type='normal', b=False):
  super(DualPathBlock, self).__init__()
  self.num_1x1_c = num_1x1_c
  self.inc = inc
  self.b = b
  if block_type is 'proj':
   self.key_stride = 1
   self.has_proj = True
  elif block_type is 'down':
   self.key_stride = 2
   self.has_proj = True
  else:
   assert block_type is 'normal'
   self.key_stride = 1
   self.has_proj = False

  if self.has_proj:
   # Using different member names here to allow easier parameter key matching for conversion
   if self.key_stride == 2:
    self.c1x1_w_s2 = BnActConv2d(
     in_chs=in_chs, out_chs=num_1x1_c + 2 * inc, kernel_size=1, stride=2)
   else:
    self.c1x1_w_s1 = BnActConv2d(
     in_chs=in_chs, out_chs=num_1x1_c + 2 * inc, kernel_size=1, stride=1)
  self.c1x1_a = BnActConv2d(in_chs=in_chs, out_chs=num_1x1_a, kernel_size=1, stride=1)
  self.c3x3_b = BnActConv2d(
   in_chs=num_1x1_a, out_chs=num_3x3_b, kernel_size=3,
   stride=self.key_stride, padding=1, groups=groups)
  if b:
   self.c1x1_c = CatBnAct(in_chs=num_3x3_b)
   self.c1x1_c1 = nn.Conv2d(num_3x3_b, num_1x1_c, kernel_size=1, bias=False)
   self.c1x1_c2 = nn.Conv2d(num_3x3_b, inc, kernel_size=1, bias=False)
  else:
   self.c1x1_c = BnActConv2d(in_chs=num_3x3_b, out_chs=num_1x1_c + inc, kernel_size=1, stride=1)

 def forward(self, x):
  x_in = torch.cat(x, dim=1) if isinstance(x, tuple) else x
  if self.has_proj:
   if self.key_stride == 2:
    x_s = self.c1x1_w_s2(x_in)
   else:
    x_s = self.c1x1_w_s1(x_in)
   x_s1 = x_s[:, :self.num_1x1_c, :, :]
   x_s2 = x_s[:, self.num_1x1_c:, :, :]
  else:
   x_s1 = x[0]
   x_s2 = x[1]
  x_in = self.c1x1_a(x_in)
  x_in = self.c3x3_b(x_in)
  if self.b:
   x_in = self.c1x1_c(x_in)
   out1 = self.c1x1_c1(x_in)
   out2 = self.c1x1_c2(x_in)
  else:
   x_in = self.c1x1_c(x_in)
   out1 = x_in[:, :self.num_1x1_c, :, :]
   out2 = x_in[:, self.num_1x1_c:, :, :]
  resid = x_s1 + out1
  dense = torch.cat([x_s2, out2], dim=1)
  return resid, dense


class DPN(nn.Module):
 def __init__(self, small=False, num_init_features=64, k_r=96, groups=32,
     b=False, k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128),
     num_classes=1000, test_time_pool=False):
  super(DPN, self).__init__()
  self.test_time_pool = test_time_pool
  self.b = b
  bw_factor = 1 if small else 4

  blocks = OrderedDict()

  # conv1
  if small:
   blocks['conv1_1'] = InputBlock(num_init_features, kernel_size=3, padding=1)
  else:
   blocks['conv1_1'] = InputBlock(num_init_features, kernel_size=7, padding=3)

  # conv2
  bw = 64 * bw_factor
  inc = inc_sec[0]
  r = (k_r * bw) // (64 * bw_factor)
  blocks['conv2_1'] = DualPathBlock(num_init_features, r, r, bw, inc, groups, 'proj', b)
  in_chs = bw + 3 * inc
  for i in range(2, k_sec[0] + 1):
   blocks['conv2_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b)
   in_chs += inc

  # conv3
  bw = 128 * bw_factor
  inc = inc_sec[1]
  r = (k_r * bw) // (64 * bw_factor)
  blocks['conv3_1'] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'down', b)
  in_chs = bw + 3 * inc
  for i in range(2, k_sec[1] + 1):
   blocks['conv3_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b)
   in_chs += inc

  # conv4
  bw = 256 * bw_factor
  inc = inc_sec[2]
  r = (k_r * bw) // (64 * bw_factor)
  blocks['conv4_1'] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'down', b)
  in_chs = bw + 3 * inc
  for i in range(2, k_sec[2] + 1):
   blocks['conv4_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b)
   in_chs += inc

  # conv5
  bw = 512 * bw_factor
  inc = inc_sec[3]
  r = (k_r * bw) // (64 * bw_factor)
  blocks['conv5_1'] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'down', b)
  in_chs = bw + 3 * inc
  for i in range(2, k_sec[3] + 1):
   blocks['conv5_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b)
   in_chs += inc
  blocks['conv5_bn_ac'] = CatBnAct(in_chs)

  self.features = nn.Sequential(blocks)

  # Using 1x1 conv for the FC layer to allow the extra pooling scheme
  self.last_linear = nn.Conv2d(in_chs, num_classes, kernel_size=1, bias=True)

 def logits(self, features):
  if not self.training and self.test_time_pool:
   x = F.avg_pool2d(features, kernel_size=7, stride=1)
   out = self.last_linear(x)
   # The extra test time pool should be pooling an img_size//32 - 6 size patch
   out = adaptive_avgmax_pool2d(out, pool_type='avgmax')
  else:
   x = adaptive_avgmax_pool2d(features, pool_type='avg')
   out = self.last_linear(x)
  return out.view(out.size(0), -1)

 def forward(self, input):
  x = self.features(input)
  x = self.logits(x)
  return x

""" PyTorch selectable adaptive pooling
Adaptive pooling with the ability to select the type of pooling from:
 * 'avg' - Average pooling
 * 'max' - Max pooling
 * 'avgmax' - Sum of average and max pooling re-scaled by 0.5
 * 'avgmaxc' - Concatenation of average and max pooling along feature dim, doubles feature dim

Both a functional and a nn.Module version of the pooling is provided.

"""

def pooling_factor(pool_type='avg'):
 return 2 if pool_type == 'avgmaxc' else 1


def adaptive_avgmax_pool2d(x, pool_type='avg', padding=0, count_include_pad=False):
 """Selectable global pooling function with dynamic input kernel size
 """
 if pool_type == 'avgmaxc':
  x = torch.cat([
   F.avg_pool2d(
    x, kernel_size=(x.size(2), x.size(3)), padding=padding, count_include_pad=count_include_pad),
   F.max_pool2d(x, kernel_size=(x.size(2), x.size(3)), padding=padding)
  ], dim=1)
 elif pool_type == 'avgmax':
  x_avg = F.avg_pool2d(
    x, kernel_size=(x.size(2), x.size(3)), padding=padding, count_include_pad=count_include_pad)
  x_max = F.max_pool2d(x, kernel_size=(x.size(2), x.size(3)), padding=padding)
  x = 0.5 * (x_avg + x_max)
 elif pool_type == 'max':
  x = F.max_pool2d(x, kernel_size=(x.size(2), x.size(3)), padding=padding)
 else:
  if pool_type != 'avg':
   print('Invalid pool type %s specified. Defaulting to average pooling.' % pool_type)
  x = F.avg_pool2d(
   x, kernel_size=(x.size(2), x.size(3)), padding=padding, count_include_pad=count_include_pad)
 return x


class AdaptiveAvgMaxPool2d(torch.nn.Module):
 """Selectable global pooling layer with dynamic input kernel size
 """
 def __init__(self, output_size=1, pool_type='avg'):
  super(AdaptiveAvgMaxPool2d, self).__init__()
  self.output_size = output_size
  self.pool_type = pool_type
  if pool_type == 'avgmaxc' or pool_type == 'avgmax':
   self.pool = nn.ModuleList([nn.AdaptiveAvgPool2d(output_size), nn.AdaptiveMaxPool2d(output_size)])
  elif pool_type == 'max':
   self.pool = nn.AdaptiveMaxPool2d(output_size)
  else:
   if pool_type != 'avg':
    print('Invalid pool type %s specified. Defaulting to average pooling.' % pool_type)
   self.pool = nn.AdaptiveAvgPool2d(output_size)

 def forward(self, x):
  if self.pool_type == 'avgmaxc':
   x = torch.cat([p(x) for p in self.pool], dim=1)
  elif self.pool_type == 'avgmax':
   x = 0.5 * torch.sum(torch.stack([p(x) for p in self.pool]), 0).squeeze(dim=0)
  else:
   x = self.pool(x)
  return x

 def factor(self):
  return pooling_factor(self.pool_type)

 def __repr__(self):
  return self.__class__.__name__ + ' (' \
    + 'output_size=' + str(self.output_size) \
    + ', pool_type=' + self.pool_type + ')'

以上这篇dpn网络的pytorch实现方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详细介绍Python语言中的按位运算符
Nov 26 Python
python Django批量导入数据
Mar 25 Python
Python读取一个目录下所有目录和文件的方法
Jul 15 Python
python正则表达式面试题解答
Apr 28 Python
Python3实现的简单验证码识别功能示例
May 02 Python
Python之inspect模块实现获取加载模块路径的方法
Oct 16 Python
利用python GDAL库读写geotiff格式的遥感影像方法
Nov 29 Python
python3正则模块re的使用方法详解
Feb 11 Python
浅谈pytorch池化maxpool2D注意事项
Feb 18 Python
Python连接SQLite数据库并进行增册改查操作方法详解
Feb 18 Python
查看jupyter notebook每个单元格运行时间实例
Apr 22 Python
python 实现批量图片识别并翻译
Nov 02 Python
Django之form组件自动校验数据实现
Jan 14 #Python
简单了解python filter、map、reduce的区别
Jan 14 #Python
Python vtk读取并显示dicom文件示例
Jan 13 #Python
Python解析多帧dicom数据详解
Jan 13 #Python
python 将dicom图片转换成jpg图片的实例
Jan 13 #Python
基于Python和PyYAML读取yaml配置文件数据
Jan 13 #Python
Python 实现判断图片格式并转换,将转换的图像存到生成的文件夹中
Jan 13 #Python
You might like
PHP逐行输出(ob_flush与flush的组合)
2012/02/04 PHP
php发送html格式文本邮件的方法
2015/06/10 PHP
json 入门基础教程 推荐
2009/10/31 Javascript
Js 时间间隔计算的函数(间隔天数)
2011/11/15 Javascript
原始XMLHttpRequest方法详情回顾
2013/11/28 Javascript
jquery制作搜狐快站页面效果示例分享
2014/02/21 Javascript
ExtJS 刷新后如何默认选中刷新前最后一次选中的节点
2014/04/03 Javascript
JavaScript闭包实例讲解
2014/04/22 Javascript
JavaScript字符串对象charAt方法入门实例(用于取得指定位置的字符)
2014/10/17 Javascript
jQuery的Ajax用户认证和注册技术实例教程(附demo源码)
2015/12/08 Javascript
JS实现alert中显示换行的方法
2015/12/17 Javascript
JS判断字符串变量是否含有某个字串的实现方法
2016/06/03 Javascript
jquery实现ajax提交表单信息的简单方法(推荐)
2016/08/24 Javascript
基于JavaScript实现类名的添加与移除
2017/04/23 Javascript
JS设计模式之命令模式概念与用法分析
2018/02/06 Javascript
vue-cli3.0 环境变量与模式配置方法
2018/11/08 Javascript
详解node.js创建一个web服务器(Server)的详细步骤
2021/01/15 Javascript
[02:11]DOTA2上海特级锦标赛主赛事第二日RECAP
2016/03/04 DOTA
Python实现抓取百度搜索结果页的网站标题信息
2015/01/22 Python
python 出现SyntaxError: non-keyword arg after keyword arg错误解决办法
2017/02/14 Python
Python3 适合初学者学习的银行账户登录系统实例
2017/08/08 Python
matplotlib作图添加表格实例代码
2018/01/23 Python
python使用PIL剪切和拼接图片
2020/03/23 Python
python 的topk算法实例
2020/04/02 Python
python实现与redis交互操作详解
2020/04/21 Python
简单了解python列表和元组的区别
2020/05/14 Python
Python3 用什么IDE开发工具比较好
2020/11/28 Python
描述RIP和OSPF区别以及特点
2015/01/17 面试题
个人优缺点自我评价
2014/01/27 职场文书
教师节商场活动方案
2014/02/13 职场文书
小学节能减排倡议书
2014/05/15 职场文书
校园绿化美化方案
2014/06/08 职场文书
2014年办公室文秘工作总结
2014/12/09 职场文书
2016初一新生军训心得体会
2016/01/11 职场文书
SpringBoot项目多数据源及mybatis 驼峰失效的问题解决方法
2022/07/07 Java/Android
mysql sock文件存储了什么信息
2022/07/15 MySQL