python神经网络Xception模型


Posted in Python onMay 06, 2022

Xception是继Inception后提出的对Inception v3的另一种改进,学一学总是好的

什么是Xception模型

Xception是谷歌公司继Inception后,提出的InceptionV3的一种改进模型,其改进的主要内容为采用depthwise separable convolution来替换原来Inception v3中的多尺寸卷积核特征响应操作。

在讲Xception模型之前,首先要讲一下什么是depthwise separable convolution(深度可分离卷积块)。

深度可分离卷积块由两个部分组成,分别是深度可分离卷积和1x1普通卷积,深度可分离卷积的卷积核大小一般是3x3的,便于理解的话我们可以把它当作是特征提取,1x1的普通卷积可以完成通道数的调整。

下图为深度可分离卷积块的结构示意图:

python神经网络Xception模型


深度可分离卷积块的目的是使用更少的参数来代替普通的3x3卷积。

我们可以进行一下普通卷积和深度可分离卷积块的对比:

假设有一个3×3大小的卷积层,其输入通道为16、输出通道为32。具体为,32个3×3大小的卷积核会遍历16个通道中的每个数据,最后可得到所需的32个输出通道,所需参数为16×32×3×3=4608个。

应用深度可分离卷积,用16个3×3大小的卷积核分别遍历16通道的数据,得到了16个特征图谱。在融合操作之前,接着用32个1×1大小的卷积核遍历这16个特征图谱,所需参数为16×3×3+16×32×1×1=656个。

可以看出来depthwise separable convolution可以减少模型的参数。

通俗地理解深度可分离卷积结构块,就是3x3的卷积核厚度只有一层,然后在输入张量上一层一层地滑动,每一次卷积完生成一个输出通道,当卷积完成后,再利用1x1的卷积调整厚度。

(视频中有些许错误,感谢zl960929的提醒,Xception使用的深度可分离卷积块SeparableConv2D也就是先深度可分离卷积再进行1x1卷积。)

对于Xception模型而言,其一共可以分为3个flow,分别是Entry flow、Middle flow、Exit flow;分为14个block,其中Entry flow中有4个、Middle flow中有8个、Exit flow中有2个。具体结构如下:

python神经网络Xception模型


其内部主要结构就是残差卷积网络搭配SeparableConv2D层实现一个个block,在Xception模型中,常见的两个block的结构如下。
这个主要在Entry flow和Exit flow中:

python神经网络Xception模型


这个主要在Middle flow中:

python神经网络Xception模型

Xception网络部分实现代码

#-------------------------------------------------------------#
#   Xception的网络部分
#-------------------------------------------------------------#
from keras.preprocessing import image

from keras.models import Model
from keras import layers
from keras.layers import Dense,Input,BatchNormalization,Activation,Conv2D,SeparableConv2D,MaxPooling2D
from keras.layers import GlobalAveragePooling2D,GlobalMaxPooling2D
from keras import backend as K
from keras.applications.imagenet_utils import decode_predictions


def Xception(input_shape = [299,299,3],classes=1000):


    img_input = Input(shape=input_shape)

    #--------------------------#
    # Entry flow
    #--------------------------#
    #--------------------#
    # block1
    #--------------------#
    # 299,299,3 -> 149,149,64
    x = Conv2D(32, (3, 3), strides=(2, 2), use_bias=False, name='block1_conv1')(img_input)
    x = BatchNormalization(name='block1_conv1_bn')(x)
    x = Activation('relu', name='block1_conv1_act')(x)
    x = Conv2D(64, (3, 3), use_bias=False, name='block1_conv2')(x)
    x = BatchNormalization(name='block1_conv2_bn')(x)
    x = Activation('relu', name='block1_conv2_act')(x)

    #--------------------#
    # block2
    #--------------------#
    # 149,149,64 -> 75,75,128
    residual = Conv2D(128, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
    residual = BatchNormalization()(residual)

    x = SeparableConv2D(128, (3, 3), padding='same', use_bias=False, name='block2_sepconv1')(x)
    x = BatchNormalization(name='block2_sepconv1_bn')(x)
    x = Activation('relu', name='block2_sepconv2_act')(x)
    x = SeparableConv2D(128, (3, 3), padding='same', use_bias=False, name='block2_sepconv2')(x)
    x = BatchNormalization(name='block2_sepconv2_bn')(x)

    x = MaxPooling2D((3, 3), strides=(2, 2), padding='same', name='block2_pool')(x)
    x = layers.add([x, residual])

    #--------------------#
    # block3
    #--------------------#
    # 75,75,128 -> 38,38,256
    residual = Conv2D(256, (1, 1), strides=(2, 2),padding='same', use_bias=False)(x)
    residual = BatchNormalization()(residual)

    x = Activation('relu', name='block3_sepconv1_act')(x)
    x = SeparableConv2D(256, (3, 3), padding='same', use_bias=False, name='block3_sepconv1')(x)
    x = BatchNormalization(name='block3_sepconv1_bn')(x)
    x = Activation('relu', name='block3_sepconv2_act')(x)
    x = SeparableConv2D(256, (3, 3), padding='same', use_bias=False, name='block3_sepconv2')(x)
    x = BatchNormalization(name='block3_sepconv2_bn')(x)

    x = MaxPooling2D((3, 3), strides=(2, 2), padding='same', name='block3_pool')(x)
    x = layers.add([x, residual])

    #--------------------#
    # block4
    #--------------------#
    # 38,38,256 -> 19,19,728
    residual = Conv2D(728, (1, 1), strides=(2, 2),padding='same', use_bias=False)(x)
    residual = BatchNormalization()(residual)

    x = Activation('relu', name='block4_sepconv1_act')(x)
    x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name='block4_sepconv1')(x)
    x = BatchNormalization(name='block4_sepconv1_bn')(x)
    x = Activation('relu', name='block4_sepconv2_act')(x)
    x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name='block4_sepconv2')(x)
    x = BatchNormalization(name='block4_sepconv2_bn')(x)

    x = MaxPooling2D((3, 3), strides=(2, 2), padding='same', name='block4_pool')(x)
    x = layers.add([x, residual])

    #--------------------------#
    # Middle flow
    #--------------------------#
    #--------------------#
    # block5--block12
    #--------------------#
    # 19,19,728 -> 19,19,728
    for i in range(8):
        residual = x
        prefix = 'block' + str(i + 5)

        x = Activation('relu', name=prefix + '_sepconv1_act')(x)
        x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name=prefix + '_sepconv1')(x)
        x = BatchNormalization(name=prefix + '_sepconv1_bn')(x)
        x = Activation('relu', name=prefix + '_sepconv2_act')(x)
        x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name=prefix + '_sepconv2')(x)
        x = BatchNormalization(name=prefix + '_sepconv2_bn')(x)
        x = Activation('relu', name=prefix + '_sepconv3_act')(x)
        x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name=prefix + '_sepconv3')(x)
        x = BatchNormalization(name=prefix + '_sepconv3_bn')(x)

        x = layers.add([x, residual])

    #--------------------------#
    # Exit flow
    #--------------------------#
    #--------------------#
    # block13
    #--------------------#
    # 19,19,728 -> 10,10,1024
    residual = Conv2D(1024, (1, 1), strides=(2, 2),
                      padding='same', use_bias=False)(x)
    residual = BatchNormalization()(residual)

    x = Activation('relu', name='block13_sepconv1_act')(x)
    x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name='block13_sepconv1')(x)
    x = BatchNormalization(name='block13_sepconv1_bn')(x)
    x = Activation('relu', name='block13_sepconv2_act')(x)
    x = SeparableConv2D(1024, (3, 3), padding='same', use_bias=False, name='block13_sepconv2')(x)
    x = BatchNormalization(name='block13_sepconv2_bn')(x)

    x = MaxPooling2D((3, 3), strides=(2, 2), padding='same', name='block13_pool')(x)
    x = layers.add([x, residual])

    #--------------------#
    # block14
    #--------------------#
    # 10,10,1024 -> 10,10,2048
    x = SeparableConv2D(1536, (3, 3), padding='same', use_bias=False, name='block14_sepconv1')(x)
    x = BatchNormalization(name='block14_sepconv1_bn')(x)
    x = Activation('relu', name='block14_sepconv1_act')(x)

    x = SeparableConv2D(2048, (3, 3), padding='same', use_bias=False, name='block14_sepconv2')(x)
    x = BatchNormalization(name='block14_sepconv2_bn')(x)
    x = Activation('relu', name='block14_sepconv2_act')(x)

    x = GlobalAveragePooling2D(name='avg_pool')(x)
    x = Dense(classes, activation='softmax', name='predictions')(x)

    inputs = img_input

    model = Model(inputs, x, name='xception')

    model.load_weights("xception_weights_tf_dim_ordering_tf_kernels.h5")

    return model

图片预测

建立网络后,可以用以下的代码进行预测。

def preprocess_input(x):
    x /= 255.
    x -= 0.5
    x *= 2.
    return x


if __name__ == '__main__':
    model = Xception()

    img_path = 'elephant.jpg'
    img = image.load_img(img_path, target_size=(299, 299))
    x = image.img_to_array(img)
    x = np.expand_dims(x, axis=0)
    x = preprocess_input(x)
    print('Input image shape:', x.shape)

    preds = model.predict(x)
    print(np.argmax(preds))
    print('Predicted:', decode_predictions(preds))

预测所需的已经训练好的Xception模型可以在https://github.com/fchollet/deep-learning-models/releases下载。非常方便。

预测结果为:

Predicted: [[('n02504458', 'African_elephant', 0.47570863), ('n01871265', 'tusker', 0.3173351), ('n02504013', 'Indian_elephant', 0.030323735), ('n02963159', 'cardigan', 0.0007877756), ('n02410509', 'bison', 0.00075616257)]]

以上就是python神经网络Xception模型详解的详细内容,更多关于Xception模型的复现详解的资料请关注三水点靠木其它相关文章!


Tags in this post...

Python 相关文章推荐
Python线程的两种编程方式
Apr 14 Python
在Python下尝试多线程编程
Apr 28 Python
利用Python的Django框架生成PDF文件的教程
Jul 22 Python
Python中类型检查的详细介绍
Feb 13 Python
python互斥锁、加锁、同步机制、异步通信知识总结
Feb 11 Python
利用Django-environ如何区分不同环境
Aug 26 Python
python分批定量读取文件内容,输出到不同文件中的方法
Dec 08 Python
python多线程分块读取文件
Aug 29 Python
Python bytes string相互转换过程解析
Mar 05 Python
pycharm中leetcode插件使用图文详解
Dec 07 Python
java关于string最常出现的面试题整理
Jan 18 Python
Python基础数据类型tuple元组的概念与用法
Aug 02 Python
Python使用永中文档转换服务
May 06 #Python
Python tensorflow卷积神经Inception V3网络结构
May 06 #Python
Python实现Matplotlib,Seaborn动态数据图
May 06 #Python
PYTHON InceptionV3模型的复现详解
代码复现python目标检测yolo3详解预测
讲解Python实例练习逆序输出字符串
May 06 #Python
python turtle绘图
May 04 #Python
You might like
当海贼王变成JOJO风
2020/03/02 日漫
动漫女神老婆无限好,但日本女生可能就不是这么一回事了!
2020/03/04 日漫
用mysql触发器自动更新memcache的实现代码
2009/10/11 PHP
非常好用的Zend Framework分页类
2014/06/25 PHP
ThinkPHP框架实现导出excel数据的方法示例【基于PHPExcel】
2018/05/12 PHP
基于jquery的一个简单的脚本验证插件
2010/04/05 Javascript
JavaScript中的面向对象介绍
2012/06/30 Javascript
jQuery语法总结和注意事项小结
2012/11/11 Javascript
jQuery实现id模糊查询的小例子
2013/03/19 Javascript
JQUERY对单选框(radio)操作的小例子
2013/04/25 Javascript
javascript如何动态加载表格与动态添加表格行
2013/11/27 Javascript
jquery结合CSS使用validate实现漂亮的验证
2015/01/29 Javascript
详解JavaScript的while循环的使用
2015/06/03 Javascript
jQuery的框架介绍
2016/05/11 Javascript
JS实现隐藏同级元素后只显示JS文件内容的方法
2016/09/04 Javascript
NPM 安装cordova时警告:npm WARN deprecated minimatch@2.0.10: Please update to minimatch 3.0.2 or higher to
2016/12/20 Javascript
深入理解vue2.0路由如何配置问题
2017/07/18 Javascript
BootStrap下的弹出框加载select2框架失败的解决方法
2017/08/31 Javascript
Angular.js中window.onload(),$(document).ready()的写法浅析
2017/09/28 Javascript
写给小白看的JavaScript异步
2017/11/29 Javascript
NodeJS实现不可逆加密与密码密文保存的方法
2018/03/16 NodeJs
实例讲解v-if和v-show的区别
2019/01/31 Javascript
详解JavaScript 新语法之Class 的私有属性与私有方法
2019/04/23 Javascript
Vue 中可以定义组件模版的几种方式
2019/08/06 Javascript
深入理解python try异常处理机制
2016/06/01 Python
Linux下python3.7.0安装教程
2018/07/30 Python
python 单线程和异步协程工作方式解析
2019/09/28 Python
Jupyter Notebook输出矢量图实例
2020/04/14 Python
基于python实现计算两组数据P值
2020/07/10 Python
详解CSS3:overflow属性
2020/11/17 HTML / CSS
美丽家庭事迹材料
2014/05/03 职场文书
小学兴趣小组活动总结
2014/07/07 职场文书
合作协议书模板
2014/10/10 职场文书
校本培训个人总结
2015/02/28 职场文书
投诉书范文
2015/07/02 职场文书
mysql sock文件存储了什么信息
2022/07/15 MySQL