使用keras实现densenet和Xception的模型融合


Posted in Python onMay 23, 2020

我正在参加天池上的一个竞赛,刚开始用的是DenseNet121但是效果没有达到预期,因此开始尝试使用模型融合,将Desenet和Xception融合起来共同提取特征。

代码如下:

def Multimodel(cnn_weights_path=None,all_weights_path=None,class_num=5,cnn_no_vary=False):
	'''
	获取densent121,xinception并联的网络
	此处的cnn_weights_path是个列表是densenet和xception的卷积部分的权值
	'''
	input_layer=Input(shape=(224,224,3))
	dense=DenseNet121(include_top=False,weights=None,input_shape=(224,224,3))
	xception=Xception(include_top=False,weights=None,input_shape=(224,224,3))
	#res=ResNet50(include_top=False,weights=None,input_shape=(224,224,3))

	if cnn_no_vary:
		for i,layer in enumerate(dense.layers):
			dense.layers[i].trainable=False
		for i,layer in enumerate(xception.layers):
			xception.layers[i].trainable=False
		#for i,layer in enumerate(res.layers):
		#	res.layers[i].trainable=False
 
	if cnn_weights_path!=None:
		dense.load_weights(cnn_weights_path[0])
		xception.load_weights(cnn_weights_path[1])
		#res.load_weights(cnn_weights_path[2])
	dense=dense(input_layer)
	xception=xception(input_layer)

	#对dense_121和xception进行全局最大池化
	top1_model=GlobalMaxPooling2D(data_format='channels_last')(dense)
	top2_model=GlobalMaxPooling2D(data_format='channels_last')(xception)
	#top3_model=GlobalMaxPool2D(input_shape=res.output_shape)(res.outputs[0])
	
	print(top1_model.shape,top2_model.shape)
	#把top1_model和top2_model连接起来
	t=keras.layers.Concatenate(axis=1)([top1_model,top2_model])
	#第一个全连接层
	top_model=Dense(units=512,activation="relu")(t)
	top_model=Dropout(rate=0.5)(top_model)
	top_model=Dense(units=class_num,activation="softmax")(top_model)
	
	model=Model(inputs=input_layer,outputs=top_model)
 
	#加载全部的参数
	if all_weights_path:
		model.load_weights(all_weights_path)
	return model

如下进行调用:

if __name__=="__main__":
 weights_path=["./densenet121_weights_tf_dim_ordering_tf_kernels_notop.h5",
 "xception_weights_tf_dim_ordering_tf_kernels_notop.h5"]
 model=Multimodel(cnn_weights_path=weights_path,class_num=6)
 plot_model(model,to_file="G:/model.png")

最后生成的模型图如下:有点长,可以不看

使用keras实现densenet和Xception的模型融合

需要注意的一点是,如果dense=dense(input_layer)这里报错的话,说明你用的是tensorflow1.4以下的版本,解决的方法就是

1、升级tensorflow到1.4以上

2、改代码:

def Multimodel(cnn_weights_path=None,all_weights_path=None,class_num=5,cnn_no_vary=False):
	'''
	获取densent121,xinception并联的网络
	此处的cnn_weights_path是个列表是densenet和xception的卷积部分的权值
	'''
	dir=os.getcwd()
	input_layer=Input(shape=(224,224,3))
	
	dense=DenseNet121(include_top=False,weights=None,input_tensor=input_layer,
		input_shape=(224,224,3))
	xception=Xception(include_top=False,weights=None,input_tensor=input_layer,
		input_shape=(224,224,3))
	#res=ResNet50(include_top=False,weights=None,input_shape=(224,224,3))
 
	if cnn_no_vary:
		for i,layer in enumerate(dense.layers):
			dense.layers[i].trainable=False
		for i,layer in enumerate(xception.layers):
			xception.layers[i].trainable=False
		#for i,layer in enumerate(res.layers):
		#	res.layers[i].trainable=False
	if cnn_weights_path!=None:
		dense.load_weights(cnn_weights_path[0])
		xception.load_weights(cnn_weights_path[1])
 
	#print(dense.shape,xception.shape)
	#对dense_121和xception进行全局最大池化
	top1_model=GlobalMaxPooling2D(input_shape=(7,7,1024),data_format='channels_last')(dense.output)
	top2_model=GlobalMaxPooling2D(input_shape=(7,7,1024),data_format='channels_last')(xception.output)
	#top3_model=GlobalMaxPool2D(input_shape=res.output_shape)(res.outputs[0])
	
	print(top1_model.shape,top2_model.shape)
	#把top1_model和top2_model连接起来
	t=keras.layers.Concatenate(axis=1)([top1_model,top2_model])
	#第一个全连接层
	top_model=Dense(units=512,activation="relu")(t)
	top_model=Dropout(rate=0.5)(top_model)
	top_model=Dense(units=class_num,activation="softmax")(top_model)
	
	model=Model(inputs=input_layer,outputs=top_model)
 
	#加载全部的参数
	if all_weights_path:
		model.load_weights(all_weights_path)
	return model

这个bug我也是在服务器上跑的时候才出现的,找了半天,而实验室的cuda和cudnn又改不了,tensorflow无法升级,因此只能改代码了。

如下所示,是最后画出的模型图:(很长,底下没内容了)

使用keras实现densenet和Xception的模型融合

以上这篇使用keras实现densenet和Xception的模型融合就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python静态方法实例
Jan 14 Python
浅谈python多线程和队列管理shell程序
Aug 04 Python
Python中import导入上一级目录模块及循环import问题的解决
Jun 04 Python
详解 Python 与文件对象共事的实例
Sep 11 Python
python 3.5实现检测路由器流量并写入txt的方法实例
Dec 17 Python
python 读取视频,处理后,实时计算帧数fps的方法
Jul 10 Python
Python基础之循环语句用法示例【for、while循环】
Mar 23 Python
Django 过滤器汇总及自定义过滤器使用详解
Jul 19 Python
pandas将多个dataframe以多个sheet的形式保存到一个excel文件中
Oct 10 Python
python3-flask-3将信息写入日志的实操方法
Nov 12 Python
解决Python图形界面中设置尺寸的问题
Mar 05 Python
Python实现双向链表基本操作
May 25 Python
在keras下实现多个模型的融合方式
May 23 #Python
Keras使用ImageNet上预训练的模型方式
May 23 #Python
使用Keras预训练模型ResNet50进行图像分类方式
May 23 #Python
基于Python中random.sample()的替代方案
May 23 #Python
keras 自定义loss损失函数,sample在loss上的加权和metric详解
May 23 #Python
keras中模型训练class_weight,sample_weight区别说明
May 23 #Python
浅谈keras中的Merge层(实现层的相加、相减、相乘实例)
May 23 #Python
You might like
php生成随机密码的三种方法小结
2010/09/04 PHP
PHP与jquery实时显示网站在线人数实例详解
2016/12/02 PHP
如何在标题栏显示框架内页面的标题
2007/02/03 Javascript
Jquery Select操作方法集合脚本之家特别版
2010/05/17 Javascript
ajax的hide隐藏问题解决方法
2012/12/11 Javascript
Js 导出table内容到Excel的简单实例
2013/11/19 Javascript
javascript使用定时函数实现跳转到某个页面
2013/12/25 Javascript
Ubuntu中搭建Nodejs开发环境过程分享
2014/06/01 NodeJs
浅谈setTimeout 与 setInterval
2015/06/23 Javascript
layui table单元格事件修改值的方法
2019/09/24 Javascript
layer.open提交子页面的form和layedit文本编辑内容的方法
2019/09/27 Javascript
vue实现几秒后跳转新页面代码
2020/09/09 Javascript
如何在JavaScript中正确处理变量
2020/12/25 Javascript
[00:12]DAC2018 Miracle-站上中单舞台,他能否再写奇迹?
2018/04/06 DOTA
python翻译软件实现代码(使用google api完成)
2013/11/26 Python
在Django的session中使用User对象的方法
2015/07/23 Python
浅谈python脚本设置运行参数的方法
2018/12/03 Python
python如何获取当前文件夹下所有文件名详解
2019/01/25 Python
对dataframe数据之间求补集的实例详解
2019/01/30 Python
Python使用numpy模块实现矩阵和列表的连接操作方法
2019/06/26 Python
python3 enum模块的应用实例详解
2019/08/12 Python
Python GUI库PyQt5样式QSS子控件介绍
2020/02/25 Python
pycharm中leetcode插件使用图文详解
2020/12/07 Python
详解python3类型注释annotations实用案例
2021/01/20 Python
python脚本使用阿里云slb对恶意攻击进行封堵的实现
2021/02/04 Python
突袭HTML5之Javascript API扩展3—本地存储全新体验
2013/01/31 HTML / CSS
新学期标语
2014/06/30 职场文书
关于旅游的活动方案
2014/08/15 职场文书
高等学院职业生涯规划书范文
2014/09/16 职场文书
上课迟到检讨书300字
2014/10/15 职场文书
2014年销售部工作总结
2014/12/01 职场文书
怎样写工作总结啊!
2019/06/18 职场文书
求职自荐信该如何书写?
2019/06/24 职场文书
MySQL索引篇之千万级数据实战测试
2021/04/05 MySQL
mysql联合索引的使用规则
2021/06/23 MySQL
深入讲解数据库中Decimal类型的使用以及实现方法
2022/02/15 MySQL