dpn网络的pytorch实现方式


Posted in Python onJanuary 14, 2020

我就废话不多说了,直接上代码吧!

import torch
import torch.nn as nn
import torch.nn.functional as F



class CatBnAct(nn.Module):
 def __init__(self, in_chs, activation_fn=nn.ReLU(inplace=True)):
  super(CatBnAct, self).__init__()
  self.bn = nn.BatchNorm2d(in_chs, eps=0.001)
  self.act = activation_fn

 def forward(self, x):
  x = torch.cat(x, dim=1) if isinstance(x, tuple) else x
  return self.act(self.bn(x))


class BnActConv2d(nn.Module):
 def __init__(self, s, out_chs, kernel_size, stride,
     padding=0, groups=1, activation_fn=nn.ReLU(inplace=True)):
  super(BnActConv2d, self).__init__()
  self.bn = nn.BatchNorm2d(in_chs, eps=0.001)
  self.act = activation_fn
  self.conv = nn.Conv2d(in_chs, out_chs, kernel_size, stride, padding, groups=groups, bias=False)

 def forward(self, x):
  return self.conv(self.act(self.bn(x)))


class InputBlock(nn.Module):
 def __init__(self, num_init_features, kernel_size=7,
     padding=3, activation_fn=nn.ReLU(inplace=True)):
  super(InputBlock, self).__init__()
  self.conv = nn.Conv2d(
   3, num_init_features, kernel_size=kernel_size, stride=2, padding=padding, bias=False)
  self.bn = nn.BatchNorm2d(num_init_features, eps=0.001)
  self.act = activation_fn
  self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

 def forward(self, x):
  x = self.conv(x)
  x = self.bn(x)
  x = self.act(x)
  x = self.pool(x)
  return x


class DualPathBlock(nn.Module):
 def __init__(
   self, in_chs, num_1x1_a, num_3x3_b, num_1x1_c, inc, groups, block_type='normal', b=False):
  super(DualPathBlock, self).__init__()
  self.num_1x1_c = num_1x1_c
  self.inc = inc
  self.b = b
  if block_type is 'proj':
   self.key_stride = 1
   self.has_proj = True
  elif block_type is 'down':
   self.key_stride = 2
   self.has_proj = True
  else:
   assert block_type is 'normal'
   self.key_stride = 1
   self.has_proj = False

  if self.has_proj:
   # Using different member names here to allow easier parameter key matching for conversion
   if self.key_stride == 2:
    self.c1x1_w_s2 = BnActConv2d(
     in_chs=in_chs, out_chs=num_1x1_c + 2 * inc, kernel_size=1, stride=2)
   else:
    self.c1x1_w_s1 = BnActConv2d(
     in_chs=in_chs, out_chs=num_1x1_c + 2 * inc, kernel_size=1, stride=1)
  self.c1x1_a = BnActConv2d(in_chs=in_chs, out_chs=num_1x1_a, kernel_size=1, stride=1)
  self.c3x3_b = BnActConv2d(
   in_chs=num_1x1_a, out_chs=num_3x3_b, kernel_size=3,
   stride=self.key_stride, padding=1, groups=groups)
  if b:
   self.c1x1_c = CatBnAct(in_chs=num_3x3_b)
   self.c1x1_c1 = nn.Conv2d(num_3x3_b, num_1x1_c, kernel_size=1, bias=False)
   self.c1x1_c2 = nn.Conv2d(num_3x3_b, inc, kernel_size=1, bias=False)
  else:
   self.c1x1_c = BnActConv2d(in_chs=num_3x3_b, out_chs=num_1x1_c + inc, kernel_size=1, stride=1)

 def forward(self, x):
  x_in = torch.cat(x, dim=1) if isinstance(x, tuple) else x
  if self.has_proj:
   if self.key_stride == 2:
    x_s = self.c1x1_w_s2(x_in)
   else:
    x_s = self.c1x1_w_s1(x_in)
   x_s1 = x_s[:, :self.num_1x1_c, :, :]
   x_s2 = x_s[:, self.num_1x1_c:, :, :]
  else:
   x_s1 = x[0]
   x_s2 = x[1]
  x_in = self.c1x1_a(x_in)
  x_in = self.c3x3_b(x_in)
  if self.b:
   x_in = self.c1x1_c(x_in)
   out1 = self.c1x1_c1(x_in)
   out2 = self.c1x1_c2(x_in)
  else:
   x_in = self.c1x1_c(x_in)
   out1 = x_in[:, :self.num_1x1_c, :, :]
   out2 = x_in[:, self.num_1x1_c:, :, :]
  resid = x_s1 + out1
  dense = torch.cat([x_s2, out2], dim=1)
  return resid, dense


class DPN(nn.Module):
 def __init__(self, small=False, num_init_features=64, k_r=96, groups=32,
     b=False, k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128),
     num_classes=1000, test_time_pool=False):
  super(DPN, self).__init__()
  self.test_time_pool = test_time_pool
  self.b = b
  bw_factor = 1 if small else 4

  blocks = OrderedDict()

  # conv1
  if small:
   blocks['conv1_1'] = InputBlock(num_init_features, kernel_size=3, padding=1)
  else:
   blocks['conv1_1'] = InputBlock(num_init_features, kernel_size=7, padding=3)

  # conv2
  bw = 64 * bw_factor
  inc = inc_sec[0]
  r = (k_r * bw) // (64 * bw_factor)
  blocks['conv2_1'] = DualPathBlock(num_init_features, r, r, bw, inc, groups, 'proj', b)
  in_chs = bw + 3 * inc
  for i in range(2, k_sec[0] + 1):
   blocks['conv2_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b)
   in_chs += inc

  # conv3
  bw = 128 * bw_factor
  inc = inc_sec[1]
  r = (k_r * bw) // (64 * bw_factor)
  blocks['conv3_1'] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'down', b)
  in_chs = bw + 3 * inc
  for i in range(2, k_sec[1] + 1):
   blocks['conv3_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b)
   in_chs += inc

  # conv4
  bw = 256 * bw_factor
  inc = inc_sec[2]
  r = (k_r * bw) // (64 * bw_factor)
  blocks['conv4_1'] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'down', b)
  in_chs = bw + 3 * inc
  for i in range(2, k_sec[2] + 1):
   blocks['conv4_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b)
   in_chs += inc

  # conv5
  bw = 512 * bw_factor
  inc = inc_sec[3]
  r = (k_r * bw) // (64 * bw_factor)
  blocks['conv5_1'] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'down', b)
  in_chs = bw + 3 * inc
  for i in range(2, k_sec[3] + 1):
   blocks['conv5_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b)
   in_chs += inc
  blocks['conv5_bn_ac'] = CatBnAct(in_chs)

  self.features = nn.Sequential(blocks)

  # Using 1x1 conv for the FC layer to allow the extra pooling scheme
  self.last_linear = nn.Conv2d(in_chs, num_classes, kernel_size=1, bias=True)

 def logits(self, features):
  if not self.training and self.test_time_pool:
   x = F.avg_pool2d(features, kernel_size=7, stride=1)
   out = self.last_linear(x)
   # The extra test time pool should be pooling an img_size//32 - 6 size patch
   out = adaptive_avgmax_pool2d(out, pool_type='avgmax')
  else:
   x = adaptive_avgmax_pool2d(features, pool_type='avg')
   out = self.last_linear(x)
  return out.view(out.size(0), -1)

 def forward(self, input):
  x = self.features(input)
  x = self.logits(x)
  return x

""" PyTorch selectable adaptive pooling
Adaptive pooling with the ability to select the type of pooling from:
 * 'avg' - Average pooling
 * 'max' - Max pooling
 * 'avgmax' - Sum of average and max pooling re-scaled by 0.5
 * 'avgmaxc' - Concatenation of average and max pooling along feature dim, doubles feature dim

Both a functional and a nn.Module version of the pooling is provided.

"""

def pooling_factor(pool_type='avg'):
 return 2 if pool_type == 'avgmaxc' else 1


def adaptive_avgmax_pool2d(x, pool_type='avg', padding=0, count_include_pad=False):
 """Selectable global pooling function with dynamic input kernel size
 """
 if pool_type == 'avgmaxc':
  x = torch.cat([
   F.avg_pool2d(
    x, kernel_size=(x.size(2), x.size(3)), padding=padding, count_include_pad=count_include_pad),
   F.max_pool2d(x, kernel_size=(x.size(2), x.size(3)), padding=padding)
  ], dim=1)
 elif pool_type == 'avgmax':
  x_avg = F.avg_pool2d(
    x, kernel_size=(x.size(2), x.size(3)), padding=padding, count_include_pad=count_include_pad)
  x_max = F.max_pool2d(x, kernel_size=(x.size(2), x.size(3)), padding=padding)
  x = 0.5 * (x_avg + x_max)
 elif pool_type == 'max':
  x = F.max_pool2d(x, kernel_size=(x.size(2), x.size(3)), padding=padding)
 else:
  if pool_type != 'avg':
   print('Invalid pool type %s specified. Defaulting to average pooling.' % pool_type)
  x = F.avg_pool2d(
   x, kernel_size=(x.size(2), x.size(3)), padding=padding, count_include_pad=count_include_pad)
 return x


class AdaptiveAvgMaxPool2d(torch.nn.Module):
 """Selectable global pooling layer with dynamic input kernel size
 """
 def __init__(self, output_size=1, pool_type='avg'):
  super(AdaptiveAvgMaxPool2d, self).__init__()
  self.output_size = output_size
  self.pool_type = pool_type
  if pool_type == 'avgmaxc' or pool_type == 'avgmax':
   self.pool = nn.ModuleList([nn.AdaptiveAvgPool2d(output_size), nn.AdaptiveMaxPool2d(output_size)])
  elif pool_type == 'max':
   self.pool = nn.AdaptiveMaxPool2d(output_size)
  else:
   if pool_type != 'avg':
    print('Invalid pool type %s specified. Defaulting to average pooling.' % pool_type)
   self.pool = nn.AdaptiveAvgPool2d(output_size)

 def forward(self, x):
  if self.pool_type == 'avgmaxc':
   x = torch.cat([p(x) for p in self.pool], dim=1)
  elif self.pool_type == 'avgmax':
   x = 0.5 * torch.sum(torch.stack([p(x) for p in self.pool]), 0).squeeze(dim=0)
  else:
   x = self.pool(x)
  return x

 def factor(self):
  return pooling_factor(self.pool_type)

 def __repr__(self):
  return self.__class__.__name__ + ' (' \
    + 'output_size=' + str(self.output_size) \
    + ', pool_type=' + self.pool_type + ')'

以上这篇dpn网络的pytorch实现方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
简单的通用表达式求10乘阶示例
Mar 03 Python
Python压缩解压缩zip文件及破解zip文件密码的方法
Nov 04 Python
Python开发中爬虫使用代理proxy抓取网页的方法示例
Sep 26 Python
Python程序员面试题 你必须提前准备!(答案及解析)
Jan 23 Python
python 3.3 下载固定链接文件并保存的方法
Dec 18 Python
Python中如何使用if语句处理列表实例代码
Feb 24 Python
Python实现变声器功能(萝莉音御姐音)
Dec 05 Python
Python: 传递列表副本方式
Dec 19 Python
Python爬虫实例——scrapy框架爬取拉勾网招聘信息
Jul 14 Python
利用python爬取有道词典的方法
Dec 08 Python
Python 实现RSA加解密文本文件
Dec 30 Python
解决selenium+Headless Chrome实现不弹出浏览器自动化登录的问题
Jan 09 Python
Django之form组件自动校验数据实现
Jan 14 #Python
简单了解python filter、map、reduce的区别
Jan 14 #Python
Python vtk读取并显示dicom文件示例
Jan 13 #Python
Python解析多帧dicom数据详解
Jan 13 #Python
python 将dicom图片转换成jpg图片的实例
Jan 13 #Python
基于Python和PyYAML读取yaml配置文件数据
Jan 13 #Python
Python 实现判断图片格式并转换,将转换的图像存到生成的文件夹中
Jan 13 #Python
You might like
基于PHP CURL获取邮箱地址的详解
2013/06/03 PHP
PHP5中实现多态的两种方法实例分享
2014/04/21 PHP
php解析xml方法实例详解
2015/05/12 PHP
PHP 传输会话curl函数的实例详解
2017/09/12 PHP
PHP实现与java 通信的插件使用教程
2019/08/11 PHP
jquery 重写 ajax提交并判断权限后 使用load方法报错解决方法
2016/01/19 Javascript
微信小程序 icon组件详细及实例代码
2016/10/25 Javascript
Bootstrap表格制作代码
2017/03/17 Javascript
深入学习 JavaScript中的函数调用
2017/03/23 Javascript
javascript数组定义的几种方法
2017/10/06 Javascript
详解vue 计算属性与方法跟侦听器区别(面试考点)
2018/04/23 Javascript
详解mpvue开发小程序小总结
2018/07/25 Javascript
详解React中传入组件的props改变时更新组件的几种实现方法
2018/09/13 Javascript
Vux+Axios拦截器增加loading的问题及实现方法
2018/11/08 Javascript
JS实现判断有效的数独算法示例
2019/02/25 Javascript
一百行JS代码实现一个校验工具
2019/04/30 Javascript
通过实例了解js函数中参数的传递
2019/06/15 Javascript
jQuery实现input[type=file]多图预览上传删除等功能
2019/08/02 jQuery
JS继承实现方法及优缺点详解
2020/09/02 Javascript
vue中如何添加百度统计代码
2020/12/19 Vue.js
[15:07]lgd_OG_m2_BP
2019/09/10 DOTA
python模拟新浪微博登陆功能(新浪微博爬虫)
2013/12/24 Python
Python求出0~100以内的所有素数
2018/01/23 Python
python使用tensorflow保存、加载和使用模型的方法
2018/01/31 Python
Python 实现在文件中的每一行添加一个逗号
2018/04/29 Python
关于python3中setup.py小概念解析
2019/08/22 Python
几款Python编译器比较与推荐(小结)
2020/10/15 Python
CSS3 特效范例整理
2011/08/22 HTML / CSS
伦敦平价潮流珠宝首饰品牌:Astrid & Miyu
2016/10/10 全球购物
应付会计岗位职责
2013/12/12 职场文书
展会邀请函范文
2014/01/26 职场文书
中学生学雷锋活动心得体会
2014/03/10 职场文书
高中美术教师事迹材料
2014/08/22 职场文书
纪念九一八事变演讲稿:忘记意味着背叛
2014/09/14 职场文书
2015年销售人员工作总结
2015/04/07 职场文书
win10电脑老是死机怎么办?win10系统老是死机的解决方法
2022/08/05 数码科技