使用pytorch实现论文中的unet网络


Posted in Python onJune 24, 2020

设计神经网络的一般步骤:

1. 设计框架

2. 设计骨干网络

Unet网络设计的步骤:

1. 设计Unet网络工厂模式

2. 设计编解码结构

3. 设计卷积模块

4. unet实例模块

Unet网络最重要的特征:

1. 编解码结构。

2. 解码结构,比FCN更加完善,采用连接方式。

3. 本质是一个框架,编码部分可以使用很多图像分类网络。

示例代码:

import torch
import torch.nn as nn

class Unet(nn.Module):
 #初始化参数:Encoder,Decoder,bridge
 #bridge默认值为无,如果有参数传入,则用该参数替换None
 def __init__(self,Encoder,Decoder,bridge = None):
  super(Unet,self).__init__()
  self.encoder = Encoder(encoder_blocks)
  self.decoder = Decoder(decoder_blocks)
  self.bridge = bridge
 def forward(self,x):
  res = self.encoder(x)
  out,skip = res[0],res[1,:]
  if bridge is not None:
   out = bridge(out)
  out = self.decoder(out,skip)
  return out
#设计编码模块
class Encoder(nn.Module):
 def __init__(self,blocks):
  super(Encoder,self).__init__()
  #assert:断言函数,避免出现参数错误
  assert len(blocks) > 0
  #nn.Modulelist():模型列表,所有的参数可以纳入网络,但是没有forward函数
  self.blocks = nn.Modulelist(blocks)
 def forward(self,x):
  skip = []
  for i in range(len(self.blocks) - 1):
   x = self.blocks[i](x)
   skip.append(x)
  res = [self.block[i+1](x)]
  #列表之间可以通过+号拼接
  res += skip
  return res
#设计Decoder模块
class Decoder(nn.Module):
 def __init__(self,blocks):
  super(Decoder, self).__init__()
  assert len(blocks) > 0
  self.blocks = nn.Modulelist(blocks)
 def ceter_crop(self,skips,x):
  _,_,height1,width1 = skips.shape()
  _,_,height2,width2 = x.shape()
  #对图像进行剪切处理,拼接的时候保持对应size参数一致
  ht,wt = min(height1,height2),min(width1,width2)
  dh1 = (height1 - height2)//2 if height1 > height2 else 0
  dw1 = (width1 - width2)//2 if width1 > width2 else 0
  dh2 = (height2 - height1)//2 if height2 > height1 else 0
  dw2 = (width2 - width1)//2 if width2 > width1 else 0
  return skips[:,:,dh1:(dh1 + ht),dw1:(dw1 + wt)],\
    x[:,:,dh2:(dh2 + ht),dw2 : (dw2 + wt)]

 def forward(self, skips,x,reverse_skips = True):
  assert len(skips) == len(blocks) - 1
  if reverse_skips is True:
   skips = skips[: : -1]
  x = self.blocks[0](x)
  for i in range(1, len(self.blocks)):
   skip = skips[i-1]
   x = torch.cat(skip,x,1)
   x = self.blocks[i](x)
  return x
#定义了一个卷积block
def unet_convs(in_channels,out_channels,padding = 0):
 #nn.Sequential:与Modulelist相比,包含了forward函数
 return nn.Sequential(
  nn.Conv2d(in_channels, out_channels, kernal_size = 3, padding = padding, bias = False),
  nn.BatchNorm2d(outchannels),
  nn.ReLU(inplace = True),
  nn.Conv2d(in_channels, out_channels, kernal_size=3, padding=padding, bias=False),
  nn.BatchNorm2d(outchannels),
  nn.ReLU(inplace=True),
 )
#实例化Unet模型
def unet(in_channels,out_channels):
 encoder_blocks = [unet_convs(in_channels, 64),\
      nn.Sequential(nn.Maxpool2d(kernal_size = 2, stride = 2, ceil_mode = True),\
         unet_convs(64,128)), \
      nn.Sequential(nn.Maxpool2d(kernal_size=2, stride=2, ceil_mode=True), \
         unet_convs(128, 256)),
      nn.Sequential(nn.Maxpool2d(kernal_size=2, stride=2, ceil_mode=True), \
         unet_convs(256, 512)),
      ]
 bridge = nn.Sequential(unet_convs(512, 1024))
 decoder_blocks = [nn.conTranpose2d(1024, 512), \
      nn.Sequential(unet_convs(1024, 512),
         nn.conTranpose2d(512, 256)),\
      nn.Sequential(unet_convs(512, 256),
         nn.conTranpose2d(256, 128)), \
      nn.Sequential(unet_convs(512, 256),
         nn.conTranpose2d(256, 128)), \
      nn.Sequential(unet_convs(256, 128),
         nn.conTranpose2d(128, 64))
      ]
 return Unet(encoder_blocks,decoder_blocks,bridge)

补充知识:Pytorch搭建U-Net网络

U-Net: Convolutional Networks for Biomedical Image Segmentation

使用pytorch实现论文中的unet网络

import torch.nn as nn
import torch
from torch import autograd
from torchsummary import summary

class DoubleConv(nn.Module):
 def __init__(self, in_ch, out_ch):
  super(DoubleConv, self).__init__()
  self.conv = nn.Sequential(
   nn.Conv2d(in_ch, out_ch, 3, padding=0),
   nn.BatchNorm2d(out_ch),
   nn.ReLU(inplace=True),
   nn.Conv2d(out_ch, out_ch, 3, padding=0),
   nn.BatchNorm2d(out_ch),
   nn.ReLU(inplace=True)
  )

 def forward(self, input):
  return self.conv(input)

class Unet(nn.Module):
 def __init__(self, in_ch, out_ch):
  super(Unet, self).__init__()
  self.conv1 = DoubleConv(in_ch, 64)
  self.pool1 = nn.MaxPool2d(2)
  self.conv2 = DoubleConv(64, 128)
  self.pool2 = nn.MaxPool2d(2)
  self.conv3 = DoubleConv(128, 256)
  self.pool3 = nn.MaxPool2d(2)
  self.conv4 = DoubleConv(256, 512)
  self.pool4 = nn.MaxPool2d(2)
  self.conv5 = DoubleConv(512, 1024)
  # 逆卷积,也可以使用上采样
  self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
  self.conv6 = DoubleConv(1024, 512)
  self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
  self.conv7 = DoubleConv(512, 256)
  self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
  self.conv8 = DoubleConv(256, 128)
  self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
  self.conv9 = DoubleConv(128, 64)
  self.conv10 = nn.Conv2d(64, out_ch, 1)

 def forward(self, x):
  c1 = self.conv1(x)
  crop1 = c1[:,:,88:480,88:480]
  p1 = self.pool1(c1)
  c2 = self.conv2(p1)
  crop2 = c2[:,:,40:240,40:240]
  p2 = self.pool2(c2)
  c3 = self.conv3(p2)
  crop3 = c3[:,:,16:120,16:120]
  p3 = self.pool3(c3)
  c4 = self.conv4(p3)
  crop4 = c4[:,:,4:60,4:60]
  p4 = self.pool4(c4)
  c5 = self.conv5(p4)
  up_6 = self.up6(c5)
  merge6 = torch.cat([up_6, crop4], dim=1)
  c6 = self.conv6(merge6)
  up_7 = self.up7(c6)
  merge7 = torch.cat([up_7, crop3], dim=1)
  c7 = self.conv7(merge7)
  up_8 = self.up8(c7)
  merge8 = torch.cat([up_8, crop2], dim=1)
  c8 = self.conv8(merge8)
  up_9 = self.up9(c8)
  merge9 = torch.cat([up_9, crop1], dim=1)
  c9 = self.conv9(merge9)
  c10 = self.conv10(c9)
  out = nn.Sigmoid()(c10)
  return out

if __name__=="__main__":
 test_input=torch.rand(1, 1, 572, 572)
 model=Unet(in_ch=1, out_ch=2)
 summary(model, (1,572,572))
 ouput=model(test_input)
 print(ouput.size())
----------------------------------------------------------------
  Layer (type)    Output Shape   Param #
================================================================
   Conv2d-1   [-1, 64, 570, 570]    640
  BatchNorm2d-2   [-1, 64, 570, 570]    128
    ReLU-3   [-1, 64, 570, 570]    0
   Conv2d-4   [-1, 64, 568, 568]   36,928
  BatchNorm2d-5   [-1, 64, 568, 568]    128
    ReLU-6   [-1, 64, 568, 568]    0
  DoubleConv-7   [-1, 64, 568, 568]    0
   MaxPool2d-8   [-1, 64, 284, 284]    0
   Conv2d-9  [-1, 128, 282, 282]   73,856
  BatchNorm2d-10  [-1, 128, 282, 282]    256
    ReLU-11  [-1, 128, 282, 282]    0
   Conv2d-12  [-1, 128, 280, 280]   147,584
  BatchNorm2d-13  [-1, 128, 280, 280]    256
    ReLU-14  [-1, 128, 280, 280]    0
  DoubleConv-15  [-1, 128, 280, 280]    0
  MaxPool2d-16  [-1, 128, 140, 140]    0
   Conv2d-17  [-1, 256, 138, 138]   295,168
  BatchNorm2d-18  [-1, 256, 138, 138]    512
    ReLU-19  [-1, 256, 138, 138]    0
   Conv2d-20  [-1, 256, 136, 136]   590,080
  BatchNorm2d-21  [-1, 256, 136, 136]    512
    ReLU-22  [-1, 256, 136, 136]    0
  DoubleConv-23  [-1, 256, 136, 136]    0
  MaxPool2d-24   [-1, 256, 68, 68]    0
   Conv2d-25   [-1, 512, 66, 66]  1,180,160
  BatchNorm2d-26   [-1, 512, 66, 66]   1,024
    ReLU-27   [-1, 512, 66, 66]    0
   Conv2d-28   [-1, 512, 64, 64]  2,359,808
  BatchNorm2d-29   [-1, 512, 64, 64]   1,024
    ReLU-30   [-1, 512, 64, 64]    0
  DoubleConv-31   [-1, 512, 64, 64]    0
  MaxPool2d-32   [-1, 512, 32, 32]    0
   Conv2d-33   [-1, 1024, 30, 30]  4,719,616
  BatchNorm2d-34   [-1, 1024, 30, 30]   2,048
    ReLU-35   [-1, 1024, 30, 30]    0
   Conv2d-36   [-1, 1024, 28, 28]  9,438,208
  BatchNorm2d-37   [-1, 1024, 28, 28]   2,048
    ReLU-38   [-1, 1024, 28, 28]    0
  DoubleConv-39   [-1, 1024, 28, 28]    0
 ConvTranspose2d-40   [-1, 512, 56, 56]  2,097,664
   Conv2d-41   [-1, 512, 54, 54]  4,719,104
  BatchNorm2d-42   [-1, 512, 54, 54]   1,024
    ReLU-43   [-1, 512, 54, 54]    0
   Conv2d-44   [-1, 512, 52, 52]  2,359,808
  BatchNorm2d-45   [-1, 512, 52, 52]   1,024
    ReLU-46   [-1, 512, 52, 52]    0
  DoubleConv-47   [-1, 512, 52, 52]    0
 ConvTranspose2d-48  [-1, 256, 104, 104]   524,544
   Conv2d-49  [-1, 256, 102, 102]  1,179,904
  BatchNorm2d-50  [-1, 256, 102, 102]    512
    ReLU-51  [-1, 256, 102, 102]    0
   Conv2d-52  [-1, 256, 100, 100]   590,080
  BatchNorm2d-53  [-1, 256, 100, 100]    512
    ReLU-54  [-1, 256, 100, 100]    0
  DoubleConv-55  [-1, 256, 100, 100]    0
 ConvTranspose2d-56  [-1, 128, 200, 200]   131,200
   Conv2d-57  [-1, 128, 198, 198]   295,040
  BatchNorm2d-58  [-1, 128, 198, 198]    256
    ReLU-59  [-1, 128, 198, 198]    0
   Conv2d-60  [-1, 128, 196, 196]   147,584
  BatchNorm2d-61  [-1, 128, 196, 196]    256
    ReLU-62  [-1, 128, 196, 196]    0
  DoubleConv-63  [-1, 128, 196, 196]    0
 ConvTranspose2d-64   [-1, 64, 392, 392]   32,832
   Conv2d-65   [-1, 64, 390, 390]   73,792
  BatchNorm2d-66   [-1, 64, 390, 390]    128
    ReLU-67   [-1, 64, 390, 390]    0
   Conv2d-68   [-1, 64, 388, 388]   36,928
  BatchNorm2d-69   [-1, 64, 388, 388]    128
    ReLU-70   [-1, 64, 388, 388]    0
  DoubleConv-71   [-1, 64, 388, 388]    0
   Conv2d-72   [-1, 2, 388, 388]    130
================================================================
Total params: 31,042,434
Trainable params: 31,042,434
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 1.25
Forward/backward pass size (MB): 3280.59
Params size (MB): 118.42
Estimated Total Size (MB): 3400.26
----------------------------------------------------------------
torch.Size([1, 2, 388, 388])

以上这篇使用pytorch实现论文中的unet网络就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python创建临时文件夹的方法
Jul 06 Python
Python实现的计数排序算法示例
Nov 29 Python
Python设计模式之门面模式简单示例
Jan 09 Python
python实现词法分析器
Jan 31 Python
Python发送邮件封装实现过程详解
May 09 Python
keras绘制acc和loss曲线图实例
Jun 15 Python
python中线程和进程有何区别
Jun 17 Python
Django创建一个后台的基本步骤记录
Oct 02 Python
Selenium执行完毕未关闭chromedriver/geckodriver进程的解决办法(java版+python版)
Dec 07 Python
pytorch实现线性回归以及多元回归
Apr 11 Python
python基于opencv批量生成验证码的示例
Apr 28 Python
刚学完怎么用Python实现定时任务,转头就跑去撩妹!
Jun 05 Python
python连接mysql有哪些方法
Jun 24 #Python
pytorch VGG11识别cifar10数据集(训练+预测单张输入图片操作)
Jun 24 #Python
Python Tornado核心及相关原理详解
Jun 24 #Python
如何使用Python处理HDF格式数据及可视化问题
Jun 24 #Python
pytorch SENet实现案例
Jun 24 #Python
利用PyTorch实现VGG16教程
Jun 24 #Python
python安装读取grib库总结(推荐)
Jun 24 #Python
You might like
PHP中GET变量的使用
2006/10/09 PHP
mac下使用brew配置环境的步骤分享
2011/05/23 PHP
php统计时间和内存使用情况示例分享
2014/03/13 PHP
如何使用Gitblog和Markdown建自己的博客
2015/07/31 PHP
php将远程图片保存到本地服务器的实现代码
2015/08/03 PHP
简介WordPress中用于获取首页和站点链接的PHP函数
2015/12/17 PHP
PHP 芝麻信用接入的注意事项
2016/12/01 PHP
浅谈php中curl、fsockopen的应用
2016/12/10 PHP
php 常用的系统函数
2017/02/07 PHP
PHP单文件上传原理及上传函数的封装操作示例
2019/09/02 PHP
JS Range HTML文档/文字内容选中、库及应用介绍
2011/05/12 Javascript
深入Javascript函数、递归与闭包(执行环境、变量对象与作用域链)使用详解
2013/05/08 Javascript
jquery仅用6行代码实现滑动门效果
2015/09/07 Javascript
谈谈JavaScript类型系统之Math
2016/01/06 Javascript
jQuery实现的分子运动小球碰撞效果
2016/01/27 Javascript
微信小程序 网络API Websocket详解
2016/11/09 Javascript
ES6 javascript中Class类继承用法实例详解
2017/10/30 Javascript
vue jsx 使用指南及vue.js 使用jsx语法的方法
2017/11/11 Javascript
详解ES6中的Map与Set集合
2019/03/22 Javascript
LayUi使用switch开关,动态的去控制它是否被启用的方法
2019/09/21 Javascript
详解Vue.js 可拖放文本框组件的使用
2021/03/03 Vue.js
Python数据结构之栈、队列的实现代码分享
2017/12/04 Python
解决Pycharm出现的部分快捷键无效问题
2018/10/22 Python
python中对数据进行各种排序的方法
2019/07/02 Python
python中@contextmanager实例用法
2021/02/07 Python
CSS3中border-radius属性设定圆角的使用技巧
2016/05/10 HTML / CSS
Lee牛仔裤澳大利亚官网:美国著名牛仔裤品牌
2017/09/02 全球购物
英国排名第一的餐具品牌:Denby Pottery
2019/11/01 全球购物
博士毕业生自我鉴定范文
2014/04/13 职场文书
3分钟英语演讲稿
2014/04/29 职场文书
会计电算化专业求职信
2014/06/10 职场文书
乡镇党的群众路线教育实践活动制度建设计划
2014/11/03 职场文书
音乐教师个人工作总结
2015/02/06 职场文书
有关信念的名言语录集锦
2019/12/06 职场文书
解决Windows Server2012 R2 无法安装 .NET Framework 3.5
2022/04/29 Servers
PostgreSQL常用字符串分割函数整理汇总
2022/07/07 PostgreSQL