深度学习小工程练习之垃圾分类详解


Posted in Python onApril 14, 2021

介绍

这是一个基于深度学习的垃圾分类小工程,用深度残差网络构建

软件架构

  1. 使用深度残差网络resnet50作为基石,在后续添加需要的层以适应不同的分类任务
  2. 模型的训练需要用生成器将数据集循环写入内存,同时图像增强以泛化模型
  3. 使用不包含网络输出部分的resnet50权重文件进行迁移学习,只训练我们在5个stage后增加的层

安装教程

  1. 需要的第三方库主要有tensorflow1.x,keras,opencv,Pillow,scikit-learn,numpy
  2. 安装方式很简单,打开terminal,例如:pip install numpy -i https://pypi.tuna.tsinghua.edu.cn/simple
  3. 数据集与权重文件比较大,所以没有上传
  4. 如果环境配置方面有问题或者需要数据集与模型权重文件,可以在评论区说明您的问题,我将远程帮助您

使用说明

  1. 文件夹theory记录了我在本次深度学习中收获的笔记,与模型训练的控制台打印信息
  2. 迁移学习需要的初始权重与模型定义文件resnet50.py放在model
  3. 下训练运行trainNet.py,训练结束会创建models文件夹,并将结果权重garclass.h5写入该文件夹
  4. datagen文件夹下的genit.py用于进行图像预处理以及数据生成器接口
  5. 使用训练好的模型进行垃圾分类,运行Demo.py

结果演示

深度学习小工程练习之垃圾分类详解

cans易拉罐

深度学习小工程练习之垃圾分类详解

代码解释

在实际的模型中,我们只使用了resnet50的5个stage,后面的输出部分需要我们自己定制,网络的结构图如下:

深度学习小工程练习之垃圾分类详解

stage5后我们的定制网络如下:

"""定制resnet后面的层"""
def custom(input_size,num_classes,pretrain):
    # 引入初始化resnet50模型
    base_model = ResNet50(weights=pretrain,
                          include_top=False,
                          pooling=None,
                          input_shape=(input_size,input_size, 3),
                          classes=num_classes)
    #由于有预权重,前部分冻结,后面进行迁移学习
    for layer in base_model.layers:
        layer.trainable = False
    #添加后面的层
    x = base_model.output
    x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
    x = layers.Dropout(0.5,name='dropout1')(x)
    #regularizers正则化层,正则化器允许在优化过程中对层的参数或层的激活情况进行惩罚
    #对损失函数进行最小化的同时,也需要让对参数添加限制,这个限制也就是正则化惩罚项,使用l2范数
    x = layers.Dense(512,activation='relu',kernel_regularizer= regularizers.l2(0.0001),name='fc2')(x)
    x = layers.BatchNormalization(name='bn_fc_01')(x)
    x = layers.Dropout(0.5,name='dropout2')(x)
    #40个分类
    x = layers.Dense(num_classes,activation='softmax')(x)
    model = Model(inputs=base_model.input,outputs=x)
    #模型编译
    model.compile(optimizer="adam",loss = 'categorical_crossentropy',metrics=['accuracy'])
    return model

网络的训练是迁移学习过程,使用已有的初始resnet50权重(5个stage已经训练过,卷积层已经能够提取特征),我们只训练后面的全连接层部分,4个epoch后再对较后面的层进行训练微调一下,获得更高准确率,训练过程如下:

class Net():
    def __init__(self,img_size,gar_num,data_dir,batch_size,pretrain):
        self.img_size=img_size
        self.gar_num=gar_num
        self.data_dir=data_dir
        self.batch_size=batch_size
        self.pretrain=pretrain
    def build_train(self):
        """迁移学习"""
        model = resnet.custom(self.img_size, self.gar_num, self.pretrain)
        model.summary()
        train_sequence, validation_sequence = genit.gendata(self.data_dir, self.batch_size, self.gar_num, self.img_size)
        epochs=4
        model.fit_generator(train_sequence,steps_per_epoch=len(train_sequence),epochs=epochs,verbose=1,validation_data=validation_sequence,
                                     max_queue_size=10,shuffle=True)
        #微调,在实际工程中,激活函数也被算进层里,所以总共181层,微调是为了重新训练部分卷积层,同时训练最后的全连接层
        layers=149
        learning_rate=1e-4
        for layer in model.layers[:layers]:
            layer.trainable = False
        for layer in model.layers[layers:]:
            layer.trainable = True
        Adam =adam(lr=learning_rate, decay=0.0005)
        model.compile(optimizer=Adam, loss='categorical_crossentropy', metrics=['accuracy'])
        model.fit_generator(train_sequence,steps_per_epoch=len(train_sequence),epochs=epochs * 2,verbose=1,
            callbacks=[
                callbacks.ModelCheckpoint('./models/garclass.h5',monitor='val_loss', save_best_only=True, mode='min'),
                callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1,patience=10, mode='min'),
                callbacks.EarlyStopping(monitor='val_loss', patience=10),],
            validation_data=validation_sequence,max_queue_size=10,shuffle=True)
        print('finish train,look for garclass.h5')

训练结果如下:

"""
    loss: 0.7949 - acc: 0.9494 - val_loss: 0.9900 - val_acc: 0.8797
    训练用了9小时左右
    """

如果使用更好的显卡,可以更快完成训练

最后

希望大家可以体验到深度学习带来的收获,能和大家学习很开心,更多关于深度学习的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
Python StringIO模块实现在内存缓冲区中读写数据
Apr 08 Python
Python中matplotlib中文乱码解决办法
May 12 Python
Python3安装Pymongo详细步骤
May 26 Python
Python获取本机所有网卡ip,掩码和广播地址实例代码
Jan 22 Python
Python FTP两个文件夹间的同步实例代码
May 25 Python
pytorch+lstm实现的pos示例
Jan 14 Python
Python 多线程共享变量的实现示例
Apr 17 Python
Python实现封装打包自己写的代码,被python import
Jul 12 Python
python3 中使用urllib问题以及urllib详解
Aug 03 Python
python Scrapy框架原理解析
Jan 04 Python
详解分布式系统中如何用python实现Paxos
May 18 Python
Python采集壁纸并实现炫轮播
Apr 30 Python
python3美化表格数据输出结果的实现代码
Apr 14 #Python
Python生成九宫格图片的示例代码
用Python写一个简易版弹球游戏
python urllib库的使用详解
Apr 13 #Python
用Python将库打包发布到pypi
python xlwt模块的使用解析
python 爬取豆瓣网页的示例
You might like
php intval的测试代码发现问题
2008/07/27 PHP
PHP图片验证码制作实现分享(全)
2012/05/10 PHP
深入了解 register_globals (附register_globals=off 网站打不开的解决方法)
2012/06/27 PHP
zf框架的registry(注册表)使用示例
2014/03/13 PHP
php输出xml必须header的解决方法
2014/10/17 PHP
php网站被挂木马后的修复方法总结
2014/11/06 PHP
优秀js开源框架-jQuery使用手册(1)
2007/03/10 Javascript
javascript div 遮罩层封锁整个页面
2009/07/10 Javascript
JS,Jquery获取select,dropdownlist,checkbox下拉列表框的值(示例代码)
2014/01/11 Javascript
JavaScript实现检查页面上的广告是否被AdBlock屏蔽了的方法
2014/11/03 Javascript
Javascript函数的参数
2015/07/16 Javascript
jquery实现仿新浪微博带动画效果弹出层代码(可关闭、可拖动)
2015/10/12 Javascript
JS操作XML实例总结(加载与解析XML文件、字符串)
2015/12/08 Javascript
javascript实现拖放效果
2015/12/16 Javascript
浅谈JS中的三种字符串连接方式及其性能比较
2016/09/02 Javascript
jQuery插件artDialog.js使用与关闭方法示例
2017/10/09 jQuery
Bootstrap Table实现定时刷新数据的方法
2018/08/13 Javascript
vscode 插件开发 + vue的操作方法
2020/06/05 Javascript
详解ES6 扩展运算符的使用与注意事项
2020/11/12 Javascript
Python文件夹与文件的操作实现代码
2014/07/13 Python
Python3.5 Pandas模块之DataFrame用法实例分析
2019/04/23 Python
Python 支持向量机分类器的实现
2020/01/15 Python
HTML5 Canvas绘制圆点虚线实例
2015/01/01 HTML / CSS
英国最大的电脑零售连锁店集团:PC World
2016/10/10 全球购物
美国时尚假发购物网站:Wigsbuy
2019/04/06 全球购物
美国运动鞋类和服装零售连锁店:Shoe Palace
2019/08/13 全球购物
SQL Server提供的3种恢复模型都是什么? 有什么区别?
2012/05/13 面试题
心理健康课教学反思
2014/02/13 职场文书
小学数学国培感言
2014/03/10 职场文书
村班子对照检查材料
2014/08/18 职场文书
2015年人力资源部工作总结
2015/04/30 职场文书
行政申诉状范文
2015/05/20 职场文书
婚礼家长致辞
2015/07/27 职场文书
小学三年级作文之写景
2019/11/05 职场文书
CSS的class与id常用的命名规则
2021/05/18 HTML / CSS
CSS+HTML 实现顶部导航栏功能
2021/08/30 HTML / CSS