深度学习小工程练习之垃圾分类详解


Posted in Python onApril 14, 2021

介绍

这是一个基于深度学习的垃圾分类小工程,用深度残差网络构建

软件架构

  1. 使用深度残差网络resnet50作为基石,在后续添加需要的层以适应不同的分类任务
  2. 模型的训练需要用生成器将数据集循环写入内存,同时图像增强以泛化模型
  3. 使用不包含网络输出部分的resnet50权重文件进行迁移学习,只训练我们在5个stage后增加的层

安装教程

  1. 需要的第三方库主要有tensorflow1.x,keras,opencv,Pillow,scikit-learn,numpy
  2. 安装方式很简单,打开terminal,例如:pip install numpy -i https://pypi.tuna.tsinghua.edu.cn/simple
  3. 数据集与权重文件比较大,所以没有上传
  4. 如果环境配置方面有问题或者需要数据集与模型权重文件,可以在评论区说明您的问题,我将远程帮助您

使用说明

  1. 文件夹theory记录了我在本次深度学习中收获的笔记,与模型训练的控制台打印信息
  2. 迁移学习需要的初始权重与模型定义文件resnet50.py放在model
  3. 下训练运行trainNet.py,训练结束会创建models文件夹,并将结果权重garclass.h5写入该文件夹
  4. datagen文件夹下的genit.py用于进行图像预处理以及数据生成器接口
  5. 使用训练好的模型进行垃圾分类,运行Demo.py

结果演示

深度学习小工程练习之垃圾分类详解

cans易拉罐

深度学习小工程练习之垃圾分类详解

代码解释

在实际的模型中,我们只使用了resnet50的5个stage,后面的输出部分需要我们自己定制,网络的结构图如下:

深度学习小工程练习之垃圾分类详解

stage5后我们的定制网络如下:

"""定制resnet后面的层"""
def custom(input_size,num_classes,pretrain):
    # 引入初始化resnet50模型
    base_model = ResNet50(weights=pretrain,
                          include_top=False,
                          pooling=None,
                          input_shape=(input_size,input_size, 3),
                          classes=num_classes)
    #由于有预权重,前部分冻结,后面进行迁移学习
    for layer in base_model.layers:
        layer.trainable = False
    #添加后面的层
    x = base_model.output
    x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
    x = layers.Dropout(0.5,name='dropout1')(x)
    #regularizers正则化层,正则化器允许在优化过程中对层的参数或层的激活情况进行惩罚
    #对损失函数进行最小化的同时,也需要让对参数添加限制,这个限制也就是正则化惩罚项,使用l2范数
    x = layers.Dense(512,activation='relu',kernel_regularizer= regularizers.l2(0.0001),name='fc2')(x)
    x = layers.BatchNormalization(name='bn_fc_01')(x)
    x = layers.Dropout(0.5,name='dropout2')(x)
    #40个分类
    x = layers.Dense(num_classes,activation='softmax')(x)
    model = Model(inputs=base_model.input,outputs=x)
    #模型编译
    model.compile(optimizer="adam",loss = 'categorical_crossentropy',metrics=['accuracy'])
    return model

网络的训练是迁移学习过程,使用已有的初始resnet50权重(5个stage已经训练过,卷积层已经能够提取特征),我们只训练后面的全连接层部分,4个epoch后再对较后面的层进行训练微调一下,获得更高准确率,训练过程如下:

class Net():
    def __init__(self,img_size,gar_num,data_dir,batch_size,pretrain):
        self.img_size=img_size
        self.gar_num=gar_num
        self.data_dir=data_dir
        self.batch_size=batch_size
        self.pretrain=pretrain
    def build_train(self):
        """迁移学习"""
        model = resnet.custom(self.img_size, self.gar_num, self.pretrain)
        model.summary()
        train_sequence, validation_sequence = genit.gendata(self.data_dir, self.batch_size, self.gar_num, self.img_size)
        epochs=4
        model.fit_generator(train_sequence,steps_per_epoch=len(train_sequence),epochs=epochs,verbose=1,validation_data=validation_sequence,
                                     max_queue_size=10,shuffle=True)
        #微调,在实际工程中,激活函数也被算进层里,所以总共181层,微调是为了重新训练部分卷积层,同时训练最后的全连接层
        layers=149
        learning_rate=1e-4
        for layer in model.layers[:layers]:
            layer.trainable = False
        for layer in model.layers[layers:]:
            layer.trainable = True
        Adam =adam(lr=learning_rate, decay=0.0005)
        model.compile(optimizer=Adam, loss='categorical_crossentropy', metrics=['accuracy'])
        model.fit_generator(train_sequence,steps_per_epoch=len(train_sequence),epochs=epochs * 2,verbose=1,
            callbacks=[
                callbacks.ModelCheckpoint('./models/garclass.h5',monitor='val_loss', save_best_only=True, mode='min'),
                callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1,patience=10, mode='min'),
                callbacks.EarlyStopping(monitor='val_loss', patience=10),],
            validation_data=validation_sequence,max_queue_size=10,shuffle=True)
        print('finish train,look for garclass.h5')

训练结果如下:

"""
    loss: 0.7949 - acc: 0.9494 - val_loss: 0.9900 - val_acc: 0.8797
    训练用了9小时左右
    """

如果使用更好的显卡,可以更快完成训练

最后

希望大家可以体验到深度学习带来的收获,能和大家学习很开心,更多关于深度学习的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
Python性能提升之延迟初始化
Dec 04 Python
Python实现mysql数据库更新表数据接口的功能
Nov 19 Python
rabbitmq(中间消息代理)在python中的使用详解
Dec 14 Python
对python的unittest架构公共参数token提取方法详解
Dec 17 Python
Python内存管理实例分析
Jul 10 Python
python实现一行输入多个值和一行输出多个值的例子
Jul 16 Python
Python如何使用BeautifulSoup爬取网页信息
Nov 26 Python
浅谈pymysql查询语句中带有in时传递参数的问题
Jun 05 Python
celery在python爬虫中定时操作实例讲解
Nov 27 Python
python 视频下载神器(you-get)的具体使用
Jan 06 Python
pandas apply使用多列计算生成新的列实现示例
Feb 24 Python
Python还能这么玩之只用30行代码从excel提取个人值班表
Jun 05 Python
python3美化表格数据输出结果的实现代码
Apr 14 #Python
Python生成九宫格图片的示例代码
用Python写一个简易版弹球游戏
python urllib库的使用详解
Apr 13 #Python
用Python将库打包发布到pypi
python xlwt模块的使用解析
python 爬取豆瓣网页的示例
You might like
php侧拉菜单 漂亮,可以向右或者向左展开,支持FF,IE
2009/10/15 PHP
PHP判断文件是否存在、是否可读、目录是否存在的代码
2012/10/03 PHP
PHP中上传多个文件的表单设计例子
2014/11/19 PHP
thinkphp文件引用与分支结构用法实例
2014/11/26 PHP
PHP获取客户端及服务器端IP的封装类
2016/07/21 PHP
PHP 实现页面静态化的几种方法
2017/07/23 PHP
thinkphp集成前端脚手架Vue-cli的教程图解
2018/08/30 PHP
jquery 简单应用示例总结
2013/08/09 Javascript
JavaScript弹出新窗口并控制窗口移动到指定位置的方法
2015/04/06 Javascript
JavaScript笔记之数据属性和存储器属性
2016/03/31 Javascript
JS生成不重复的随机数组的简单实例
2016/07/10 Javascript
基于BootStrap实现局部刷新分页实例代码
2016/08/08 Javascript
Bootstrap CSS使用方法
2016/12/23 Javascript
纯js实现倒计时功能
2017/01/06 Javascript
ui-router中使用ocLazyLoad和resolve的具体方法
2017/10/18 Javascript
微信小程序返回多级页面的实现方法
2017/10/27 Javascript
详解vue-cli 3.0 build包太大导致首屏过长的解决方案
2018/11/10 Javascript
vue-cli3单页构建大型项目方案
2020/04/07 Javascript
JS如何实现在弹出窗口中加载页面
2020/12/03 Javascript
[01:21:36]CHAOS vs Alliacne 2019国际邀请赛小组赛 BO2 第一场 8.15
2019/08/16 DOTA
python利用paramiko连接远程服务器执行命令的方法
2017/10/16 Python
详细解读tornado协程(coroutine)原理
2018/01/15 Python
Python使用matplotlib填充图形指定区域代码示例
2018/01/16 Python
Python 字符串转换为整形和浮点类型的方法
2018/07/17 Python
Python切片操作深入详解
2018/07/27 Python
python爬取百度贴吧前1000页内容(requests库面向对象思想实现)
2019/08/10 Python
简单的Python人脸识别系统
2020/07/14 Python
python实现canny边缘检测
2020/09/14 Python
python 读取yaml文件的两种方法(在unittest中使用)
2020/12/01 Python
HTML中使用SVG与SVG预定义形状元素介绍
2013/06/28 HTML / CSS
科颜氏加拿大官方网站: Kiehl’s加拿大
2016/08/16 全球购物
日本最大美瞳直送网:Morecontact(中文)
2019/04/03 全球购物
毕业生求职简历的自我评价
2013/10/07 职场文书
幼儿园教学管理制度
2014/02/04 职场文书
优秀毕业生自我鉴定
2014/02/11 职场文书
Python面向对象之内置函数相关知识总结
2021/06/24 Python