Pytorch 使用CNN图像分类的实现


Posted in Python onJune 16, 2020

需求

在4*4的图片中,比较外围黑色像素点和内圈黑色像素点个数的大小将图片分类

Pytorch 使用CNN图像分类的实现

如上图图片外围黑色像素点5个大于内圈黑色像素点1个分为0类反之1类

想法

  • 通过numpy、PIL构造4*4的图像数据集
  • 构造自己的数据集类
  • 读取数据集对数据集选取减少偏斜
  • cnn设计因为特征少,直接1*1卷积层
  • 或者在4*4外围添加padding成6*6,设计2*2的卷积核得出3*3再接上全连接层

代码

import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
from PIL import Image

构造数据集

import csv
import collections
import os
import shutil

def buildDataset(root,dataType,dataSize):
  """构造数据集
  构造的图片存到root/{dataType}Data
  图片地址和标签的csv文件存到 root/{dataType}DataInfo.csv
  Args:
    root:str
      项目目录
    dataType:str
      'train'或者‘test'
    dataNum:int
      数据大小
  Returns:
  """
  dataInfo = []
  dataPath = f'{root}/{dataType}Data'
  if not os.path.exists(dataPath):
    os.makedirs(dataPath)
  else:
    shutil.rmtree(dataPath)
    os.mkdir(dataPath)
    
  for i in range(dataSize):
    # 创建0,1 数组
    imageArray=np.random.randint(0,2,(4,4))
    # 计算0,1数量得到标签
    allBlackNum = collections.Counter(imageArray.flatten())[0]
    innerBlackNum = collections.Counter(imageArray[1:3,1:3].flatten())[0]
    label = 0 if (allBlackNum-innerBlackNum)>innerBlackNum else 1
    # 将图片保存
    path = f'{dataPath}/{i}.jpg'
    dataInfo.append([path,label])
    im = Image.fromarray(np.uint8(imageArray*255))
    im = im.convert('1') 
    im.save(path)
  # 将图片地址和标签存入csv文件
  filePath = f'{root}/{dataType}DataInfo.csv'
  with open(filePath, 'w') as f:
    writer = csv.writer(f)
    writer.writerows(dataInfo)
root=r'/Users/null/Documents/PythonProject/Classifier'

构造训练数据集

buildDataset(root,'train',20000)

构造测试数据集

buildDataset(root,'test',10000)

读取数据集

class MyDataset(torch.utils.data.Dataset):

  def __init__(self, root, datacsv, transform=None):
    super(MyDataset, self).__init__()
    with open(f'{root}/{datacsv}', 'r') as f:
      imgs = []
      # 读取csv信息到imgs列表
      for path,label in map(lambda line:line.rstrip().split(','),f):
        imgs.append((path, int(label)))
    self.imgs = imgs
    self.transform = transform if transform is not None else lambda x:x
    
  def __getitem__(self, index):
    path, label = self.imgs[index]
    img = self.transform(Image.open(path).convert('1'))
    return img, label

  def __len__(self):
    return len(self.imgs)
trainData=MyDataset(root = root,datacsv='trainDataInfo.csv', transform=transforms.ToTensor())
testData=MyDataset(root = root,datacsv='testDataInfo.csv', transform=transforms.ToTensor())

处理数据集使得数据集不偏斜

import itertools

def chooseData(dataset,scale):
  # 将类别为1的排序到前面
  dataset.imgs.sort(key=lambda x:x[1],reverse=True)
  # 获取类别1的数目 ,取scale倍的数组,得数据不那么偏斜
  trueNum =collections.Counter(itertools.chain.from_iterable(dataset.imgs))[1]
  end = min(trueNum*scale,len(dataset))
  dataset.imgs=dataset.imgs[:end]
scale = 4
chooseData(trainData,scale)
chooseData(testData,scale)
len(trainData),len(testData)
(2250, 1122)
import torch.utils.data as Data

# 超参数
batchSize = 50
lr = 0.1
numEpochs = 20

trainIter = Data.DataLoader(dataset=trainData, batch_size=batchSize, shuffle=True)
testIter = Data.DataLoader(dataset=testData, batch_size=batchSize)

定义模型

from torch import nn
from torch.autograd import Variable
from torch.nn import Module,Linear,Sequential,Conv2d,ReLU,ConstantPad2d
import torch.nn.functional as F
class Net(Module):  
  def __init__(self):
    super(Net, self).__init__()

    self.cnnLayers = Sequential(
      # padding添加1层常数1,设定卷积核为2*2
      ConstantPad2d(1, 1),
      Conv2d(1, 1, kernel_size=2, stride=2,bias=True)
    )
    self.linearLayers = Sequential(
      Linear(9, 2)
    )

  def forward(self, x):
    x = self.cnnLayers(x)
    x = x.view(x.shape[0], -1)
    x = self.linearLayers(x)
    return x
class Net2(Module):  
  def __init__(self):
    super(Net2, self).__init__()

    self.cnnLayers = Sequential(
      Conv2d(1, 1, kernel_size=1, stride=1,bias=True)
    )
    self.linearLayers = Sequential(
      ReLU(),
      Linear(16, 2)
    )

  def forward(self, x):
    x = self.cnnLayers(x)
    x = x.view(x.shape[0], -1)
    x = self.linearLayers(x)
    return x

定义损失函数

# 交叉熵损失函数
loss = nn.CrossEntropyLoss()
loss2 = nn.CrossEntropyLoss()

定义优化算法

net = Net()
optimizer = torch.optim.SGD(net.parameters(),lr = lr)
net2 = Net2()
optimizer2 = torch.optim.SGD(net2.parameters(),lr = lr)

训练模型

# 计算准确率
def evaluateAccuracy(dataIter, net):
  accSum, n = 0.0, 0
  with torch.no_grad():
    for X, y in dataIter:
      accSum += (net(X).argmax(dim=1) == y).float().sum().item()
      n += y.shape[0]
  return accSum / n
def train(net, trainIter, testIter, loss, numEpochs, batchSize,
       optimizer):
  for epoch in range(numEpochs):
    trainLossSum, trainAccSum, n = 0.0, 0.0, 0
    for X,y in trainIter:
      yHat = net(X)
      l = loss(yHat,y).sum()
      optimizer.zero_grad()
      l.backward()
      optimizer.step()
      # 计算训练准确度和loss
      trainLossSum += l.item()
      trainAccSum += (yHat.argmax(dim=1) == y).sum().item()
      n += y.shape[0]
    # 评估测试准确度
    testAcc = evaluateAccuracy(testIter, net)
    print('epoch {:d}, loss {:.4f}, train acc {:.3f}, test acc {:.3f}'.format(epoch + 1, trainLossSum / n, trainAccSum / n, testAcc))

Net模型训练

train(net, trainIter, testIter, loss, numEpochs, batchSize,optimizer)
epoch 1, loss 0.0128, train acc 0.667, test acc 0.667
epoch 2, loss 0.0118, train acc 0.683, test acc 0.760
epoch 3, loss 0.0104, train acc 0.742, test acc 0.807
epoch 4, loss 0.0093, train acc 0.769, test acc 0.772
epoch 5, loss 0.0085, train acc 0.797, test acc 0.745
epoch 6, loss 0.0084, train acc 0.798, test acc 0.807
epoch 7, loss 0.0082, train acc 0.804, test acc 0.816
epoch 8, loss 0.0078, train acc 0.816, test acc 0.812
epoch 9, loss 0.0077, train acc 0.818, test acc 0.817
epoch 10, loss 0.0074, train acc 0.824, test acc 0.826
epoch 11, loss 0.0072, train acc 0.836, test acc 0.819
epoch 12, loss 0.0075, train acc 0.823, test acc 0.829
epoch 13, loss 0.0071, train acc 0.839, test acc 0.797
epoch 14, loss 0.0067, train acc 0.849, test acc 0.824
epoch 15, loss 0.0069, train acc 0.848, test acc 0.843
epoch 16, loss 0.0064, train acc 0.864, test acc 0.851
epoch 17, loss 0.0062, train acc 0.867, test acc 0.780
epoch 18, loss 0.0060, train acc 0.871, test acc 0.864
epoch 19, loss 0.0057, train acc 0.881, test acc 0.890
epoch 20, loss 0.0055, train acc 0.885, test acc 0.897

Net2模型训练

# batchSize = 50 
# lr = 0.1
# numEpochs = 15 下得出的结果
train(net2, trainIter, testIter, loss2, numEpochs, batchSize,optimizer2)

epoch 1, loss 0.0119, train acc 0.638, test acc 0.676
epoch 2, loss 0.0079, train acc 0.823, test acc 0.986
epoch 3, loss 0.0046, train acc 0.987, test acc 0.977
epoch 4, loss 0.0030, train acc 0.983, test acc 0.973
epoch 5, loss 0.0023, train acc 0.981, test acc 0.976
epoch 6, loss 0.0019, train acc 0.980, test acc 0.988
epoch 7, loss 0.0016, train acc 0.984, test acc 0.984
epoch 8, loss 0.0014, train acc 0.985, test acc 0.986
epoch 9, loss 0.0013, train acc 0.987, test acc 0.992
epoch 10, loss 0.0011, train acc 0.989, test acc 0.993
epoch 11, loss 0.0010, train acc 0.989, test acc 0.996
epoch 12, loss 0.0010, train acc 0.992, test acc 0.994
epoch 13, loss 0.0009, train acc 0.993, test acc 0.994
epoch 14, loss 0.0008, train acc 0.995, test acc 0.996
epoch 15, loss 0.0008, train acc 0.994, test acc 0.998

测试

test = torch.Tensor([[[[0,0,0,0],[0,1,1,0],[0,1,1,0],[0,0,0,0]]],
         [[[1,1,1,1],[1,0,0,1],[1,0,0,1],[1,1,1,1]]],
         [[[0,1,0,1],[1,0,0,1],[1,0,0,1],[0,0,0,1]]],
         [[[0,1,1,1],[1,0,0,1],[1,0,0,1],[0,0,0,1]]],
         [[[0,0,1,1],[1,0,0,1],[1,0,0,1],[1,0,1,0]]],
         [[[0,0,1,0],[0,1,0,1],[0,0,1,1],[1,0,1,0]]],
         [[[1,1,1,0],[1,0,0,1],[1,0,1,1],[1,0,1,1]]]
         ])

target=torch.Tensor([0,1,0,1,1,0,1])
test
tensor([[[[0., 0., 0., 0.],
     [0., 1., 1., 0.],
     [0., 1., 1., 0.],
     [0., 0., 0., 0.]]],

​

    [[[1., 1., 1., 1.],
     [1., 0., 0., 1.],
     [1., 0., 0., 1.],
     [1., 1., 1., 1.]]],

​

    [[[0., 1., 0., 1.],
     [1., 0., 0., 1.],
     [1., 0., 0., 1.],
     [0., 0., 0., 1.]]],

​

    [[[0., 1., 1., 1.],
     [1., 0., 0., 1.],
     [1., 0., 0., 1.],
     [0., 0., 0., 1.]]],

​

    [[[0., 0., 1., 1.],
     [1., 0., 0., 1.],
     [1., 0., 0., 1.],
     [1., 0., 1., 0.]]],

​

    [[[0., 0., 1., 0.],
     [0., 1., 0., 1.],
     [0., 0., 1., 1.],
     [1., 0., 1., 0.]]],

​

    [[[1., 1., 1., 0.],
     [1., 0., 0., 1.],
     [1., 0., 1., 1.],
     [1., 0., 1., 1.]]]])



with torch.no_grad():
  output = net(test)
  output2 = net2(test)
predictions =output.argmax(dim=1)
predictions2 =output2.argmax(dim=1)
# 比较结果
print(f'Net测试结果{predictions.eq(target)}')
print(f'Net2测试结果{predictions2.eq(target)}')
Net测试结果tensor([ True, True, False, True, True, True, True])
Net2测试结果tensor([False, True, False, True, True, False, True])

到此这篇关于Pytorch 使用CNN图像分类的实现的文章就介绍到这了,更多相关Pytorch CNN图像分类内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python中的rjust()方法使用详解
May 19 Python
Python每天必学之bytes字节
Jan 28 Python
Python实现的计数排序算法示例
Nov 29 Python
python实现神经网络感知器算法
Dec 20 Python
Python上下文管理器全实例详解
Nov 12 Python
Keras框架中的epoch、bacth、batch size、iteration使用介绍
Jun 10 Python
推荐技术人员一款Python开源库(造数据神器)
Jul 08 Python
如何用Python绘制3D柱形图
Sep 16 Python
python实现二分查找算法
Sep 18 Python
Python urllib库如何添加headers过程解析
Oct 05 Python
python如何做代码性能分析
Apr 26 Python
如何利用python创作字符画
Jun 25 Python
利用python中的matplotlib打印混淆矩阵实例
Jun 16 #Python
Python SMTP配置参数并发送邮件
Jun 16 #Python
基于matplotlib中ion()和ioff()的使用详解
Jun 16 #Python
Python数据相关系数矩阵和热力图轻松实现教程
Jun 16 #Python
matplotlib.pyplot.matshow 矩阵可视化实例
Jun 16 #Python
使用python matploblib库绘制准确率,损失率折线图
Jun 16 #Python
为什么称python为胶水语言
Jun 16 #Python
You might like
VFP与其他应用程序的集成
2006/10/09 PHP
PHP安装攻略:常见问题解答(一)
2006/10/09 PHP
php中根据某年第几天计算出日期年月日的代码
2011/02/24 PHP
基于PHP一些十分严重的缺陷详解
2013/06/03 PHP
一个简洁的PHP可逆加密函数(分享)
2013/06/06 PHP
php像数组一样存取和修改字符串字符
2014/03/21 PHP
刷新PHP缓冲区为你的站点加速
2015/10/10 PHP
编写PHP脚本使WordPress的主题支持Widget侧边栏
2015/12/14 PHP
对比PHP对MySQL的缓冲查询和无缓冲查询
2016/07/01 PHP
关于laravel 数据库迁移中integer类型是无法指定长度的问题
2019/10/09 PHP
十分钟打造AutoComplete自动完成效果代码
2009/12/26 Javascript
jQuery 仿百度输入标签插件附效果图
2014/07/04 Javascript
跟我学习javascript的浮点数精度
2015/11/16 Javascript
JavaScript tab选项卡插件实例代码
2016/02/23 Javascript
Javascript生成全局唯一标识符(GUID,UUID)的方法
2016/02/27 Javascript
浅谈JQuery+ajax+jsonp 跨域访问
2016/06/25 Javascript
简单理解vue中track-by属性
2016/10/26 Javascript
webpack打包并将文件加载到指定的位置方法
2018/02/22 Javascript
JS中DOM元素的attribute与property属性示例详解
2018/09/04 Javascript
4个顶级JavaScript高级文本编辑器
2018/10/10 Javascript
详解vue 项目白屏解决方案
2018/10/31 Javascript
[51:36]EG vs VP 2018国际邀请赛淘汰赛BO3 第一场 8.24
2018/08/25 DOTA
pandas 读取各种格式文件的方法
2018/06/22 Python
python爬虫库scrapy简单使用实例详解
2020/02/10 Python
python3字符串输出常见面试题总结
2020/12/01 Python
纯CSS3实现带动画效果导航菜单无需js
2013/09/27 HTML / CSS
CSS中几个与换行有关的属性简明总结
2014/04/15 HTML / CSS
德国隐形眼镜店:LuckyLens
2018/07/29 全球购物
KTV的创业计划书范文
2014/02/02 职场文书
生日庆典策划方案
2014/06/02 职场文书
在职员工证明书
2014/09/19 职场文书
2014院党委领导班子及其成员群众路线对照检查材料思想汇报
2014/10/04 职场文书
承兑汇票转让证明怎么写?
2014/11/30 职场文书
匿名信格式范文
2015/05/27 职场文书
居安思危观后感
2015/06/11 职场文书
军训新闻稿范文
2015/07/17 职场文书