基于PyTorch实现一个简单的CNN图像分类器


Posted in Python onMay 29, 2021

pytorch中文网:https://www.pytorchtutorial.com/
pytorch官方文档:https://pytorch.org/docs/stable/index.html

一. 加载数据

Pytorch的数据加载一般是用torch.utils.data.Dataset与torch.utils.data.Dataloader两个类联合进行。我们需要继承Dataset来定义自己的数据集类,然后在训练时用Dataloader加载自定义的数据集类。

1. 继承Dataset类并重写关键方法

pytorch的dataset类有两种:Map-style datasets和Iterable-style datasets。前者是我们常用的结构,而后者是当数据集难以(或不可能)进行随机读取时使用。在这里我们实现Map-style dataset。
继承torch.utils.data.Dataset后,需要重写的方法有:__len__与__getitem__方法,其中__len__方法需要返回所有数据的数量,而__getitem__则是要依照给出的数据索引获取对应的tensor类型的Sample,除了这两个方法以外,一般还需要实现__init__方法来初始化一些变量。话不多说,直接上代码。

'''
包括了各种数据集的读取处理,以及图像相关处理方法
'''
from torch.utils.data import Dataset
import torch
import os
import cv2
from Config import mycfg
import random
import numpy as np


class ImageClassifyDataset(Dataset):
    def __init__(self, imagedir, labelfile, classify_num, train=True):
    	'''
    	这里进行一些初始化操作。
    	'''
        self.imagedir = imagedir
        self.labelfile = labelfile
        self.classify_num = classify_num
        self.img_list = []
        # 读取标签
        with open(self.labelfile, 'r') as fp:
            lines = fp.readlines()
            for line in lines:
                filepath = os.path.join(self.imagedir, line.split(";")[0].replace('\\', '/'))
                label = line.split(";")[1].strip('\n')
                self.img_list.append((filepath, label))
        if not train:
            self.img_list = random.sample(self.img_list, 50)

    def __len__(self):
        return len(self.img_list)
        
    def __getitem__(self, item):
	    '''
	    这个函数是关键,通过item(索引)来取数据集中的数据,
	    一般来说在这里才将图像数据加载入内存,之前存的是图像的保存路径
	    '''
        _int_label = int(self.img_list[item][1])	# label直接用0,1,2,3,4...表示不同类别
        label = torch.tensor(_int_label,dtype=torch.long)
        img = self.ProcessImgResize(self.img_list[item][0])
        return img, label

    def ProcessImgResize(self, filename):
    	'''
    	对图像进行一些预处理
    	'''
        _img = cv2.imread(filename)
        _img = cv2.resize(_img, (mycfg.IMG_WIDTH, mycfg.IMG_HEIGHT), interpolation=cv2.INTER_CUBIC)
        _img = _img.transpose((2, 0, 1))
        _img = _img / 255
        _img = torch.from_numpy(_img)
        _img = _img.to(torch.float32)
        return _img

有一些的数据集类一般还会传入一个transforms函数来构造一个图像预处理序列,传入transforms函数的一个好处是作为参数传入的话可以对一些非本地数据集中的数据进行操作(比如直接通过torchvision获取的一些预存数据集CIFAR10等等),除此之外就是torchvision.transforms里面有一些预定义的图像操作函数,可以直接像拼积木一样拼成一个图像处理序列,很方便。我这里因为是用我自己下载到本地的数据集,而且比较简单就直接用自己的函数来操作了。

2. 使用Dataloader加载数据

实例化自定义的数据集类ImageClassifyDataset后,将其传给DataLoader作为参数,得到一个可遍历的数据加载器。可以通过参数batch_size控制批处理大小,shuffle控制是否乱序读取,num_workers控制用于读取数据的线程数量。

from torch.utils.data import DataLoader
from MyDataset import ImageClassifyDataset

dataset = ImageClassifyDataset(imagedir, labelfile, 10)
dataloader = DataLoader(dataset, batch_size=5, shuffle=True,num_workers=5)
for index, data in enumerate(dataloader):
	print(index)	# batch索引
	print(data)		# 一个batch的{img,label}

二. 模型设计

在这里只讨论深度学习模型的设计,pytorch中的网络结构是一层一层叠出来的,pytorch中预定义了许多可以通过参数控制的网络层结构,比如Linear、CNN、RNN、Transformer等等具体可以查阅官方文档中的torch.nn部分。
设计自己的模型结构需要继承torch.nn.Module这个类,然后实现其中的forward方法,一般在__init__中设定好网络模型的一些组件,然后在forward方法中依据输入输出顺序拼装组件。

'''
包括了各种模型、自定义的loss计算方法、optimizer
'''
import torch.nn as nn


class Simple_CNN(nn.Module):
    def __init__(self, class_num):
        super(Simple_CNN, self).__init__()
        self.class_num = class_num
        self.conv1 = nn.Sequential(
            nn.Conv2d(		# input: 3,400,600
                in_channels=3,
                out_channels=8,
                kernel_size=5,
                stride=1,
                padding=2
            ),
            nn.Conv2d(
                in_channels=8,
                out_channels=16,
                kernel_size=5,
                stride=1,
                padding=2
            ),
            nn.AvgPool2d(2),  # 16,400,600 --> 16,200,300
            nn.BatchNorm2d(16),
            nn.LeakyReLU(),
            nn.Conv2d(
                in_channels=16,
                out_channels=16,
                kernel_size=5,
                stride=1,
                padding=2
            ),
            nn.Conv2d(
                in_channels=16,
                out_channels=8,
                kernel_size=5,
                stride=1,
                padding=2
            ),
            nn.AvgPool2d(2),  # 8,200,300 --> 8,100,150
            nn.BatchNorm2d(8),
            nn.LeakyReLU(),
            nn.Conv2d(
                in_channels=8,
                out_channels=8,
                kernel_size=3,
                stride=1,
                padding=1
            ),
            nn.Conv2d(
                in_channels=8,
                out_channels=1,
                kernel_size=3,
                stride=1,
                padding=1
            ),
            nn.AvgPool2d(2),  # 1,100,150 --> 1,50,75
            nn.BatchNorm2d(1),
            nn.LeakyReLU()
        )
        self.line = nn.Sequential(
            nn.Linear(
                in_features=50 * 75,
                out_features=self.class_num
            ),
            nn.Softmax()
        )

    def forward(self, x):
        x = self.conv1(x)
        x = x.view(-1, 50 * 75)
        y = self.line(x)
        return y

上面我定义的模型中包括卷积组件conv1和全连接组件line,卷积组件中包括了一些卷积层,一般是按照{卷积层、池化层、激活函数}的顺序拼接,其中我还在激活函数之前添加了一个BatchNorm2d层对上层的输出进行正则化以免传入激活函数的值过小(梯度消失)或过大(梯度爆炸)。
在拼接组件时,由于我全连接层的输入是一个一维向量,所以需要将卷积组件中最后的50 × 75 50\times 7550×75大小的矩阵展平成一维的再传入全连接层(x.view(-1,50*75))

三. 训练

实例化模型后,网络模型的训练需要定义损失函数与优化器,损失函数定义了网络输出与标签的差距,依据不同的任务需要定义不同的合适的损失函数,而优化器则定义了神经网络中的参数如何基于损失来更新,目前神经网络最常用的优化器就是SGD(随机梯度下降算法) 及其变种。
在我这个简单的分类器模型中,直接用的多分类任务最常用的损失函数CrossEntropyLoss()以及优化器SGD。

self.cnnmodel = Simple_CNN(mycfg.CLASS_NUM)
self.criterion = nn.CrossEntropyLoss()	# 交叉熵,标签应该是0,1,2,3...的形式而不是独热的
self.optimizer = optim.SGD(self.cnnmodel.parameters(), lr=mycfg.LEARNING_RATE, momentum=0.9)

训练过程其实很简单,使用dataloader依照batch读出数据后,将input放入网络模型中计算得到网络的输出,然后基于标签通过损失函数计算Loss,并将Loss反向传播回神经网络(在此之前需要清理上一次循环时的梯度),最后通过优化器更新权重。训练部分代码如下:

for each_epoch in range(mycfg.MAX_EPOCH):
            running_loss = 0.0
            self.cnnmodel.train()
            for index, data in enumerate(self.dataloader):
                inputs, labels = data
                outputs = self.cnnmodel(inputs)
                loss = self.criterion(outputs, labels)

                self.optimizer.zero_grad()	# 清理上一次循环的梯度
                loss.backward()	# 反向传播
                self.optimizer.step()	# 更新参数
                running_loss += loss.item()
                if index % 200 == 199:
                    print("[{}] loss: {:.4f}".format(each_epoch, running_loss/200))
                    running_loss = 0.0
            # 保存每一轮的模型
            model_name = 'classify-{}-{}.pth'.format(each_epoch,round(all_loss/all_index,3))
            torch.save(self.cnnmodel,model_name)	# 保存全部模型

四. 测试

测试和训练的步骤差不多,也就是读取模型后通过dataloader获取数据然后将其输入网络获得输出,但是不需要进行反向传播的等操作了。比较值得注意的可能就是准确率计算方面有一些小技巧。

acc = 0.0
count = 0
self.cnnmodel = torch.load('mymodel.pth')
self.cnnmodel.eval()
for index, data in enumerate(dataloader_eval):
	inputs, labels = data   # 5,3,400,600  5,10
	count += len(labels)
	outputs = cnnmodel(inputs)
	_,predict = torch.max(outputs, 1)
	acc += (labels == predict).sum().item()
print("[{}] accurancy: {:.4f}".format(each_epoch, acc / count))

我这里采用的是保存全部模型并加载全部模型的方法,这种方法的好处是在使用模型时可以完全将其看作一个黑盒,但是在模型比较大时这种方法会很费事。此时可以采用只保存参数不保存网络结构的方法,在每一次使用模型时需要读取参数赋值给已经实例化的模型:

torch.save(cnnmodel.state_dict(), "my_resnet.pth")
cnnmodel = Simple_CNN()
cnnmodel.load_state_dict(torch.load("my_resnet.pth"))

结语

至此整个流程就说完了,是一个小白级的图像分类任务流程,因为前段时间一直在做android方面的事,所以有点生疏了,就写了这篇博客记录一下,之后应该还会写一下seq2seq以及image caption任务方面的模型构造与训练过程,完整代码之后也会统一放到github上给大家做参考。

以上就是基于PyTorch实现一个简单的CNN图像分类器的详细内容,更多关于PyTorch实现CNN图像分类器的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
python通过wxPython打开一个音频文件并播放的方法
Mar 25 Python
在Python的Django框架上部署ORM库的教程
Apr 20 Python
python实现unicode转中文及转换默认编码的方法
Apr 29 Python
Django admin美化插件suit使用示例
Dec 12 Python
python编程使用selenium模拟登陆淘宝实例代码
Jan 25 Python
详解Python正则表达式re模块
Mar 19 Python
python跳出双层for循环的解决方法
Jun 24 Python
python实现键盘输入的实操方法
Jul 16 Python
Python生命游戏实现原理及过程解析(附源代码)
Aug 01 Python
Tensorflow中tf.ConfigProto()的用法详解
Feb 06 Python
Python内置类型集合set和frozenset的使用详解
Apr 26 Python
Python  序列化反序列化和异常处理的问题小结
Dec 24 Python
python 爬取华为应用市场评论
python 开心网和豆瓣日记爬取的小爬虫
May 29 #Python
Python趣味挑战之实现简易版音乐播放器
新手必备Python开发环境搭建教程
Keras多线程机制与flask多线程冲突的解决方案
May 28 #Python
pytorch 6 batch_train 批训练操作
May 28 #Python
pytorch 如何使用batch训练lstm网络
May 28 #Python
You might like
比较完整的微信开发php代码
2016/08/02 PHP
PHP实现链式操作的原理详解
2016/09/16 PHP
浅谈php使用curl模拟多线程发送请求
2019/03/08 PHP
JS操作Cookies包括(读取添加与删除)
2012/12/26 Javascript
JS声明变量背后的编译原理剖析
2012/12/28 Javascript
SOSO地图JS画出标注和中心点以html形式运行
2013/08/09 Javascript
JS实现重新加载当前页面或者父页面的几种方法
2016/11/30 Javascript
微信小程序  http请求封装详解及实例代码
2017/02/15 Javascript
angularJs使用$watch和$filter过滤器制作搜索筛选实例
2017/06/01 Javascript
关于使用axios的一些心得技巧分享
2017/07/02 Javascript
vue中v-model动态生成的实例详解
2017/10/27 Javascript
JS中的事件委托实例浅析
2018/03/22 Javascript
详解create-react-app 自定义 eslint 配置
2018/06/07 Javascript
js canvas画布实现高斯模糊效果
2018/11/27 Javascript
你不知道的Vue技巧之--开发一个可以通过方法调用的组件(推荐)
2019/04/15 Javascript
layui 点击重置按钮, select 并没有被重置的解决方法
2019/09/03 Javascript
JavaScript数值类型知识汇总
2019/11/17 Javascript
js实现九宫格抽奖
2020/03/19 Javascript
Nuxt 嵌套路由nuxt-child组件用法(父子页面组件的传值)
2020/11/05 Javascript
Python的Bottle框架中获取制定cookie的教程
2015/04/24 Python
linux环境下的python安装过程图解(含setuptools)
2017/11/22 Python
python使用pil库实现图片合成实例代码
2018/01/20 Python
解决PySide+Python子线程更新UI线程的问题
2019/01/11 Python
Python3.5 + sklearn利用SVM自动识别字母验证码方法示例
2019/05/10 Python
使用PyQt的QLabel组件实现选定目标框功能的方法示例
2020/05/19 Python
怎样实现H5+CSS3手指滑动切换图片的示例代码
2019/05/05 HTML / CSS
使用canvas压缩图片大小的方法示例
2019/08/02 HTML / CSS
阿拉伯世界最大的电子商务网站:Souq沙特阿拉伯
2016/10/28 全球购物
生物技术研究生自荐信
2013/11/12 职场文书
社会实践心得体会
2014/01/03 职场文书
小学生安全保证书
2014/02/01 职场文书
投标人廉洁自律承诺书
2014/05/26 职场文书
会议邀请函
2015/01/30 职场文书
2016感恩父亲节主题广播稿
2015/12/18 职场文书
LayUI+Shiro实现动态菜单并记住菜单收展的示例
2021/05/06 Javascript
HTML基础详解(上)
2021/10/16 HTML / CSS