使用pytorch实现论文中的unet网络


Posted in Python onJune 24, 2020

设计神经网络的一般步骤:

1. 设计框架

2. 设计骨干网络

Unet网络设计的步骤:

1. 设计Unet网络工厂模式

2. 设计编解码结构

3. 设计卷积模块

4. unet实例模块

Unet网络最重要的特征:

1. 编解码结构。

2. 解码结构,比FCN更加完善,采用连接方式。

3. 本质是一个框架,编码部分可以使用很多图像分类网络。

示例代码:

import torch
import torch.nn as nn

class Unet(nn.Module):
 #初始化参数:Encoder,Decoder,bridge
 #bridge默认值为无,如果有参数传入,则用该参数替换None
 def __init__(self,Encoder,Decoder,bridge = None):
  super(Unet,self).__init__()
  self.encoder = Encoder(encoder_blocks)
  self.decoder = Decoder(decoder_blocks)
  self.bridge = bridge
 def forward(self,x):
  res = self.encoder(x)
  out,skip = res[0],res[1,:]
  if bridge is not None:
   out = bridge(out)
  out = self.decoder(out,skip)
  return out
#设计编码模块
class Encoder(nn.Module):
 def __init__(self,blocks):
  super(Encoder,self).__init__()
  #assert:断言函数,避免出现参数错误
  assert len(blocks) > 0
  #nn.Modulelist():模型列表,所有的参数可以纳入网络,但是没有forward函数
  self.blocks = nn.Modulelist(blocks)
 def forward(self,x):
  skip = []
  for i in range(len(self.blocks) - 1):
   x = self.blocks[i](x)
   skip.append(x)
  res = [self.block[i+1](x)]
  #列表之间可以通过+号拼接
  res += skip
  return res
#设计Decoder模块
class Decoder(nn.Module):
 def __init__(self,blocks):
  super(Decoder, self).__init__()
  assert len(blocks) > 0
  self.blocks = nn.Modulelist(blocks)
 def ceter_crop(self,skips,x):
  _,_,height1,width1 = skips.shape()
  _,_,height2,width2 = x.shape()
  #对图像进行剪切处理,拼接的时候保持对应size参数一致
  ht,wt = min(height1,height2),min(width1,width2)
  dh1 = (height1 - height2)//2 if height1 > height2 else 0
  dw1 = (width1 - width2)//2 if width1 > width2 else 0
  dh2 = (height2 - height1)//2 if height2 > height1 else 0
  dw2 = (width2 - width1)//2 if width2 > width1 else 0
  return skips[:,:,dh1:(dh1 + ht),dw1:(dw1 + wt)],\
    x[:,:,dh2:(dh2 + ht),dw2 : (dw2 + wt)]

 def forward(self, skips,x,reverse_skips = True):
  assert len(skips) == len(blocks) - 1
  if reverse_skips is True:
   skips = skips[: : -1]
  x = self.blocks[0](x)
  for i in range(1, len(self.blocks)):
   skip = skips[i-1]
   x = torch.cat(skip,x,1)
   x = self.blocks[i](x)
  return x
#定义了一个卷积block
def unet_convs(in_channels,out_channels,padding = 0):
 #nn.Sequential:与Modulelist相比,包含了forward函数
 return nn.Sequential(
  nn.Conv2d(in_channels, out_channels, kernal_size = 3, padding = padding, bias = False),
  nn.BatchNorm2d(outchannels),
  nn.ReLU(inplace = True),
  nn.Conv2d(in_channels, out_channels, kernal_size=3, padding=padding, bias=False),
  nn.BatchNorm2d(outchannels),
  nn.ReLU(inplace=True),
 )
#实例化Unet模型
def unet(in_channels,out_channels):
 encoder_blocks = [unet_convs(in_channels, 64),\
      nn.Sequential(nn.Maxpool2d(kernal_size = 2, stride = 2, ceil_mode = True),\
         unet_convs(64,128)), \
      nn.Sequential(nn.Maxpool2d(kernal_size=2, stride=2, ceil_mode=True), \
         unet_convs(128, 256)),
      nn.Sequential(nn.Maxpool2d(kernal_size=2, stride=2, ceil_mode=True), \
         unet_convs(256, 512)),
      ]
 bridge = nn.Sequential(unet_convs(512, 1024))
 decoder_blocks = [nn.conTranpose2d(1024, 512), \
      nn.Sequential(unet_convs(1024, 512),
         nn.conTranpose2d(512, 256)),\
      nn.Sequential(unet_convs(512, 256),
         nn.conTranpose2d(256, 128)), \
      nn.Sequential(unet_convs(512, 256),
         nn.conTranpose2d(256, 128)), \
      nn.Sequential(unet_convs(256, 128),
         nn.conTranpose2d(128, 64))
      ]
 return Unet(encoder_blocks,decoder_blocks,bridge)

补充知识:Pytorch搭建U-Net网络

U-Net: Convolutional Networks for Biomedical Image Segmentation

使用pytorch实现论文中的unet网络

import torch.nn as nn
import torch
from torch import autograd
from torchsummary import summary

class DoubleConv(nn.Module):
 def __init__(self, in_ch, out_ch):
  super(DoubleConv, self).__init__()
  self.conv = nn.Sequential(
   nn.Conv2d(in_ch, out_ch, 3, padding=0),
   nn.BatchNorm2d(out_ch),
   nn.ReLU(inplace=True),
   nn.Conv2d(out_ch, out_ch, 3, padding=0),
   nn.BatchNorm2d(out_ch),
   nn.ReLU(inplace=True)
  )

 def forward(self, input):
  return self.conv(input)

class Unet(nn.Module):
 def __init__(self, in_ch, out_ch):
  super(Unet, self).__init__()
  self.conv1 = DoubleConv(in_ch, 64)
  self.pool1 = nn.MaxPool2d(2)
  self.conv2 = DoubleConv(64, 128)
  self.pool2 = nn.MaxPool2d(2)
  self.conv3 = DoubleConv(128, 256)
  self.pool3 = nn.MaxPool2d(2)
  self.conv4 = DoubleConv(256, 512)
  self.pool4 = nn.MaxPool2d(2)
  self.conv5 = DoubleConv(512, 1024)
  # 逆卷积,也可以使用上采样
  self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
  self.conv6 = DoubleConv(1024, 512)
  self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
  self.conv7 = DoubleConv(512, 256)
  self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
  self.conv8 = DoubleConv(256, 128)
  self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
  self.conv9 = DoubleConv(128, 64)
  self.conv10 = nn.Conv2d(64, out_ch, 1)

 def forward(self, x):
  c1 = self.conv1(x)
  crop1 = c1[:,:,88:480,88:480]
  p1 = self.pool1(c1)
  c2 = self.conv2(p1)
  crop2 = c2[:,:,40:240,40:240]
  p2 = self.pool2(c2)
  c3 = self.conv3(p2)
  crop3 = c3[:,:,16:120,16:120]
  p3 = self.pool3(c3)
  c4 = self.conv4(p3)
  crop4 = c4[:,:,4:60,4:60]
  p4 = self.pool4(c4)
  c5 = self.conv5(p4)
  up_6 = self.up6(c5)
  merge6 = torch.cat([up_6, crop4], dim=1)
  c6 = self.conv6(merge6)
  up_7 = self.up7(c6)
  merge7 = torch.cat([up_7, crop3], dim=1)
  c7 = self.conv7(merge7)
  up_8 = self.up8(c7)
  merge8 = torch.cat([up_8, crop2], dim=1)
  c8 = self.conv8(merge8)
  up_9 = self.up9(c8)
  merge9 = torch.cat([up_9, crop1], dim=1)
  c9 = self.conv9(merge9)
  c10 = self.conv10(c9)
  out = nn.Sigmoid()(c10)
  return out

if __name__=="__main__":
 test_input=torch.rand(1, 1, 572, 572)
 model=Unet(in_ch=1, out_ch=2)
 summary(model, (1,572,572))
 ouput=model(test_input)
 print(ouput.size())
----------------------------------------------------------------
  Layer (type)    Output Shape   Param #
================================================================
   Conv2d-1   [-1, 64, 570, 570]    640
  BatchNorm2d-2   [-1, 64, 570, 570]    128
    ReLU-3   [-1, 64, 570, 570]    0
   Conv2d-4   [-1, 64, 568, 568]   36,928
  BatchNorm2d-5   [-1, 64, 568, 568]    128
    ReLU-6   [-1, 64, 568, 568]    0
  DoubleConv-7   [-1, 64, 568, 568]    0
   MaxPool2d-8   [-1, 64, 284, 284]    0
   Conv2d-9  [-1, 128, 282, 282]   73,856
  BatchNorm2d-10  [-1, 128, 282, 282]    256
    ReLU-11  [-1, 128, 282, 282]    0
   Conv2d-12  [-1, 128, 280, 280]   147,584
  BatchNorm2d-13  [-1, 128, 280, 280]    256
    ReLU-14  [-1, 128, 280, 280]    0
  DoubleConv-15  [-1, 128, 280, 280]    0
  MaxPool2d-16  [-1, 128, 140, 140]    0
   Conv2d-17  [-1, 256, 138, 138]   295,168
  BatchNorm2d-18  [-1, 256, 138, 138]    512
    ReLU-19  [-1, 256, 138, 138]    0
   Conv2d-20  [-1, 256, 136, 136]   590,080
  BatchNorm2d-21  [-1, 256, 136, 136]    512
    ReLU-22  [-1, 256, 136, 136]    0
  DoubleConv-23  [-1, 256, 136, 136]    0
  MaxPool2d-24   [-1, 256, 68, 68]    0
   Conv2d-25   [-1, 512, 66, 66]  1,180,160
  BatchNorm2d-26   [-1, 512, 66, 66]   1,024
    ReLU-27   [-1, 512, 66, 66]    0
   Conv2d-28   [-1, 512, 64, 64]  2,359,808
  BatchNorm2d-29   [-1, 512, 64, 64]   1,024
    ReLU-30   [-1, 512, 64, 64]    0
  DoubleConv-31   [-1, 512, 64, 64]    0
  MaxPool2d-32   [-1, 512, 32, 32]    0
   Conv2d-33   [-1, 1024, 30, 30]  4,719,616
  BatchNorm2d-34   [-1, 1024, 30, 30]   2,048
    ReLU-35   [-1, 1024, 30, 30]    0
   Conv2d-36   [-1, 1024, 28, 28]  9,438,208
  BatchNorm2d-37   [-1, 1024, 28, 28]   2,048
    ReLU-38   [-1, 1024, 28, 28]    0
  DoubleConv-39   [-1, 1024, 28, 28]    0
 ConvTranspose2d-40   [-1, 512, 56, 56]  2,097,664
   Conv2d-41   [-1, 512, 54, 54]  4,719,104
  BatchNorm2d-42   [-1, 512, 54, 54]   1,024
    ReLU-43   [-1, 512, 54, 54]    0
   Conv2d-44   [-1, 512, 52, 52]  2,359,808
  BatchNorm2d-45   [-1, 512, 52, 52]   1,024
    ReLU-46   [-1, 512, 52, 52]    0
  DoubleConv-47   [-1, 512, 52, 52]    0
 ConvTranspose2d-48  [-1, 256, 104, 104]   524,544
   Conv2d-49  [-1, 256, 102, 102]  1,179,904
  BatchNorm2d-50  [-1, 256, 102, 102]    512
    ReLU-51  [-1, 256, 102, 102]    0
   Conv2d-52  [-1, 256, 100, 100]   590,080
  BatchNorm2d-53  [-1, 256, 100, 100]    512
    ReLU-54  [-1, 256, 100, 100]    0
  DoubleConv-55  [-1, 256, 100, 100]    0
 ConvTranspose2d-56  [-1, 128, 200, 200]   131,200
   Conv2d-57  [-1, 128, 198, 198]   295,040
  BatchNorm2d-58  [-1, 128, 198, 198]    256
    ReLU-59  [-1, 128, 198, 198]    0
   Conv2d-60  [-1, 128, 196, 196]   147,584
  BatchNorm2d-61  [-1, 128, 196, 196]    256
    ReLU-62  [-1, 128, 196, 196]    0
  DoubleConv-63  [-1, 128, 196, 196]    0
 ConvTranspose2d-64   [-1, 64, 392, 392]   32,832
   Conv2d-65   [-1, 64, 390, 390]   73,792
  BatchNorm2d-66   [-1, 64, 390, 390]    128
    ReLU-67   [-1, 64, 390, 390]    0
   Conv2d-68   [-1, 64, 388, 388]   36,928
  BatchNorm2d-69   [-1, 64, 388, 388]    128
    ReLU-70   [-1, 64, 388, 388]    0
  DoubleConv-71   [-1, 64, 388, 388]    0
   Conv2d-72   [-1, 2, 388, 388]    130
================================================================
Total params: 31,042,434
Trainable params: 31,042,434
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 1.25
Forward/backward pass size (MB): 3280.59
Params size (MB): 118.42
Estimated Total Size (MB): 3400.26
----------------------------------------------------------------
torch.Size([1, 2, 388, 388])

以上这篇使用pytorch实现论文中的unet网络就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python返回真假值(True or False)小技巧
Apr 10 Python
详解Python中DOM方法的动态性
Apr 11 Python
详解python中asyncio模块
Mar 03 Python
Python异常处理操作实例详解
May 10 Python
对python sklearn one-hot编码详解
Jul 10 Python
Python实现查找最小的k个数示例【两种解法】
Jan 08 Python
利用Python查看微信共同好友功能的实现代码
Apr 24 Python
PyQt5的PyQtGraph实践系列3之实时数据更新绘制图形
May 13 Python
python使用flask与js进行前后台交互的例子
Jul 19 Python
如何用tempfile库创建python进程中的临时文件
Jan 28 Python
一文带你掌握Pyecharts地理数据可视化的方法
Feb 06 Python
Python中的matplotlib绘制百分比堆叠柱状图,并为每一个类别设置不同的填充图案
Apr 20 Python
python连接mysql有哪些方法
Jun 24 #Python
pytorch VGG11识别cifar10数据集(训练+预测单张输入图片操作)
Jun 24 #Python
Python Tornado核心及相关原理详解
Jun 24 #Python
如何使用Python处理HDF格式数据及可视化问题
Jun 24 #Python
pytorch SENet实现案例
Jun 24 #Python
利用PyTorch实现VGG16教程
Jun 24 #Python
python安装读取grib库总结(推荐)
Jun 24 #Python
You might like
用windows下编译过的eAccelerator for PHP 5.1.6实现php加速的使用方法
2007/09/30 PHP
PHP 多维数组排序实现代码
2009/08/05 PHP
PHP-CGI进程CPU 100% 与 file_get_contents 函数的关系分析
2011/08/15 PHP
PHP二维数组矩形转置实例
2016/07/20 PHP
三个思路解决laravel上传文件报错:413 Request Entity Too Large问题
2017/11/13 PHP
Ext4.2的Ext.grid.plugin.RowExpander无法触发事件解决办法
2014/08/15 Javascript
Javascript毫秒数用法实例
2015/02/05 Javascript
jQuery实现个性翻牌效果导航菜单的方法
2015/03/09 Javascript
js实现基于正则表达式的轻量提示插件
2015/08/29 Javascript
jQuery自动完成插件completer附源码下载
2016/01/04 Javascript
JQuery为元素添加样式的实现方法
2016/07/20 Javascript
AngularJS基础 ng-csp 指令详解
2016/08/01 Javascript
微信小程序  http请求封装详解及实例代码
2017/02/15 Javascript
AngularJS路由实现页面跳转实例
2017/03/03 Javascript
详解node nvm进行node多版本管理
2017/10/21 Javascript
最实用的JS数组函数整理
2017/12/05 Javascript
利用JavaScript的Map提升性能的方法详解
2019/08/14 Javascript
[02:12]2015国际邀请赛 SHOWOPEN
2015/08/05 DOTA
[03:36]DOTA2完美大师赛coL战队趣味视频——我演你猜
2017/11/23 DOTA
python在控制台输出进度条的方法
2015/06/20 Python
解决DataFrame排序sort的问题
2018/06/07 Python
django从请求到响应的过程深入讲解
2018/08/01 Python
正确理解Python中if __name__ == '__main__'
2019/01/24 Python
详解Python给照片换底色(蓝底换红底)
2019/03/22 Python
Pyqt QImage 与 np array 转换方法
2019/06/27 Python
解决python中显示图片的plt.imshow plt.show()内存泄漏问题
2020/04/24 Python
python将logging模块封装成单独模块并实现动态切换Level方式
2020/05/12 Python
pip install命令安装扩展库整理
2021/03/02 Python
Html5之svg可缩放矢量图形_动力节点Java学院整理
2017/07/17 HTML / CSS
英国高街电视:High Street TV
2018/05/22 全球购物
管理学专业个人求职信范文
2013/09/21 职场文书
应届生文秘专业个人自荐信格式
2013/09/21 职场文书
会计专业毕业生自荐信范文
2013/12/20 职场文书
元旦获奖感言
2014/03/08 职场文书
学生未请假就回家检讨书
2014/09/22 职场文书
Redis IP地址的绑定的实现
2021/05/08 Redis