深度学习小工程练习之垃圾分类详解


Posted in Python onApril 14, 2021

介绍

这是一个基于深度学习的垃圾分类小工程,用深度残差网络构建

软件架构

  1. 使用深度残差网络resnet50作为基石,在后续添加需要的层以适应不同的分类任务
  2. 模型的训练需要用生成器将数据集循环写入内存,同时图像增强以泛化模型
  3. 使用不包含网络输出部分的resnet50权重文件进行迁移学习,只训练我们在5个stage后增加的层

安装教程

  1. 需要的第三方库主要有tensorflow1.x,keras,opencv,Pillow,scikit-learn,numpy
  2. 安装方式很简单,打开terminal,例如:pip install numpy -i https://pypi.tuna.tsinghua.edu.cn/simple
  3. 数据集与权重文件比较大,所以没有上传
  4. 如果环境配置方面有问题或者需要数据集与模型权重文件,可以在评论区说明您的问题,我将远程帮助您

使用说明

  1. 文件夹theory记录了我在本次深度学习中收获的笔记,与模型训练的控制台打印信息
  2. 迁移学习需要的初始权重与模型定义文件resnet50.py放在model
  3. 下训练运行trainNet.py,训练结束会创建models文件夹,并将结果权重garclass.h5写入该文件夹
  4. datagen文件夹下的genit.py用于进行图像预处理以及数据生成器接口
  5. 使用训练好的模型进行垃圾分类,运行Demo.py

结果演示

深度学习小工程练习之垃圾分类详解

cans易拉罐

深度学习小工程练习之垃圾分类详解

代码解释

在实际的模型中,我们只使用了resnet50的5个stage,后面的输出部分需要我们自己定制,网络的结构图如下:

深度学习小工程练习之垃圾分类详解

stage5后我们的定制网络如下:

"""定制resnet后面的层"""
def custom(input_size,num_classes,pretrain):
    # 引入初始化resnet50模型
    base_model = ResNet50(weights=pretrain,
                          include_top=False,
                          pooling=None,
                          input_shape=(input_size,input_size, 3),
                          classes=num_classes)
    #由于有预权重,前部分冻结,后面进行迁移学习
    for layer in base_model.layers:
        layer.trainable = False
    #添加后面的层
    x = base_model.output
    x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
    x = layers.Dropout(0.5,name='dropout1')(x)
    #regularizers正则化层,正则化器允许在优化过程中对层的参数或层的激活情况进行惩罚
    #对损失函数进行最小化的同时,也需要让对参数添加限制,这个限制也就是正则化惩罚项,使用l2范数
    x = layers.Dense(512,activation='relu',kernel_regularizer= regularizers.l2(0.0001),name='fc2')(x)
    x = layers.BatchNormalization(name='bn_fc_01')(x)
    x = layers.Dropout(0.5,name='dropout2')(x)
    #40个分类
    x = layers.Dense(num_classes,activation='softmax')(x)
    model = Model(inputs=base_model.input,outputs=x)
    #模型编译
    model.compile(optimizer="adam",loss = 'categorical_crossentropy',metrics=['accuracy'])
    return model

网络的训练是迁移学习过程,使用已有的初始resnet50权重(5个stage已经训练过,卷积层已经能够提取特征),我们只训练后面的全连接层部分,4个epoch后再对较后面的层进行训练微调一下,获得更高准确率,训练过程如下:

class Net():
    def __init__(self,img_size,gar_num,data_dir,batch_size,pretrain):
        self.img_size=img_size
        self.gar_num=gar_num
        self.data_dir=data_dir
        self.batch_size=batch_size
        self.pretrain=pretrain
    def build_train(self):
        """迁移学习"""
        model = resnet.custom(self.img_size, self.gar_num, self.pretrain)
        model.summary()
        train_sequence, validation_sequence = genit.gendata(self.data_dir, self.batch_size, self.gar_num, self.img_size)
        epochs=4
        model.fit_generator(train_sequence,steps_per_epoch=len(train_sequence),epochs=epochs,verbose=1,validation_data=validation_sequence,
                                     max_queue_size=10,shuffle=True)
        #微调,在实际工程中,激活函数也被算进层里,所以总共181层,微调是为了重新训练部分卷积层,同时训练最后的全连接层
        layers=149
        learning_rate=1e-4
        for layer in model.layers[:layers]:
            layer.trainable = False
        for layer in model.layers[layers:]:
            layer.trainable = True
        Adam =adam(lr=learning_rate, decay=0.0005)
        model.compile(optimizer=Adam, loss='categorical_crossentropy', metrics=['accuracy'])
        model.fit_generator(train_sequence,steps_per_epoch=len(train_sequence),epochs=epochs * 2,verbose=1,
            callbacks=[
                callbacks.ModelCheckpoint('./models/garclass.h5',monitor='val_loss', save_best_only=True, mode='min'),
                callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1,patience=10, mode='min'),
                callbacks.EarlyStopping(monitor='val_loss', patience=10),],
            validation_data=validation_sequence,max_queue_size=10,shuffle=True)
        print('finish train,look for garclass.h5')

训练结果如下:

"""
    loss: 0.7949 - acc: 0.9494 - val_loss: 0.9900 - val_acc: 0.8797
    训练用了9小时左右
    """

如果使用更好的显卡,可以更快完成训练

最后

希望大家可以体验到深度学习带来的收获,能和大家学习很开心,更多关于深度学习的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
python中mechanize库的简单使用示例
Jan 10 Python
Python教程之全局变量用法
Jun 27 Python
Python实现破解猜数游戏算法示例
Sep 25 Python
python计算两个数的百分比方法
Jun 29 Python
python批量复制图片到另一个文件夹
Sep 17 Python
简单了解Django应用app及分布式路由
Jul 24 Python
python递归下载文件夹下所有文件
Aug 31 Python
Python List列表对象内置方法实例详解
Oct 22 Python
Python decimal模块使用方法详解
Jun 08 Python
python使用Word2Vec进行情感分析解析
Jul 31 Python
Python 常用日期处理 -- calendar 与 dateutil 模块的使用
Sep 02 Python
Python3获取cookie常用三种方案
Oct 05 Python
python3美化表格数据输出结果的实现代码
Apr 14 #Python
Python生成九宫格图片的示例代码
用Python写一个简易版弹球游戏
python urllib库的使用详解
Apr 13 #Python
用Python将库打包发布到pypi
python xlwt模块的使用解析
python 爬取豆瓣网页的示例
You might like
php 数组的一个悲剧?
2011/05/11 PHP
php中关于codeigniter的xmlrpc的类在进行数据交换时的类型问题
2011/07/03 PHP
IIS安装Apache伪静态插件的具体操作图文
2013/07/01 PHP
PHP原生函数一定好吗?
2014/12/08 PHP
javascript中巧用“闭包”实现程序的暂停执行功能
2007/04/04 Javascript
javascript当onmousedown、onmouseup、onclick同时应用于同一个标签节点Element
2010/01/05 Javascript
js客户端快捷键管理类的较完整实现和应用
2010/06/08 Javascript
JS操作JSON方法总结(推荐)
2016/06/14 Javascript
Javascript基础回顾之(二) js作用域
2017/01/31 Javascript
Vue 2.5.2下axios + express 本地请求404的解决方法
2018/02/21 Javascript
解决vue 格式化银行卡(信用卡)每4位一个符号隔断的问题
2018/09/14 Javascript
详解如何使用router-link对象方式传递参数?
2019/05/02 Javascript
JavaScript工具库之Lodash详解
2019/06/15 Javascript
小程序Request的另类用法详解
2019/08/09 Javascript
使用Element的InfiniteScroll 无限滚动组件报错的解决
2020/07/27 Javascript
[07:06]2018DOTA2国际邀请赛寻真——卫冕冠军Team Liquid
2018/08/10 DOTA
python写的ARP攻击代码实例
2014/06/04 Python
Python中List.index()方法的使用教程
2015/05/20 Python
Python的requests网络编程包使用教程
2016/07/11 Python
Python爬虫代理IP池实现方法
2017/01/05 Python
python使用tkinter实现简单计算器
2018/01/30 Python
Python实现JSON反序列化类对象的示例
2018/01/31 Python
python使用opencv驱动摄像头的方法
2018/08/03 Python
PyQt5 对图片进行缩放的实例
2019/06/18 Python
django的model操作汇整详解
2019/07/26 Python
基于python使用tibco ems代码实例
2019/12/20 Python
使用HTML5做的导航条详细步骤
2020/10/19 HTML / CSS
项目经理的岗位职责
2013/11/23 职场文书
法律专业学生的自我评价
2014/02/07 职场文书
销售人员职业生涯规划范文
2014/03/01 职场文书
常务副总经理任命书
2014/06/05 职场文书
珍惜资源的建议书
2014/08/26 职场文书
毕业论文答辩开场白和结束语
2015/05/27 职场文书
发票退票证明
2015/06/24 职场文书
微信小程序结合ThinkPHP5授权登陆后获取手机号
2021/11/23 PHP
详解如何使用Nginx解决跨域问题
2022/05/06 Servers