深度学习小工程练习之垃圾分类详解


Posted in Python onApril 14, 2021

介绍

这是一个基于深度学习的垃圾分类小工程,用深度残差网络构建

软件架构

  1. 使用深度残差网络resnet50作为基石,在后续添加需要的层以适应不同的分类任务
  2. 模型的训练需要用生成器将数据集循环写入内存,同时图像增强以泛化模型
  3. 使用不包含网络输出部分的resnet50权重文件进行迁移学习,只训练我们在5个stage后增加的层

安装教程

  1. 需要的第三方库主要有tensorflow1.x,keras,opencv,Pillow,scikit-learn,numpy
  2. 安装方式很简单,打开terminal,例如:pip install numpy -i https://pypi.tuna.tsinghua.edu.cn/simple
  3. 数据集与权重文件比较大,所以没有上传
  4. 如果环境配置方面有问题或者需要数据集与模型权重文件,可以在评论区说明您的问题,我将远程帮助您

使用说明

  1. 文件夹theory记录了我在本次深度学习中收获的笔记,与模型训练的控制台打印信息
  2. 迁移学习需要的初始权重与模型定义文件resnet50.py放在model
  3. 下训练运行trainNet.py,训练结束会创建models文件夹,并将结果权重garclass.h5写入该文件夹
  4. datagen文件夹下的genit.py用于进行图像预处理以及数据生成器接口
  5. 使用训练好的模型进行垃圾分类,运行Demo.py

结果演示

深度学习小工程练习之垃圾分类详解

cans易拉罐

深度学习小工程练习之垃圾分类详解

代码解释

在实际的模型中,我们只使用了resnet50的5个stage,后面的输出部分需要我们自己定制,网络的结构图如下:

深度学习小工程练习之垃圾分类详解

stage5后我们的定制网络如下:

"""定制resnet后面的层"""
def custom(input_size,num_classes,pretrain):
    # 引入初始化resnet50模型
    base_model = ResNet50(weights=pretrain,
                          include_top=False,
                          pooling=None,
                          input_shape=(input_size,input_size, 3),
                          classes=num_classes)
    #由于有预权重,前部分冻结,后面进行迁移学习
    for layer in base_model.layers:
        layer.trainable = False
    #添加后面的层
    x = base_model.output
    x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
    x = layers.Dropout(0.5,name='dropout1')(x)
    #regularizers正则化层,正则化器允许在优化过程中对层的参数或层的激活情况进行惩罚
    #对损失函数进行最小化的同时,也需要让对参数添加限制,这个限制也就是正则化惩罚项,使用l2范数
    x = layers.Dense(512,activation='relu',kernel_regularizer= regularizers.l2(0.0001),name='fc2')(x)
    x = layers.BatchNormalization(name='bn_fc_01')(x)
    x = layers.Dropout(0.5,name='dropout2')(x)
    #40个分类
    x = layers.Dense(num_classes,activation='softmax')(x)
    model = Model(inputs=base_model.input,outputs=x)
    #模型编译
    model.compile(optimizer="adam",loss = 'categorical_crossentropy',metrics=['accuracy'])
    return model

网络的训练是迁移学习过程,使用已有的初始resnet50权重(5个stage已经训练过,卷积层已经能够提取特征),我们只训练后面的全连接层部分,4个epoch后再对较后面的层进行训练微调一下,获得更高准确率,训练过程如下:

class Net():
    def __init__(self,img_size,gar_num,data_dir,batch_size,pretrain):
        self.img_size=img_size
        self.gar_num=gar_num
        self.data_dir=data_dir
        self.batch_size=batch_size
        self.pretrain=pretrain
    def build_train(self):
        """迁移学习"""
        model = resnet.custom(self.img_size, self.gar_num, self.pretrain)
        model.summary()
        train_sequence, validation_sequence = genit.gendata(self.data_dir, self.batch_size, self.gar_num, self.img_size)
        epochs=4
        model.fit_generator(train_sequence,steps_per_epoch=len(train_sequence),epochs=epochs,verbose=1,validation_data=validation_sequence,
                                     max_queue_size=10,shuffle=True)
        #微调,在实际工程中,激活函数也被算进层里,所以总共181层,微调是为了重新训练部分卷积层,同时训练最后的全连接层
        layers=149
        learning_rate=1e-4
        for layer in model.layers[:layers]:
            layer.trainable = False
        for layer in model.layers[layers:]:
            layer.trainable = True
        Adam =adam(lr=learning_rate, decay=0.0005)
        model.compile(optimizer=Adam, loss='categorical_crossentropy', metrics=['accuracy'])
        model.fit_generator(train_sequence,steps_per_epoch=len(train_sequence),epochs=epochs * 2,verbose=1,
            callbacks=[
                callbacks.ModelCheckpoint('./models/garclass.h5',monitor='val_loss', save_best_only=True, mode='min'),
                callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1,patience=10, mode='min'),
                callbacks.EarlyStopping(monitor='val_loss', patience=10),],
            validation_data=validation_sequence,max_queue_size=10,shuffle=True)
        print('finish train,look for garclass.h5')

训练结果如下:

"""
    loss: 0.7949 - acc: 0.9494 - val_loss: 0.9900 - val_acc: 0.8797
    训练用了9小时左右
    """

如果使用更好的显卡,可以更快完成训练

最后

希望大家可以体验到深度学习带来的收获,能和大家学习很开心,更多关于深度学习的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
python中查看变量内存地址的方法
May 05 Python
python操作mysql数据库
Mar 05 Python
Python Flask前后端Ajax交互的方法示例
Jul 31 Python
python 异或加密字符串的实例
Oct 14 Python
Python 实现中值滤波、均值滤波的方法
Jan 09 Python
Python os.access()用法实例
Feb 18 Python
python爬取酷狗音乐排行榜
Feb 20 Python
Python使用python-docx读写word文档
Aug 26 Python
Python3实现zip分卷压缩过程解析
Oct 09 Python
python 基于opencv操作摄像头
Dec 24 Python
解决pytorch 的state_dict()拷贝问题
Mar 03 Python
快速一键生成Python爬虫请求头
Mar 04 Python
python3美化表格数据输出结果的实现代码
Apr 14 #Python
Python生成九宫格图片的示例代码
用Python写一个简易版弹球游戏
python urllib库的使用详解
Apr 13 #Python
用Python将库打包发布到pypi
python xlwt模块的使用解析
python 爬取豆瓣网页的示例
You might like
laravel框架添加数据,显示数据,返回成功值的方法
2019/10/11 PHP
Mootools 1.2教程 排序类和方法简介
2009/09/15 Javascript
javascript之通用简单的table选项卡实现(二)
2010/05/09 Javascript
js 金额格式化来回转换示例
2014/02/23 Javascript
JavaScript中的索引数组、关联数组和静态数组、动态数组讲解
2014/11/08 Javascript
JavaScript解析json格式数据简单示例
2014/12/09 Javascript
jquery判断至少有一个checkbox被选中的方法
2015/06/05 Javascript
Javascript生成带参数的二维码示例
2016/10/10 Javascript
使用requirejs模块化开发多页面一个入口js的使用方式
2017/06/14 Javascript
Kindeditor单独调用多图上传实例
2017/07/31 Javascript
NestJs 静态目录配置详解
2019/03/12 Javascript
详解微信小程序支付流程与梳理
2019/07/16 Javascript
Nodejs中使用puppeteer控制浏览器中视频播放功能
2019/08/26 NodeJs
VUE单页面切换动画代码(全网最好的切换效果)
2019/10/31 Javascript
解决Vue router-link绑定事件不生效的问题
2020/07/22 Javascript
浅谈vue 多个变量同时赋相同值互相影响
2020/08/05 Javascript
Chrome插件开发系列一:弹窗终结者开发实战
2020/10/02 Javascript
python自动化工具日志查询分析脚本代码实现
2013/11/26 Python
Python自动发送邮件的方法实例总结
2018/12/08 Python
python实现合并两个排序的链表
2019/03/03 Python
详解Python Matplotlib解决绘图X轴值不按数组排序问题
2019/08/05 Python
利用python实现周期财务统计可视化
2019/08/25 Python
使用Matplotlib 绘制精美的数学图形例子
2019/12/13 Python
Pycharm 2020最新永久激活码(附最新激活码和插件)
2020/09/17 Python
2020最新pycharm汉化安装(python工程狮亲测有效)
2020/04/26 Python
哥伦比亚最大的网上商店:Linio哥伦比亚
2016/09/25 全球购物
Sephora丝芙兰菲律宾官方网站:购买化妆品和护肤品
2017/04/05 全球购物
波兰最大的宠物用品网上商店:FERA.PL
2019/08/11 全球购物
《小熊住山洞》教学反思
2014/02/21 职场文书
2014植树节活动总结
2014/03/11 职场文书
教育实践活动对照检查材料
2014/09/23 职场文书
2014年团队工作总结
2014/11/24 职场文书
居委会工作总结2015
2015/05/18 职场文书
2016同学毕业寄语大全
2015/12/04 职场文书
Nginx设置日志打印post请求参数的方法
2021/03/31 Servers
如何通过简单的代码描述Angular父组件、子组件传值
2022/04/07 Javascript