python神经网络ResNet50模型


Posted in Python onMay 06, 2022

什么是残差网络

最近看yolo3里面讲到了残差网络,对这个网络结构很感兴趣,于是了解到这个网络结构最初的使用是在ResNet网络里。

Residual net(残差网络):
 

将靠前若干层的某一层数据输出直接跳过多层引入到后面数据层的输入部分。

意味着后面的特征层的内容会有一部分由其前面的某一层线性贡献。

其结构如下:

python神经网络ResNet50模型


深度残差网络的设计是为了克服由于网络深度加深而产生的学习效率变低与准确率无法有效提升的问题。

什么是ResNet50模型

ResNet50有两个基本的块,分别名为Conv Block和Identity Block,其中Conv Block输入和输出的维度是不一样的,所以不能连续串联,它的作用是改变网络的维度;

Identity Block输入维度和输出维度相同,可以串联,用于加深网络的。

Conv Block的结构如下:

python神经网络ResNet50模型


Identity Block的结构如下:

python神经网络ResNet50模型


这两个都是残差网络结构。

总的网络结构如下:

python神经网络ResNet50模型


这样看起来可能比较抽象,还有一副很好的我从网上找的图,可以拉到最后面去看哈,放前面太占位置了。

ResNet50网络部分实现代码

#-------------------------------------------------------------#
#   ResNet50的网络部分
#-------------------------------------------------------------#
from __future__ import print_function

import numpy as np
from keras import layers

from keras.layers import Input
from keras.layers import Dense,Conv2D,MaxPooling2D,ZeroPadding2D,AveragePooling2D
from keras.layers import Activation,BatchNormalization,Flatten
from keras.models import Model

from keras.preprocessing import image
import keras.backend as K
from keras.utils.data_utils import get_file
from keras.applications.imagenet_utils import decode_predictions
from keras.applications.imagenet_utils import preprocess_input


def identity_block(input_tensor, kernel_size, filters, stage, block):

    filters1, filters2, filters3 = filters

    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base = 'bn' + str(stage) + block + '_branch'

    x = Conv2D(filters1, (1, 1), name=conv_name_base + '2a')(input_tensor)
    x = BatchNormalization(name=bn_name_base + '2a')(x)
    x = Activation('relu')(x)

    x = Conv2D(filters2, kernel_size,padding='same', name=conv_name_base + '2b')(x)

    x = BatchNormalization(name=bn_name_base + '2b')(x)
    x = Activation('relu')(x)

    x = Conv2D(filters3, (1, 1), name=conv_name_base + '2c')(x)
    x = BatchNormalization(name=bn_name_base + '2c')(x)

    x = layers.add([x, input_tensor])
    x = Activation('relu')(x)
    return x


def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)):

    filters1, filters2, filters3 = filters

    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base = 'bn' + str(stage) + block + '_branch'

    x = Conv2D(filters1, (1, 1), strides=strides,
               name=conv_name_base + '2a')(input_tensor)
    x = BatchNormalization(name=bn_name_base + '2a')(x)
    x = Activation('relu')(x)

    x = Conv2D(filters2, kernel_size, padding='same',
               name=conv_name_base + '2b')(x)
    x = BatchNormalization(name=bn_name_base + '2b')(x)
    x = Activation('relu')(x)

    x = Conv2D(filters3, (1, 1), name=conv_name_base + '2c')(x)
    x = BatchNormalization(name=bn_name_base + '2c')(x)

    shortcut = Conv2D(filters3, (1, 1), strides=strides,
                      name=conv_name_base + '1')(input_tensor)
    shortcut = BatchNormalization(name=bn_name_base + '1')(shortcut)

    x = layers.add([x, shortcut])
    x = Activation('relu')(x)
    return x


def ResNet50(input_shape=[224,224,3],classes=1000):

    img_input = Input(shape=input_shape)
    x = ZeroPadding2D((3, 3))(img_input)

    x = Conv2D(64, (7, 7), strides=(2, 2), name='conv1')(x)
    x = BatchNormalization(name='bn_conv1')(x)
    x = Activation('relu')(x)
    x = MaxPooling2D((3, 3), strides=(2, 2))(x)

    x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
    x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
    x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')

    x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
    x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
    x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
    x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')

    x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')

    x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
    x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
    x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')

    x = AveragePooling2D((7, 7), name='avg_pool')(x)

    x = Flatten()(x)
    x = Dense(classes, activation='softmax', name='fc1000')(x)

    model = Model(img_input, x, name='resnet50')

    model.load_weights("resnet50_weights_tf_dim_ordering_tf_kernels.h5")

    return model

图片预测

建立网络后,可以用以下的代码进行预测。

if __name__ == '__main__':
    model = ResNet50()
    model.summary()
    img_path = 'elephant.jpg'
    img = image.load_img(img_path, target_size=(224, 224))
    x = image.img_to_array(img)
    x = np.expand_dims(x, axis=0)
    x = preprocess_input(x)

    print('Input image shape:', x.shape)
    preds = model.predict(x)
    print('Predicted:', decode_predictions(preds))

预测所需的已经训练好的ResNet50模型可以在https://github.com/fchollet/deep-learning-models/releases下载。非常方便。
预测结果为:

Predicted: [[('n01871265', 'tusker', 0.41107917), ('n02504458', 'African_elephant', 0.39015812), ('n02504013', 'Indian_elephant', 0.12260196), ('n03000247', 'chain_mail', 0.023176488), ('n02437312', 'Arabian_camel', 0.020982226)]]

以上就是python神经网络ResNet50模型的复现详解的详细内容!


Tags in this post...

Python 相关文章推荐
python连接MySQL数据库实例分析
May 12 Python
编写Python爬虫抓取豆瓣电影TOP100及用户头像的方法
Jan 20 Python
Python基础教程之内置函数locals()和globals()用法分析
Mar 16 Python
python数字图像处理之高级形态学处理
Apr 27 Python
Selenium定位元素操作示例
Aug 10 Python
Python的argparse库使用详解
Oct 09 Python
Python 实现两个服务器之间文件的上传方法
Feb 13 Python
简单易懂Pytorch实战实例VGG深度网络
Aug 27 Python
Python: tkinter窗口屏幕居中,设置窗口最大,最小尺寸实例
Mar 04 Python
python 连续不等式语法糖实例
Apr 15 Python
如何用python处理excel表格
Jun 09 Python
python可视化 matplotlib画图使用colorbar工具自定义颜色
Dec 07 Python
python和anaconda的区别
May 06 #Python
python神经网络Xception模型
May 06 #Python
Python使用永中文档转换服务
May 06 #Python
Python tensorflow卷积神经Inception V3网络结构
May 06 #Python
Python实现Matplotlib,Seaborn动态数据图
May 06 #Python
PYTHON InceptionV3模型的复现详解
代码复现python目标检测yolo3详解预测
You might like
php强制下载文件函数
2016/08/24 PHP
删除PHP数组中的重复元素的实现代码
2017/04/10 PHP
ASP.NET中使用后端代码注册脚本 生成JQUERY-EASYUI的界面错位的解决方法
2010/06/12 Javascript
菜鸟学习JavaScript小实验之函数引用
2010/11/17 Javascript
浅析document.createDocumentFragment()与js效率
2013/07/08 Javascript
nodejs教程之异步I/O
2014/11/21 NodeJs
jQuery中dom元素上绑定的事件详解
2015/04/24 Javascript
Javascript 高阶函数使用介绍
2015/06/15 Javascript
快速学习AngularJs HTTP响应拦截器
2015/12/31 Javascript
瀑布流的实现方式(原生js+jquery+css3)
2020/06/28 Javascript
避免jQuery名字冲突 noConflict()方法
2016/07/30 Javascript
解决OneThink中无法异步提交kindeditor文本框中修改后的内容方法
2017/05/05 Javascript
Bootstrap 3多级下拉菜单实例
2017/11/23 Javascript
mpvue 如何使用腾讯视频插件的方法
2018/07/16 Javascript
详解element-ui设置下拉选择切换必填和非必填
2019/06/17 Javascript
node 解析图片二维码的内容代码实例
2019/09/11 Javascript
判断JavaScript中的两个变量是否相等的操作符
2019/12/21 Javascript
解决python中os.listdir()函数读取文件夹下文件的乱序和排序问题
2018/10/17 Python
python中web框架的自定义创建
2019/09/08 Python
Windows 平台做 Python 开发的最佳组合(推荐)
2020/07/27 Python
用HTML5实现鼠标滚轮事件放大缩小图片的功能
2015/06/25 HTML / CSS
美国折扣宠物药房:Total Pet Supply
2018/05/27 全球购物
美国领先的个性化礼品商城:Personalization Mall
2019/07/27 全球购物
娱乐地球:Entertainment Earth
2020/01/08 全球购物
WEB控件及HTML服务端控件能否调用客户端方法?如果能,请解释如何调用?
2015/08/25 面试题
专科毕业生就业推荐信
2013/11/01 职场文书
2014端午节活动策划方案
2014/01/27 职场文书
安全检查管理制度
2014/02/02 职场文书
小学校长竞聘演讲稿
2014/05/16 职场文书
婚前财产协议书范本
2014/10/19 职场文书
阿里云服务器搭建Php+Apache运行环境的详细过程
2021/05/15 PHP
CSS 制作波浪效果的思路
2021/05/18 HTML / CSS
Python基础之条件语句详解
2021/06/16 Python
Python接口自动化之文件上传/下载接口详解
2022/04/05 Python
青岛市的收音机研制与生产
2022/04/07 无线电
MySQL存储过程及语法详解
2022/08/05 MySQL