深度学习小工程练习之垃圾分类详解


Posted in Python onApril 14, 2021

介绍

这是一个基于深度学习的垃圾分类小工程,用深度残差网络构建

软件架构

  1. 使用深度残差网络resnet50作为基石,在后续添加需要的层以适应不同的分类任务
  2. 模型的训练需要用生成器将数据集循环写入内存,同时图像增强以泛化模型
  3. 使用不包含网络输出部分的resnet50权重文件进行迁移学习,只训练我们在5个stage后增加的层

安装教程

  1. 需要的第三方库主要有tensorflow1.x,keras,opencv,Pillow,scikit-learn,numpy
  2. 安装方式很简单,打开terminal,例如:pip install numpy -i https://pypi.tuna.tsinghua.edu.cn/simple
  3. 数据集与权重文件比较大,所以没有上传
  4. 如果环境配置方面有问题或者需要数据集与模型权重文件,可以在评论区说明您的问题,我将远程帮助您

使用说明

  1. 文件夹theory记录了我在本次深度学习中收获的笔记,与模型训练的控制台打印信息
  2. 迁移学习需要的初始权重与模型定义文件resnet50.py放在model
  3. 下训练运行trainNet.py,训练结束会创建models文件夹,并将结果权重garclass.h5写入该文件夹
  4. datagen文件夹下的genit.py用于进行图像预处理以及数据生成器接口
  5. 使用训练好的模型进行垃圾分类,运行Demo.py

结果演示

深度学习小工程练习之垃圾分类详解

cans易拉罐

深度学习小工程练习之垃圾分类详解

代码解释

在实际的模型中,我们只使用了resnet50的5个stage,后面的输出部分需要我们自己定制,网络的结构图如下:

深度学习小工程练习之垃圾分类详解

stage5后我们的定制网络如下:

"""定制resnet后面的层"""
def custom(input_size,num_classes,pretrain):
    # 引入初始化resnet50模型
    base_model = ResNet50(weights=pretrain,
                          include_top=False,
                          pooling=None,
                          input_shape=(input_size,input_size, 3),
                          classes=num_classes)
    #由于有预权重,前部分冻结,后面进行迁移学习
    for layer in base_model.layers:
        layer.trainable = False
    #添加后面的层
    x = base_model.output
    x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
    x = layers.Dropout(0.5,name='dropout1')(x)
    #regularizers正则化层,正则化器允许在优化过程中对层的参数或层的激活情况进行惩罚
    #对损失函数进行最小化的同时,也需要让对参数添加限制,这个限制也就是正则化惩罚项,使用l2范数
    x = layers.Dense(512,activation='relu',kernel_regularizer= regularizers.l2(0.0001),name='fc2')(x)
    x = layers.BatchNormalization(name='bn_fc_01')(x)
    x = layers.Dropout(0.5,name='dropout2')(x)
    #40个分类
    x = layers.Dense(num_classes,activation='softmax')(x)
    model = Model(inputs=base_model.input,outputs=x)
    #模型编译
    model.compile(optimizer="adam",loss = 'categorical_crossentropy',metrics=['accuracy'])
    return model

网络的训练是迁移学习过程,使用已有的初始resnet50权重(5个stage已经训练过,卷积层已经能够提取特征),我们只训练后面的全连接层部分,4个epoch后再对较后面的层进行训练微调一下,获得更高准确率,训练过程如下:

class Net():
    def __init__(self,img_size,gar_num,data_dir,batch_size,pretrain):
        self.img_size=img_size
        self.gar_num=gar_num
        self.data_dir=data_dir
        self.batch_size=batch_size
        self.pretrain=pretrain
    def build_train(self):
        """迁移学习"""
        model = resnet.custom(self.img_size, self.gar_num, self.pretrain)
        model.summary()
        train_sequence, validation_sequence = genit.gendata(self.data_dir, self.batch_size, self.gar_num, self.img_size)
        epochs=4
        model.fit_generator(train_sequence,steps_per_epoch=len(train_sequence),epochs=epochs,verbose=1,validation_data=validation_sequence,
                                     max_queue_size=10,shuffle=True)
        #微调,在实际工程中,激活函数也被算进层里,所以总共181层,微调是为了重新训练部分卷积层,同时训练最后的全连接层
        layers=149
        learning_rate=1e-4
        for layer in model.layers[:layers]:
            layer.trainable = False
        for layer in model.layers[layers:]:
            layer.trainable = True
        Adam =adam(lr=learning_rate, decay=0.0005)
        model.compile(optimizer=Adam, loss='categorical_crossentropy', metrics=['accuracy'])
        model.fit_generator(train_sequence,steps_per_epoch=len(train_sequence),epochs=epochs * 2,verbose=1,
            callbacks=[
                callbacks.ModelCheckpoint('./models/garclass.h5',monitor='val_loss', save_best_only=True, mode='min'),
                callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1,patience=10, mode='min'),
                callbacks.EarlyStopping(monitor='val_loss', patience=10),],
            validation_data=validation_sequence,max_queue_size=10,shuffle=True)
        print('finish train,look for garclass.h5')

训练结果如下:

"""
    loss: 0.7949 - acc: 0.9494 - val_loss: 0.9900 - val_acc: 0.8797
    训练用了9小时左右
    """

如果使用更好的显卡,可以更快完成训练

最后

希望大家可以体验到深度学习带来的收获,能和大家学习很开心,更多关于深度学习的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
使用PyCharm配合部署Python的Django框架的配置纪实
Nov 19 Python
Python中struct模块对字节流/二进制流的操作教程
Jan 21 Python
Python 单元测试(unittest)的使用小结
Nov 14 Python
对python For 循环的三种遍历方式解析
Feb 01 Python
Python实现字符型图片验证码识别完整过程详解
May 10 Python
Python操作excel的方法总结(xlrd、xlwt、openpyxl)
Sep 02 Python
pytorch 图像中的数据预处理和批标准化实例
Jan 15 Python
浅谈tensorflow 中的图片读取和裁剪方式
Jun 30 Python
Python 如何对文件目录操作
Jul 10 Python
用python实现前向分词最大匹配算法的示例代码
Aug 06 Python
Python importlib模块重载使用方法详解
Oct 13 Python
python脚本定时发送邮件
Dec 22 Python
python3美化表格数据输出结果的实现代码
Apr 14 #Python
Python生成九宫格图片的示例代码
用Python写一个简易版弹球游戏
python urllib库的使用详解
Apr 13 #Python
用Python将库打包发布到pypi
python xlwt模块的使用解析
python 爬取豆瓣网页的示例
You might like
PHP下操作Linux消息队列完成进程间通信的方法
2010/07/24 PHP
通过PHP的内置函数,通过DES算法对数据加密和解密
2012/06/21 PHP
php读取excel文件的简单实例
2013/08/26 PHP
PHP实现的比较完善的购物车类
2014/12/02 PHP
php版微信开发Token验证失败或请求URL超时问题的解决方法
2016/09/23 PHP
总结AJAX相关JS代码片段和浏览器模型
2007/08/15 Javascript
JavaScript高级程序设计 学习笔记 js高级技巧
2011/09/20 Javascript
JavaScript中的面向对象介绍
2012/06/30 Javascript
chrome不支持form.submit的解决方案
2015/04/28 Javascript
简单对比分析JavaScript中的apply,call与this的使用
2015/12/04 Javascript
浅谈JavaScript中的分支结构
2016/07/01 Javascript
JS 滚动事件window.onscroll与position:fixed写兼容IE6的回到顶部组件
2016/10/10 Javascript
基于input框覆盖掉数字英文的实例讲解
2017/07/21 Javascript
Angular 4.x+Ionic3踩坑之Ionic3.x pop反向传值详解
2018/03/13 Javascript
简单的三步vuex入门
2018/05/20 Javascript
详解auto-vue-file:一个自动创建vue组件的包
2019/04/26 Javascript
微信小程序人脸识别功能代码实例
2019/05/07 Javascript
解决layui的form里的元素进行动态生成,验证失效的问题
2019/09/14 Javascript
JavaScript中的相等操作符使用详解
2019/12/21 Javascript
[04:14]从西雅图到上海——玩家自制DOTA2主题歌曲应援TI9
2019/07/11 DOTA
[46:14]完美世界DOTA2联赛PWL S3 Magma vs INK ICE 第一场 12.11
2020/12/16 DOTA
Python3 伪装浏览器的方法示例
2017/11/23 Python
Django框架中间件(Middleware)用法实例分析
2019/05/24 Python
简单介绍一下pyinstaller打包以及安全性的实现
2020/06/02 Python
python实现PolynomialFeatures多项式的方法
2021/01/06 Python
解决tensorflow模型压缩的问题_踩坑无数,总算搞定
2021/03/02 Python
利用 Canvas实现绘画一个未闭合的带进度条的圆环
2019/07/26 HTML / CSS
新加坡最佳婴儿用品店:Mamahood.com.sg
2018/08/26 全球购物
罗兰·穆雷官网:Roland Mouret
2018/09/28 全球购物
STP的判定过程
2012/10/01 面试题
高中军训的心得体会
2014/09/01 职场文书
家长通知书家长意见
2015/06/03 职场文书
四大名著读书笔记
2015/06/25 职场文书
课改心得体会范文
2016/01/25 职场文书
MySQL大小写敏感的注意事项
2021/05/24 MySQL
Python开发之QT解决无边框界面拖动卡屏问题(附带源码)
2021/05/27 Python