深度学习小工程练习之垃圾分类详解


Posted in Python onApril 14, 2021

介绍

这是一个基于深度学习的垃圾分类小工程,用深度残差网络构建

软件架构

  1. 使用深度残差网络resnet50作为基石,在后续添加需要的层以适应不同的分类任务
  2. 模型的训练需要用生成器将数据集循环写入内存,同时图像增强以泛化模型
  3. 使用不包含网络输出部分的resnet50权重文件进行迁移学习,只训练我们在5个stage后增加的层

安装教程

  1. 需要的第三方库主要有tensorflow1.x,keras,opencv,Pillow,scikit-learn,numpy
  2. 安装方式很简单,打开terminal,例如:pip install numpy -i https://pypi.tuna.tsinghua.edu.cn/simple
  3. 数据集与权重文件比较大,所以没有上传
  4. 如果环境配置方面有问题或者需要数据集与模型权重文件,可以在评论区说明您的问题,我将远程帮助您

使用说明

  1. 文件夹theory记录了我在本次深度学习中收获的笔记,与模型训练的控制台打印信息
  2. 迁移学习需要的初始权重与模型定义文件resnet50.py放在model
  3. 下训练运行trainNet.py,训练结束会创建models文件夹,并将结果权重garclass.h5写入该文件夹
  4. datagen文件夹下的genit.py用于进行图像预处理以及数据生成器接口
  5. 使用训练好的模型进行垃圾分类,运行Demo.py

结果演示

深度学习小工程练习之垃圾分类详解

cans易拉罐

深度学习小工程练习之垃圾分类详解

代码解释

在实际的模型中,我们只使用了resnet50的5个stage,后面的输出部分需要我们自己定制,网络的结构图如下:

深度学习小工程练习之垃圾分类详解

stage5后我们的定制网络如下:

"""定制resnet后面的层"""
def custom(input_size,num_classes,pretrain):
    # 引入初始化resnet50模型
    base_model = ResNet50(weights=pretrain,
                          include_top=False,
                          pooling=None,
                          input_shape=(input_size,input_size, 3),
                          classes=num_classes)
    #由于有预权重,前部分冻结,后面进行迁移学习
    for layer in base_model.layers:
        layer.trainable = False
    #添加后面的层
    x = base_model.output
    x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
    x = layers.Dropout(0.5,name='dropout1')(x)
    #regularizers正则化层,正则化器允许在优化过程中对层的参数或层的激活情况进行惩罚
    #对损失函数进行最小化的同时,也需要让对参数添加限制,这个限制也就是正则化惩罚项,使用l2范数
    x = layers.Dense(512,activation='relu',kernel_regularizer= regularizers.l2(0.0001),name='fc2')(x)
    x = layers.BatchNormalization(name='bn_fc_01')(x)
    x = layers.Dropout(0.5,name='dropout2')(x)
    #40个分类
    x = layers.Dense(num_classes,activation='softmax')(x)
    model = Model(inputs=base_model.input,outputs=x)
    #模型编译
    model.compile(optimizer="adam",loss = 'categorical_crossentropy',metrics=['accuracy'])
    return model

网络的训练是迁移学习过程,使用已有的初始resnet50权重(5个stage已经训练过,卷积层已经能够提取特征),我们只训练后面的全连接层部分,4个epoch后再对较后面的层进行训练微调一下,获得更高准确率,训练过程如下:

class Net():
    def __init__(self,img_size,gar_num,data_dir,batch_size,pretrain):
        self.img_size=img_size
        self.gar_num=gar_num
        self.data_dir=data_dir
        self.batch_size=batch_size
        self.pretrain=pretrain
    def build_train(self):
        """迁移学习"""
        model = resnet.custom(self.img_size, self.gar_num, self.pretrain)
        model.summary()
        train_sequence, validation_sequence = genit.gendata(self.data_dir, self.batch_size, self.gar_num, self.img_size)
        epochs=4
        model.fit_generator(train_sequence,steps_per_epoch=len(train_sequence),epochs=epochs,verbose=1,validation_data=validation_sequence,
                                     max_queue_size=10,shuffle=True)
        #微调,在实际工程中,激活函数也被算进层里,所以总共181层,微调是为了重新训练部分卷积层,同时训练最后的全连接层
        layers=149
        learning_rate=1e-4
        for layer in model.layers[:layers]:
            layer.trainable = False
        for layer in model.layers[layers:]:
            layer.trainable = True
        Adam =adam(lr=learning_rate, decay=0.0005)
        model.compile(optimizer=Adam, loss='categorical_crossentropy', metrics=['accuracy'])
        model.fit_generator(train_sequence,steps_per_epoch=len(train_sequence),epochs=epochs * 2,verbose=1,
            callbacks=[
                callbacks.ModelCheckpoint('./models/garclass.h5',monitor='val_loss', save_best_only=True, mode='min'),
                callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1,patience=10, mode='min'),
                callbacks.EarlyStopping(monitor='val_loss', patience=10),],
            validation_data=validation_sequence,max_queue_size=10,shuffle=True)
        print('finish train,look for garclass.h5')

训练结果如下:

"""
    loss: 0.7949 - acc: 0.9494 - val_loss: 0.9900 - val_acc: 0.8797
    训练用了9小时左右
    """

如果使用更好的显卡,可以更快完成训练

最后

希望大家可以体验到深度学习带来的收获,能和大家学习很开心,更多关于深度学习的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
跟老齐学Python之不要红头文件(2)
Sep 28 Python
Python中列表的一些基本操作知识汇总
May 20 Python
python3音乐播放器简单实现代码
Apr 20 Python
快速入门python学习笔记
Dec 06 Python
用python实现的线程池实例代码
Jan 06 Python
Python爬虫实例_利用百度地图API批量获取城市所有的POI点
Jan 10 Python
浅析PyTorch中nn.Linear的使用
Aug 18 Python
使用python的turtle绘画滑稽脸实例
Nov 21 Python
如何使用python代码操作git代码
Feb 29 Python
Python unittest单元测试openpyxl实现过程解析
May 27 Python
Python如何实现FTP功能
May 28 Python
Python3 requests模块如何模仿浏览器及代理
Jun 15 Python
python3美化表格数据输出结果的实现代码
Apr 14 #Python
Python生成九宫格图片的示例代码
用Python写一个简易版弹球游戏
python urllib库的使用详解
Apr 13 #Python
用Python将库打包发布到pypi
python xlwt模块的使用解析
python 爬取豆瓣网页的示例
You might like
PHP set_time_limit(0)长连接的实现分析
2010/03/02 PHP
PHP实例分享判断客户端是否使用代理服务器及其匿名级别
2014/06/04 PHP
PHP网页游戏学习之Xnova(ogame)源码解读(八)
2014/06/23 PHP
PHP的Yii框架的基本使用示例
2015/08/21 PHP
PHP表单提交后引号前自动加反斜杠的原因及三种办法关闭php魔术引号
2015/09/30 PHP
php中上传文件的的解决方案
2018/09/25 PHP
Alliance vs Liquid BO3 第二场2.13
2021/03/10 DOTA
Nigma vs Alliance BO5 第四场2.14
2021/03/10 DOTA
基于mootools插件实现遮罩层新手引导
2012/05/24 Javascript
文本框中禁止非数字字符输入比如手机号码、邮编
2013/08/19 Javascript
实现网页页面跳转的几种方法(meta标签、js实现、php实现)
2014/05/20 Javascript
谷歌地图打不开的解决办法
2014/08/07 Javascript
快速学习jQuery插件 Form表单插件使用方法
2015/12/01 Javascript
js跨浏览器的事件侦听器和事件对象的使用方法
2015/12/17 Javascript
Bootstrap入门书籍之(三)栅格系统
2016/02/17 Javascript
jquery把int类型转换成字符串类型的方法
2016/10/07 Javascript
AngularJS全局scope与Isolate scope通信用法示例
2016/11/22 Javascript
jquery ajax异步提交表单数据的方法
2017/10/27 jQuery
node实现登录图片验证码的示例代码
2018/04/20 Javascript
webstorm+vue初始化项目的方法
2018/10/18 Javascript
react结合bootstrap实现评论功能
2020/05/30 Javascript
jQuery带控制按钮轮播图插件
2020/07/31 jQuery
python 简易计算器程序,代码就几行
2009/08/29 Python
Python 字符串与二进制串的相互转换示例
2018/07/23 Python
Django页面数据的缓存与使用的具体方法
2019/04/23 Python
python3+django2开发一个简单的人员管理系统过程详解
2019/07/23 Python
利用Python实现kNN算法的代码
2019/08/16 Python
浅谈python之自动化运维(Paramiko)
2020/01/31 Python
详解python对象之间的交互
2020/09/29 Python
python 通过pip freeze、dowload打离线包及自动安装的过程详解(适用于保密的离线环境
2020/12/14 Python
canvas简易绘图的实现(海绵宝宝篇)
2018/07/04 HTML / CSS
香港万宁官方海外旗舰店:香港健与美连锁店
2018/09/27 全球购物
开会迟到检讨书
2014/01/08 职场文书
解除劳动合同协议书
2014/04/14 职场文书
财务稽核岗位职责
2015/04/13 职场文书
经典格言警句:没有热忱,世间便无进步
2019/11/13 职场文书