Keras:Unet网络实现多类语义分割方式


Posted in Python onJune 11, 2020

1 介绍

U-Net最初是用来对医学图像的语义分割,后来也有人将其应用于其他领域。但大多还是用来进行二分类,即将原始图像分成两个灰度级或者色度,依次找到图像中感兴趣的目标部分。

本文主要利用U-Net网络结构实现了多类的语义分割,并展示了部分测试效果,希望对你有用!

2 源代码

(1)训练模型

from __future__ import print_function
import os
import datetime
import numpy as np
from keras.models import Model
from keras.layers import Input, concatenate, Conv2D, MaxPooling2D, Conv2DTranspose, AveragePooling2D, Dropout, \
 BatchNormalization
from keras.optimizers import Adam
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.callbacks import ModelCheckpoint
from keras import backend as K
from keras.layers.advanced_activations import LeakyReLU, ReLU
import cv2
 
PIXEL = 512 #set your image size
BATCH_SIZE = 5
lr = 0.001
EPOCH = 100
X_CHANNEL = 3 # training images channel
Y_CHANNEL = 1 # label iamges channel
X_NUM = 422 # your traning data number
 
pathX = 'I:\\Pascal VOC Dataset\\train1\\images\\' #change your file path
pathY = 'I:\\Pascal VOC Dataset\\train1\\SegmentationObject\\' #change your file path
 
#data processing
def generator(pathX, pathY,BATCH_SIZE):
 while 1:
  X_train_files = os.listdir(pathX)
  Y_train_files = os.listdir(pathY)
  a = (np.arange(1, X_NUM))
  X = []
  Y = []
  for i in range(BATCH_SIZE):
   index = np.random.choice(a)
   # print(index)
   img = cv2.imread(pathX + X_train_files[index], 1)
   img = np.array(img).reshape(PIXEL, PIXEL, X_CHANNEL)
   X.append(img)
   img1 = cv2.imread(pathY + Y_train_files[index], 1)
   img1 = np.array(img1).reshape(PIXEL, PIXEL, Y_CHANNEL)
   Y.append(img1)
 
  X = np.array(X)
  Y = np.array(Y)
  yield X, Y
 
 #creat unet network
inputs = Input((PIXEL, PIXEL, 3))
conv1 = Conv2D(8, 3, activation='relu', padding='same', kernel_initializer='he_normal')(inputs)
pool1 = AveragePooling2D(pool_size=(2, 2))(conv1) # 16
 
conv2 = BatchNormalization(momentum=0.99)(pool1)
conv2 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv2)
conv2 = BatchNormalization(momentum=0.99)(conv2)
conv2 = Conv2D(64, 1, activation='relu', padding='same', kernel_initializer='he_normal')(conv2)
conv2 = Dropout(0.02)(conv2)
pool2 = AveragePooling2D(pool_size=(2, 2))(conv2) # 8
 
conv3 = BatchNormalization(momentum=0.99)(pool2)
conv3 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv3)
conv3 = BatchNormalization(momentum=0.99)(conv3)
conv3 = Conv2D(128, 1, activation='relu', padding='same', kernel_initializer='he_normal')(conv3)
conv3 = Dropout(0.02)(conv3)
pool3 = AveragePooling2D(pool_size=(2, 2))(conv3) # 4
 
conv4 = BatchNormalization(momentum=0.99)(pool3)
conv4 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv4)
conv4 = BatchNormalization(momentum=0.99)(conv4)
conv4 = Conv2D(256, 1, activation='relu', padding='same', kernel_initializer='he_normal')(conv4)
conv4 = Dropout(0.02)(conv4)
pool4 = AveragePooling2D(pool_size=(2, 2))(conv4)
 
conv5 = BatchNormalization(momentum=0.99)(pool4)
conv5 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv5)
conv5 = BatchNormalization(momentum=0.99)(conv5)
conv5 = Conv2D(512, 1, activation='relu', padding='same', kernel_initializer='he_normal')(conv5)
conv5 = Dropout(0.02)(conv5)
pool4 = AveragePooling2D(pool_size=(2, 2))(conv4)
# conv5 = Conv2D(35, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv4)
# drop4 = Dropout(0.02)(conv5)
pool4 = AveragePooling2D(pool_size=(2, 2))(pool3) # 2
pool5 = AveragePooling2D(pool_size=(2, 2))(pool4) # 1
 
conv6 = BatchNormalization(momentum=0.99)(pool5)
conv6 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv6)
 
conv7 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv6)
up7 = (UpSampling2D(size=(2, 2))(conv7)) # 2
conv7 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up7)
merge7 = concatenate([pool4, conv7], axis=3)
 
conv8 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge7)
up8 = (UpSampling2D(size=(2, 2))(conv8)) # 4
conv8 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up8)
merge8 = concatenate([pool3, conv8], axis=3)
 
conv9 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge8)
up9 = (UpSampling2D(size=(2, 2))(conv9)) # 8
conv9 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up9)
merge9 = concatenate([pool2, conv9], axis=3)
 
conv10 = Conv2D(32, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge9)
up10 = (UpSampling2D(size=(2, 2))(conv10)) # 16
conv10 = Conv2D(32, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up10)
 
conv11 = Conv2D(16, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv10)
up11 = (UpSampling2D(size=(2, 2))(conv11)) # 32
conv11 = Conv2D(8, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up11)
 
# conv12 = Conv2D(3, 1, activation='relu', padding='same', kernel_initializer='he_normal')(conv11)
conv12 = Conv2D(3, 1, activation='relu', padding='same', kernel_initializer='he_normal')(conv11)
 
model = Model(input=inputs, output=conv12)
print(model.summary())
model.compile(optimizer=Adam(lr=1e-3), loss='mse', metrics=['accuracy'])
 
history = model.fit_generator(generator(pathX, pathY,BATCH_SIZE),
        steps_per_epoch=600, nb_epoch=EPOCH)
end_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
 
 #save your training model
model.save(r'V1_828.h5')
 
#save your loss data
mse = np.array((history.history['loss']))
np.save(r'V1_828.npy', mse)

(2)测试模型

from keras.models import load_model
import numpy as np
import matplotlib.pyplot as plt
import os
import cv2
 
model = load_model('V1_828.h5')
test_images_path = 'I:\\Pascal VOC Dataset\\test\\test_images\\'
test_gt_path = 'I:\\Pascal VOC Dataset\\test\\SegmentationObject\\'
pre_path = 'I:\\Pascal VOC Dataset\\test\\pre\\'
 
X = []
for info in os.listdir(test_images_path):
 A = cv2.imread(test_images_path + info)
 X.append(A)
 # i += 1
X = np.array(X)
print(X.shape)
Y = model.predict(X)
 
groudtruth = []
for info in os.listdir(test_gt_path):
 A = cv2.imread(test_gt_path + info)
 groudtruth.append(A)
groudtruth = np.array(groudtruth)
 
i = 0
for info in os.listdir(test_images_path):
 cv2.imwrite(pre_path + info,Y[i])
 i += 1
 
a = range(10)
n = np.random.choice(a)
cv2.imwrite('prediction.png',Y[n])
cv2.imwrite('groudtruth.png',groudtruth[n])
fig, axs = plt.subplots(1, 3)
# cnt = 1
# for j in range(1):
axs[0].imshow(np.abs(X[n]))
axs[0].axis('off')
axs[1].imshow(np.abs(Y[n]))
axs[1].axis('off')
axs[2].imshow(np.abs(groudtruth[n]))
axs[2].axis('off')
 # cnt += 1
fig.savefig("imagestest.png")
plt.close()

3 效果展示

说明:从左到右依次是预测图像,真实图像,标注图像。可以看出,对于部分数据的分割效果还有待改进,主要原因还是数据集相对复杂,模型难于找到其中的规律。

Keras:Unet网络实现多类语义分割方式

以上这篇Keras:Unet网络实现多类语义分割方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中__call__方法示例分析
Oct 11 Python
在Python中处理列表之reverse()方法的使用教程
May 21 Python
python获取外网ip地址的方法总结
Jul 02 Python
Python IDLE入门简介
Dec 08 Python
利用Python在一个文件的头部插入数据的实例
May 02 Python
对python:print打印时加u的含义详解
Dec 15 Python
浅析Python 引号、注释、字符串
Jul 25 Python
python向图片里添加文字
Nov 26 Python
Python通过Pillow实现图片对比
Apr 29 Python
基于python和flask实现http接口过程解析
Jun 15 Python
一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系
Jul 03 Python
pycharm最新激活码有效期至2100年(亲测可用)
Feb 05 Python
Pycharm中配置远程Docker运行环境的教程图解
Jun 11 #Python
Keras 快速解决OOM超内存的问题
Jun 11 #Python
python3.6.8 + pycharm + PyQt5 环境搭建的图文教程
Jun 11 #Python
使用keras实现孪生网络中的权值共享教程
Jun 11 #Python
查看keras各种网络结构各层的名字方式
Jun 11 #Python
python datetime时间格式的相互转换问题
Jun 11 #Python
完美解决keras保存好的model不能成功加载问题
Jun 11 #Python
You might like
第十节--抽象方法和抽象类
2006/11/16 PHP
php 上一篇,下一篇文章实现代码与原理说明
2010/05/09 PHP
服务器迁移php版本不同可能诱发的问题
2015/12/22 PHP
轻松掌握php设计模式之访问者模式
2016/09/23 PHP
thinkPHP5 tablib标签库自定义方法详解
2017/05/10 PHP
js常用排序实现代码
2010/12/28 Javascript
js中关于String对象的replace使用详解
2011/05/24 Javascript
javascript判断office版本示例
2014/04/11 Javascript
深入理解JavaScript系列(31):设计模式之代理模式详解
2015/03/03 Javascript
jquery跟随屏幕滚动效果的实现代码
2016/04/13 Javascript
JavaScript判断微信浏览器实例代码
2016/06/13 Javascript
jquery自动补齐功能插件flexselect用法示例
2016/08/06 Javascript
JavaScript 随机验证码的生成实例代码
2016/09/22 Javascript
Java  Spring 事务回滚详解
2016/10/17 Javascript
20行JS代码实现网页刮刮乐效果
2017/06/23 Javascript
微信小程序 上传头像的实例详解
2017/10/27 Javascript
VueJs使用Amaze ui调整列表和内容页面
2017/11/30 Javascript
Vue导出json数据到Excel电子表格的示例
2017/12/04 Javascript
js实现移动端tab切换时下划线滑动效果
2019/09/08 Javascript
关于在LayUI中使用AJAX提交巨坑记录
2019/10/25 Javascript
在vue中使用inheritAttrs实现组件的扩展性介绍
2020/12/07 Vue.js
在Python中操作文件之truncate()方法的使用教程
2015/05/25 Python
使用Python神器对付12306变态验证码
2016/01/05 Python
基于Python 的进程管理工具supervisor使用指南
2016/09/18 Python
Django实现学员管理系统
2019/02/26 Python
pytorch多进程加速及代码优化方法
2019/08/19 Python
python实现批量命名照片
2020/06/18 Python
解决Keras中CNN输入维度报错问题
2020/06/29 Python
Python 实现一个计时器
2020/07/28 Python
python实现简单文件读写函数
2021/02/25 Python
用纯css3实现的图片放大镜特效效果非常不错
2014/09/02 HTML / CSS
解决img标签上下出现间隙的方法
2016/12/14 HTML / CSS
STAUD官方网站:洛杉矶独有的闲适风格
2019/04/11 全球购物
大班开学家长寄语
2014/04/04 职场文书
医生个人年度总结
2015/02/28 职场文书
教你使用RustDesk 搭建一个自己的远程桌面中继服务器
2022/08/14 Servers