Keras实现DenseNet结构操作


Posted in Python onJuly 06, 2020

DenseNet结构在16年由Huang Gao和Liu Zhuang等人提出,并且在CVRP2017中被评为最佳论文。网络的核心结构为如下所示的Dense块,在每一个Dense块中,存在多个Dense层,即下图所示的H1~H4。各Dense层之间彼此均相互连接,即H1的输入为x0,输出为x1,H2的输入即为[x0, x1],输出为x2,依次类推。最终Dense块的输出即为[x0, x1, x2, x3, x4]。这种结构个人感觉非常类似生物学里边的神经元连接方式,应该能够比较有效的提高了网络中特征信息的利用效率。

Keras实现DenseNet结构操作

DenseNet的其他结构就非常类似一般的卷积神经网络结构了,可以参考论文中提供的网路结构图(下图)。但是个人感觉,DenseNet的这种结构应该是存在进一步的优化方法的,比如可能不一定需要在Dense块中对每一个Dense层均直接进行相互连接,来缩小网络的结构;也可能可以在不相邻的Dense块之间通过简单的下采样操作进行连接,进一步提升网络对不同尺度的特征的利用效率。

Keras实现DenseNet结构操作

由于DenseNet的密集连接方式,在构建一个相同容量的网络时其所需的参数数量远小于其之前提出的如resnet等结构。进一步,个人感觉应该可以把Dense块看做对一个有较多参数的卷积层的高效替代。因此,其也可以结合U-Net等网络结构,来进一步优化网络性能,比如单纯的把U-net中的所有卷积层全部换成DenseNet的结构,就可以显著压缩网络大小。

下面基于Keras实现DenseNet-BC结构。首先定义Dense层,根据论文描述构建如下:

def DenseLayer(x, nb_filter, bn_size=4, alpha=0.0, drop_rate=0.2):
 
 # Bottleneck layers
 x = BatchNormalization(axis=3)(x)
 x = LeakyReLU(alpha=alpha)(x)
 x = Conv2D(bn_size*nb_filter, (1, 1), strides=(1,1), padding='same')(x)
 
 # Composite function
 x = BatchNormalization(axis=3)(x)
 x = LeakyReLU(alpha=alpha)(x)
 x = Conv2D(nb_filter, (3, 3), strides=(1,1), padding='same')(x)
 
 if drop_rate: x = Dropout(drop_rate)(x)
 
 return x

论文原文中提出使用1*1卷积核的卷积层作为bottleneck层来优化计算效率。原文中使用的激活函数全部为relu,但个人习惯是用leakyrelu进行构建,来方便调参。

之后是用Dense层搭建Dense块,如下:

def DenseBlock(x, nb_layers, growth_rate, drop_rate=0.2):
 
 for ii in range(nb_layers):
  conv = DenseLayer(x, nb_filter=growth_rate, drop_rate=drop_rate)
  x = concatenate([x, conv], axis=3)
 return x

如论文中所述,将每一个Dense层的输出与其输入融合之后作为下一Dense层的输入,来实现密集连接。

最后是各Dense块之间的过渡层,如下:

def TransitionLayer(x, compression=0.5, alpha=0.0, is_max=0):
 
 nb_filter = int(x.shape.as_list()[-1]*compression)
 x = BatchNormalization(axis=3)(x)
 x = LeakyReLU(alpha=alpha)(x)
 x = Conv2D(nb_filter, (1, 1), strides=(1,1), padding='same')(x)
 if is_max != 0: x = MaxPooling2D(pool_size=(2, 2), strides=2)(x)
 else: x = AveragePooling2D(pool_size=(2, 2), strides=2)(x)
 
 return x

论文中提出使用均值池化层来作下采样,不过在边缘特征提取方面,最大池化层效果应该更好,这里就加了相关接口。

将上述结构按照论文中提出的结构进行拼接,这里选择的参数是论文中提到的L=100,k=12,网络连接如下:

growth_rate = 12
inpt = Input(shape=(32,32,3))
 
x = Conv2D(growth_rate*2, (3, 3), strides=1, padding='same')(inpt)
x = BatchNormalization(axis=3)(x)
x = LeakyReLU(alpha=0.1)(x)
x = DenseBlock(x, 12, growth_rate, drop_rate=0.2)
x = TransitionLayer(x)
x = DenseBlock(x, 12, growth_rate, drop_rate=0.2)
x = TransitionLayer(x)
x = DenseBlock(x, 12, growth_rate, drop_rate=0.2)
x = BatchNormalization(axis=3)(x)
x = GlobalAveragePooling2D()(x)
x = Dense(10, activation='softmax')(x)
 
model = Model(inpt, x)
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.summary()

虽然我们已经完成了网络的架设,网络本身的参数数量也仅有0.5M,但由于以这种方式实现的网络在Dense块中,每一次concat均需要开辟一组全新的内存空间,导致实际需要的内存空间非常大。作者在17年的时候,还专门写了相关的技术报告:https://arxiv.org/abs/1707.06990来说明怎么节省内存空间,不过单纯用keras实现起来是比较麻烦。下一篇博客中将以pytorch框架来对其进行实现。

最后放出网络完整代码:

import numpy as np
import keras
from keras.models import Model, save_model, load_model
from keras.layers import Input, Dense, Dropout, BatchNormalization, LeakyReLU, concatenate
from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D, GlobalAveragePooling2D
 
## data
import pickle
 
data_batch_1 = pickle.load(open("cifar-10-batches-py/data_batch_1", 'rb'), encoding='bytes')
data_batch_2 = pickle.load(open("cifar-10-batches-py/data_batch_2", 'rb'), encoding='bytes')
data_batch_3 = pickle.load(open("cifar-10-batches-py/data_batch_3", 'rb'), encoding='bytes')
data_batch_4 = pickle.load(open("cifar-10-batches-py/data_batch_4", 'rb'), encoding='bytes')
data_batch_5 = pickle.load(open("cifar-10-batches-py/data_batch_5", 'rb'), encoding='bytes')
 
train_X_1 = data_batch_1[b'data']
train_X_1 = train_X_1.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
train_Y_1 = data_batch_1[b'labels']
 
train_X_2 = data_batch_2[b'data']
train_X_2 = train_X_2.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
train_Y_2 = data_batch_2[b'labels']
 
train_X_3 = data_batch_3[b'data']
train_X_3 = train_X_3.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
train_Y_3 = data_batch_3[b'labels']
 
train_X_4 = data_batch_4[b'data']
train_X_4 = train_X_4.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
train_Y_4 = data_batch_4[b'labels']
 
train_X_5 = data_batch_5[b'data']
train_X_5 = train_X_5.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
train_Y_5 = data_batch_5[b'labels']
 
train_X = np.row_stack((train_X_1, train_X_2))
train_X = np.row_stack((train_X, train_X_3))
train_X = np.row_stack((train_X, train_X_4))
train_X = np.row_stack((train_X, train_X_5))
 
train_Y = np.row_stack((train_Y_1, train_Y_2))
train_Y = np.row_stack((train_Y, train_Y_3))
train_Y = np.row_stack((train_Y, train_Y_4))
train_Y = np.row_stack((train_Y, train_Y_5))
train_Y = train_Y.reshape(50000, 1).transpose(0, 1).astype("int32")
train_Y = keras.utils.to_categorical(train_Y)
 
test_batch = pickle.load(open("cifar-10-batches-py/test_batch", 'rb'), encoding='bytes')
test_X = test_batch[b'data']
test_X = test_X.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
test_Y = test_batch[b'labels']
test_Y = keras.utils.to_categorical(test_Y)
 
train_X /= 255
test_X /= 255
 
# model
 
def DenseLayer(x, nb_filter, bn_size=4, alpha=0.0, drop_rate=0.2):
 
 # Bottleneck layers
 x = BatchNormalization(axis=3)(x)
 x = LeakyReLU(alpha=alpha)(x)
 x = Conv2D(bn_size*nb_filter, (1, 1), strides=(1,1), padding='same')(x)
 
 # Composite function
 x = BatchNormalization(axis=3)(x)
 x = LeakyReLU(alpha=alpha)(x)
 x = Conv2D(nb_filter, (3, 3), strides=(1,1), padding='same')(x)
 
 if drop_rate: x = Dropout(drop_rate)(x)
 
 return x
 
def DenseBlock(x, nb_layers, growth_rate, drop_rate=0.2):
 
 for ii in range(nb_layers):
  conv = DenseLayer(x, nb_filter=growth_rate, drop_rate=drop_rate)
  x = concatenate([x, conv], axis=3)
  
 return x
 
def TransitionLayer(x, compression=0.5, alpha=0.0, is_max=0):
 
 nb_filter = int(x.shape.as_list()[-1]*compression)
 x = BatchNormalization(axis=3)(x)
 x = LeakyReLU(alpha=alpha)(x)
 x = Conv2D(nb_filter, (1, 1), strides=(1,1), padding='same')(x)
 if is_max != 0: x = MaxPooling2D(pool_size=(2, 2), strides=2)(x)
 else: x = AveragePooling2D(pool_size=(2, 2), strides=2)(x)
 
 return x
 
growth_rate = 12
 
inpt = Input(shape=(32,32,3))
 
x = Conv2D(growth_rate*2, (3, 3), strides=1, padding='same')(inpt)
x = BatchNormalization(axis=3)(x)
x = LeakyReLU(alpha=0.1)(x)
x = DenseBlock(x, 12, growth_rate, drop_rate=0.2)
x = TransitionLayer(x)
x = DenseBlock(x, 12, growth_rate, drop_rate=0.2)
x = TransitionLayer(x)
x = DenseBlock(x, 12, growth_rate, drop_rate=0.2)
x = BatchNormalization(axis=3)(x)
x = GlobalAveragePooling2D()(x)
x = Dense(10, activation='softmax')(x)
 
model = Model(inpt, x)
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
 
model.summary()
 
for ii in range(10):
 print("Epoch:", ii+1)
 model.fit(train_X, train_Y, batch_size=100, epochs=1, verbose=1)
 score = model.evaluate(test_X, test_Y, verbose=1)
 print('Test loss =', score[0])
 print('Test accuracy =', score[1])
 
save_model(model, 'DenseNet.h5')
model = load_model('DenseNet.h5')
 
pred_Y = model.predict(test_X)
score = model.evaluate(test_X, test_Y, verbose=0)
print('Test loss =', score[0])
print('Test accuracy =', score[1])

以上这篇Keras实现DenseNet结构操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python list使用示例 list中找连续的数字
Jan 27 Python
python开发之thread实现布朗运动的方法
Nov 11 Python
python使用arcpy.mapping模块批量出图
Mar 06 Python
python 中的int()函数怎么用
Oct 17 Python
Django2.1.3 中间件使用详解
Nov 26 Python
Python线程池模块ThreadPoolExecutor用法分析
Dec 28 Python
python 动态迁移solr数据过程解析
Sep 04 Python
Python测试Kafka集群(pykafka)实例
Dec 23 Python
快速解释如何使用pandas的inplace参数的使用
Jul 23 Python
Python控制鼠标键盘代码实例
Dec 08 Python
Python Selenium异常处理的实例分析
Feb 28 Python
python和anaconda的区别
May 06 Python
基于Python和C++实现删除链表的节点
Jul 06 #Python
基于Python 的语音重采样函数解析
Jul 06 #Python
python interpolate插值实例
Jul 06 #Python
基于Python实现2种反转链表方法代码实例
Jul 06 #Python
简单了解Django项目应用创建过程
Jul 06 #Python
如何在mac下配置python虚拟环境
Jul 06 #Python
Python优秀开源项目Rich源码解析的流程分析
Jul 06 #Python
You might like
二十行语句实现从Excel到mysql的转化
2006/10/09 PHP
简单的移动设备检测PHP脚本代码
2011/02/19 PHP
php中取得文件的后缀名?
2012/02/20 PHP
PHP mysqli事务操作常用方法分析
2017/07/22 PHP
js模拟实现Array的sort方法
2007/12/11 Javascript
jquery 得到当前页面高度和宽度的两个函数
2010/02/21 Javascript
基于Jquery的标签智能验证实现代码
2010/12/27 Javascript
有关JavaScript的10个怪癖和秘密分享
2011/08/28 Javascript
Javascript new Date().valueOf()的作用与时间戳由来详解
2013/04/24 Javascript
Javascript中Event属性搜集整理
2013/09/17 Javascript
JS简单实现元素复制示例附图
2013/11/19 Javascript
JavaScript事件委托的技术原理探讨示例
2014/04/17 Javascript
JS实现颜色梯度与渐变效果完整实例
2016/12/30 Javascript
Vue ElementUi同时校验多个表单(巧用new promise)
2018/06/06 Javascript
移动端图片上传旋转、压缩问题的方法
2018/10/16 Javascript
详细讲解如何创建, 发布自己的 Vue UI 组件库
2019/05/29 Javascript
详解js中的原型,原型对象,原型链
2020/07/16 Javascript
JavaScript实现轮播图效果
2020/10/30 Javascript
关于angular 8.1使用过程中的一些记录
2020/11/25 Javascript
[53:52]EG vs VGJ.T 2018国际邀请赛小组赛BO2 第一场 8.16
2018/08/17 DOTA
python设置windows桌面壁纸的实现代码
2013/01/28 Python
python重试装饰器示例
2014/02/11 Python
详解从Django Rest Framework响应中删除空字段
2019/01/11 Python
Python Pandas 对列/行进行选择,增加,删除操作
2020/05/17 Python
Python中random模块常用方法的使用教程
2020/10/04 Python
python Gabor滤波器讲解
2020/10/26 Python
python 爬虫如何正确的使用cookie
2020/10/27 Python
Expedia丹麦:全球领先的旅游网站
2018/03/18 全球购物
时尚设计师手表:The Watch Cabin
2018/10/06 全球购物
美国二手复古奢侈品包包购物网站:LXRandCo
2019/06/18 全球购物
会计毕业生自荐信
2013/11/21 职场文书
应届生求职信范文
2014/06/30 职场文书
实习单位证明范例
2014/11/17 职场文书
搭讪开场白台词大全
2015/05/28 职场文书
OpenCV-Python使用cv2实现傅里叶变换
2021/06/09 Python
Python内置数据结构列表与元组示例详解
2021/08/04 Python