Keras:Unet网络实现多类语义分割方式


Posted in Python onJune 11, 2020

1 介绍

U-Net最初是用来对医学图像的语义分割,后来也有人将其应用于其他领域。但大多还是用来进行二分类,即将原始图像分成两个灰度级或者色度,依次找到图像中感兴趣的目标部分。

本文主要利用U-Net网络结构实现了多类的语义分割,并展示了部分测试效果,希望对你有用!

2 源代码

(1)训练模型

from __future__ import print_function
import os
import datetime
import numpy as np
from keras.models import Model
from keras.layers import Input, concatenate, Conv2D, MaxPooling2D, Conv2DTranspose, AveragePooling2D, Dropout, \
 BatchNormalization
from keras.optimizers import Adam
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.callbacks import ModelCheckpoint
from keras import backend as K
from keras.layers.advanced_activations import LeakyReLU, ReLU
import cv2
 
PIXEL = 512 #set your image size
BATCH_SIZE = 5
lr = 0.001
EPOCH = 100
X_CHANNEL = 3 # training images channel
Y_CHANNEL = 1 # label iamges channel
X_NUM = 422 # your traning data number
 
pathX = 'I:\\Pascal VOC Dataset\\train1\\images\\' #change your file path
pathY = 'I:\\Pascal VOC Dataset\\train1\\SegmentationObject\\' #change your file path
 
#data processing
def generator(pathX, pathY,BATCH_SIZE):
 while 1:
  X_train_files = os.listdir(pathX)
  Y_train_files = os.listdir(pathY)
  a = (np.arange(1, X_NUM))
  X = []
  Y = []
  for i in range(BATCH_SIZE):
   index = np.random.choice(a)
   # print(index)
   img = cv2.imread(pathX + X_train_files[index], 1)
   img = np.array(img).reshape(PIXEL, PIXEL, X_CHANNEL)
   X.append(img)
   img1 = cv2.imread(pathY + Y_train_files[index], 1)
   img1 = np.array(img1).reshape(PIXEL, PIXEL, Y_CHANNEL)
   Y.append(img1)
 
  X = np.array(X)
  Y = np.array(Y)
  yield X, Y
 
 #creat unet network
inputs = Input((PIXEL, PIXEL, 3))
conv1 = Conv2D(8, 3, activation='relu', padding='same', kernel_initializer='he_normal')(inputs)
pool1 = AveragePooling2D(pool_size=(2, 2))(conv1) # 16
 
conv2 = BatchNormalization(momentum=0.99)(pool1)
conv2 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv2)
conv2 = BatchNormalization(momentum=0.99)(conv2)
conv2 = Conv2D(64, 1, activation='relu', padding='same', kernel_initializer='he_normal')(conv2)
conv2 = Dropout(0.02)(conv2)
pool2 = AveragePooling2D(pool_size=(2, 2))(conv2) # 8
 
conv3 = BatchNormalization(momentum=0.99)(pool2)
conv3 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv3)
conv3 = BatchNormalization(momentum=0.99)(conv3)
conv3 = Conv2D(128, 1, activation='relu', padding='same', kernel_initializer='he_normal')(conv3)
conv3 = Dropout(0.02)(conv3)
pool3 = AveragePooling2D(pool_size=(2, 2))(conv3) # 4
 
conv4 = BatchNormalization(momentum=0.99)(pool3)
conv4 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv4)
conv4 = BatchNormalization(momentum=0.99)(conv4)
conv4 = Conv2D(256, 1, activation='relu', padding='same', kernel_initializer='he_normal')(conv4)
conv4 = Dropout(0.02)(conv4)
pool4 = AveragePooling2D(pool_size=(2, 2))(conv4)
 
conv5 = BatchNormalization(momentum=0.99)(pool4)
conv5 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv5)
conv5 = BatchNormalization(momentum=0.99)(conv5)
conv5 = Conv2D(512, 1, activation='relu', padding='same', kernel_initializer='he_normal')(conv5)
conv5 = Dropout(0.02)(conv5)
pool4 = AveragePooling2D(pool_size=(2, 2))(conv4)
# conv5 = Conv2D(35, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv4)
# drop4 = Dropout(0.02)(conv5)
pool4 = AveragePooling2D(pool_size=(2, 2))(pool3) # 2
pool5 = AveragePooling2D(pool_size=(2, 2))(pool4) # 1
 
conv6 = BatchNormalization(momentum=0.99)(pool5)
conv6 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv6)
 
conv7 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv6)
up7 = (UpSampling2D(size=(2, 2))(conv7)) # 2
conv7 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up7)
merge7 = concatenate([pool4, conv7], axis=3)
 
conv8 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge7)
up8 = (UpSampling2D(size=(2, 2))(conv8)) # 4
conv8 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up8)
merge8 = concatenate([pool3, conv8], axis=3)
 
conv9 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge8)
up9 = (UpSampling2D(size=(2, 2))(conv9)) # 8
conv9 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up9)
merge9 = concatenate([pool2, conv9], axis=3)
 
conv10 = Conv2D(32, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge9)
up10 = (UpSampling2D(size=(2, 2))(conv10)) # 16
conv10 = Conv2D(32, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up10)
 
conv11 = Conv2D(16, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv10)
up11 = (UpSampling2D(size=(2, 2))(conv11)) # 32
conv11 = Conv2D(8, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up11)
 
# conv12 = Conv2D(3, 1, activation='relu', padding='same', kernel_initializer='he_normal')(conv11)
conv12 = Conv2D(3, 1, activation='relu', padding='same', kernel_initializer='he_normal')(conv11)
 
model = Model(input=inputs, output=conv12)
print(model.summary())
model.compile(optimizer=Adam(lr=1e-3), loss='mse', metrics=['accuracy'])
 
history = model.fit_generator(generator(pathX, pathY,BATCH_SIZE),
        steps_per_epoch=600, nb_epoch=EPOCH)
end_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
 
 #save your training model
model.save(r'V1_828.h5')
 
#save your loss data
mse = np.array((history.history['loss']))
np.save(r'V1_828.npy', mse)

(2)测试模型

from keras.models import load_model
import numpy as np
import matplotlib.pyplot as plt
import os
import cv2
 
model = load_model('V1_828.h5')
test_images_path = 'I:\\Pascal VOC Dataset\\test\\test_images\\'
test_gt_path = 'I:\\Pascal VOC Dataset\\test\\SegmentationObject\\'
pre_path = 'I:\\Pascal VOC Dataset\\test\\pre\\'
 
X = []
for info in os.listdir(test_images_path):
 A = cv2.imread(test_images_path + info)
 X.append(A)
 # i += 1
X = np.array(X)
print(X.shape)
Y = model.predict(X)
 
groudtruth = []
for info in os.listdir(test_gt_path):
 A = cv2.imread(test_gt_path + info)
 groudtruth.append(A)
groudtruth = np.array(groudtruth)
 
i = 0
for info in os.listdir(test_images_path):
 cv2.imwrite(pre_path + info,Y[i])
 i += 1
 
a = range(10)
n = np.random.choice(a)
cv2.imwrite('prediction.png',Y[n])
cv2.imwrite('groudtruth.png',groudtruth[n])
fig, axs = plt.subplots(1, 3)
# cnt = 1
# for j in range(1):
axs[0].imshow(np.abs(X[n]))
axs[0].axis('off')
axs[1].imshow(np.abs(Y[n]))
axs[1].axis('off')
axs[2].imshow(np.abs(groudtruth[n]))
axs[2].axis('off')
 # cnt += 1
fig.savefig("imagestest.png")
plt.close()

3 效果展示

说明:从左到右依次是预测图像,真实图像,标注图像。可以看出,对于部分数据的分割效果还有待改进,主要原因还是数据集相对复杂,模型难于找到其中的规律。

Keras:Unet网络实现多类语义分割方式

以上这篇Keras:Unet网络实现多类语义分割方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Python获取Linux系统的各种信息
Jul 10 Python
举例讲解Python设计模式编程中对抽象工厂模式的运用
Mar 02 Python
python3实现SMTP发送邮件详细教程
Jun 19 Python
Python 数值区间处理_对interval 库的快速入门详解
Nov 16 Python
Python自动化运维之Ansible定义主机与组规则操作详解
Jun 13 Python
pandas DataFrame 交集并集补集的实现
Jun 24 Python
sklearn和keras的数据切分与交叉验证的实例详解
Jun 19 Python
python跨文件使用全局变量的实现
Nov 17 Python
python如何用matplotlib创建三维图表
Jan 26 Python
Python页面加载的等待方式总结
Feb 28 Python
Python+uiautomator2实现自动刷抖音视频功能
Apr 29 Python
python利用while求100内的整数和方式
Nov 07 Python
Pycharm中配置远程Docker运行环境的教程图解
Jun 11 #Python
Keras 快速解决OOM超内存的问题
Jun 11 #Python
python3.6.8 + pycharm + PyQt5 环境搭建的图文教程
Jun 11 #Python
使用keras实现孪生网络中的权值共享教程
Jun 11 #Python
查看keras各种网络结构各层的名字方式
Jun 11 #Python
python datetime时间格式的相互转换问题
Jun 11 #Python
完美解决keras保存好的model不能成功加载问题
Jun 11 #Python
You might like
php 模拟 asp.net webFrom 按钮提交事件的思路及代码
2013/12/02 PHP
php GUID生成函数和类
2014/03/10 PHP
PHP中读取文件的8种方法和代码实例
2014/08/05 PHP
解决php extension 加载顺序问题
2019/08/16 PHP
laravel model模型定义实现开启自动管理时间created_at,updated_at
2019/10/17 PHP
table对象中的insertRow与deleteRow使用示例
2014/01/26 Javascript
jQuery中$.ajax()和$.getJson()同步处理详解
2015/08/12 Javascript
jQuery源码分析之init的详细介绍
2017/02/13 Javascript
jquery中each循环的简单回滚操作
2017/05/05 jQuery
ES6学习教程之Map的常用方法总结
2017/08/03 Javascript
代码详解javascript模块加载器
2018/03/04 Javascript
Angular CLI在Angular项目中如何使用scss详解
2018/04/10 Javascript
vue.js做一个简单的编辑菜谱功能
2018/05/08 Javascript
Js经典案例的实例代码
2018/05/10 Javascript
element-ui的回调函数Events的用法详解
2018/10/16 Javascript
使用VUE+iView+.Net Core上传图片的方法示例
2019/01/04 Javascript
你不可不知的Vue.js列表渲染详解
2019/10/01 Javascript
Vue3.0中的monorepo管理模式的实现
2019/10/14 Javascript
python 示例分享---逻辑推理编程解决八皇后
2014/07/20 Python
Python装饰器使用实例:验证参数合法性
2015/06/24 Python
numpy中实现二维数组按照某列、某行排序的方法
2018/04/04 Python
Python+OpenCV感兴趣区域ROI提取方法
2019/01/10 Python
numpy.meshgrid()理解(小结)
2019/08/01 Python
python将字符串转变成dict格式的实现
2019/11/18 Python
numpy实现神经网络反向传播算法的步骤
2019/12/24 Python
python实现logistic分类算法代码
2020/02/28 Python
Python 调用 ES、Solr、Phoenix的示例代码
2020/11/23 Python
PyCharm Community安装与配置的详细教程
2020/11/24 Python
Selenium Webdriver元素定位的八种常用方式(小结)
2021/01/13 Python
Smashbox英国官网:美国知名彩妆品牌
2017/11/13 全球购物
旧时光糖果:Old Time Candy
2018/02/05 全球购物
放假通知怎么写
2015/08/18 职场文书
宝宝满月宴答谢词
2015/09/30 职场文书
2016年春季开学典礼新闻稿
2015/11/25 职场文书
Pillow图像处理库安装及使用
2022/04/12 Python
JS前端轻量fabric.js系列物体基类
2022/08/05 Javascript