pytorch + visdom CNN处理自建图片数据集的方法


Posted in Python onJune 04, 2018

环境

系统:win10

cpu:i7-6700HQ

gpu:gtx965m

python : 3.6

pytorch :0.3

数据下载

来源自Sasank Chilamkurthy 的教程; 数据:下载链接。

下载后解压放到项目根目录:

pytorch + visdom CNN处理自建图片数据集的方法 

数据集为用来分类 蚂蚁和蜜蜂。有大约120个训练图像,每个类有75个验证图像。

数据导入

可以使用 torchvision.datasets.ImageFolder(root,transforms) 模块 可以将 图片转换为 tensor。

先定义transform:

ata_transforms = {
  'train': transforms.Compose([
    # 随机切成224x224 大小图片 统一图片格式
    transforms.RandomResizedCrop(224),
    # 图像翻转
    transforms.RandomHorizontalFlip(),
    # totensor 归一化(0,255) >> (0,1)  normalize  channel=(channel-mean)/std
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  ]),
  "val" : transforms.Compose([
    # 图片大小缩放 统一图片格式
    transforms.Resize(256),
    # 以中心裁剪
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  ])
}

导入,加载数据:

data_dir = './hymenoptera_data'
# trans data
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}
# load data
data_loaders = {x: DataLoader(image_datasets[x], batch_size=BATCH_SIZE, shuffle=True) for x in ['train', 'val']}

data_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
print(data_sizes, class_names)

{'train': 244, 'val': 153} ['ants', 'bees']

训练集 244图片 , 测试集153图片 。

可视化部分图片看看,由于visdom支持tensor输入 ,不用换成numpy,直接用tensor计算即可 :

inputs, classes = next(iter(data_loaders['val']))

out = torchvision.utils.make_grid(inputs)
inp = torch.transpose(out, 0, 2)
mean = torch.FloatTensor([0.485, 0.456, 0.406])
std = torch.FloatTensor([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = torch.transpose(inp, 0, 2)
viz.images(inp)

pytorch + visdom CNN处理自建图片数据集的方法

创建CNN

net 根据上一篇的处理cifar10的改了一下规格:

class CNN(nn.Module):
  def __init__(self, in_dim, n_class):
    super(CNN, self).__init__()
    self.cnn = nn.Sequential(
      nn.BatchNorm2d(in_dim),
      nn.ReLU(True),
      nn.Conv2d(in_dim, 16, 7), # 224 >> 218
      nn.BatchNorm2d(16),
      nn.ReLU(inplace=True),
      nn.MaxPool2d(2, 2), # 218 >> 109
      nn.ReLU(True),
      nn.Conv2d(16, 32, 5), # 105
      nn.BatchNorm2d(32),
      nn.ReLU(True),
      nn.Conv2d(32, 64, 5), # 101
      nn.BatchNorm2d(64),
      nn.ReLU(True),
      nn.Conv2d(64, 64, 3, 1, 1),
      nn.BatchNorm2d(64),
      nn.ReLU(True),
      nn.MaxPool2d(2, 2), # 101 >> 50
      nn.Conv2d(64, 128, 3, 1, 1), #
      nn.BatchNorm2d(128),
      nn.ReLU(True),
      nn.MaxPool2d(3), # 50 >> 16
    )
    self.fc = nn.Sequential(
      nn.Linear(128*16*16, 120),
      nn.BatchNorm1d(120),
      nn.ReLU(True),
      nn.Linear(120, n_class))
  def forward(self, x):
    out = self.cnn(x)
    out = self.fc(out.view(-1, 128*16*16))
    return out

# 输入3层rgb ,输出 分类 2    
model = CNN(3, 2)

loss,优化函数:

line = viz.line(Y=np.arange(10))
loss_f = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0.9)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

参数:

BATCH_SIZE = 4
LR = 0.001
EPOCHS = 10

运行 10个 epoch 看看:

[9/10] train_loss:0.650|train_acc:0.639|test_loss:0.621|test_acc0.706
[10/10] train_loss:0.645|train_acc:0.627|test_loss:0.654|test_acc0.686
Training complete in 1m 16s
Best val Acc: 0.712418

pytorch + visdom CNN处理自建图片数据集的方法

运行 20个看看:

[19/20] train_loss:0.592|train_acc:0.701|test_loss:0.563|test_acc0.712
[20/20] train_loss:0.564|train_acc:0.721|test_loss:0.571|test_acc0.706
Training complete in 2m 30s
Best val Acc: 0.745098

pytorch + visdom CNN处理自建图片数据集的方法

准确率比较低:只有74.5%

我们使用models 里的 resnet18 运行 10个epoch:

model = torchvision.models.resnet18(True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)

[9/10] train_loss:0.621|train_acc:0.652|test_loss:0.588|test_acc0.667
[10/10] train_loss:0.610|train_acc:0.680|test_loss:0.561|test_acc0.667
Training complete in 1m 24s
Best val Acc: 0.686275

效果也很一般,想要短时间内就训练出效果很好的models,我们可以下载训练好的state,在此基础上训练:

model = torchvision.models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)

[9/10] train_loss:0.308|train_acc:0.877|test_loss:0.160|test_acc0.941
[10/10] train_loss:0.267|train_acc:0.885|test_loss:0.148|test_acc0.954
Training complete in 1m 25s
Best val Acc: 0.954248

10个epoch直接的到95%的准确率。

pytorch + visdom CNN处理自建图片数据集的方法

示例代码:https://github.com/ffzs/ml_pytorch/blob/master/ml_pytorch_hymenoptera

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 列表删除所有指定元素的方法
Apr 19 Python
django缓存配置的几种方法详解
Jul 16 Python
Flask入门之上传文件到服务器的方法示例
Jul 18 Python
python实现遍历文件夹修改文件后缀
Aug 28 Python
PyQt5实现让QScrollArea支持鼠标拖动的操作方法
Jun 19 Python
python使用装饰器作日志处理的方法
Jul 11 Python
利用Python校准本地时间的方法教程
Oct 31 Python
python如何把字符串类型list转换成list
Feb 18 Python
Python PIL库图片灰化处理
Apr 07 Python
vscode写python时的代码错误提醒和自动格式化的方法
May 07 Python
学会Python数据可视化必须尝试这7个库
Jun 16 Python
解析目标检测之IoU
Jun 26 Python
python验证码识别教程之滑动验证码
Jun 04 #Python
python验证码识别教程之利用投影法、连通域法分割图片
Jun 04 #Python
python验证码识别教程之灰度处理、二值化、降噪与tesserocr识别
Jun 04 #Python
实用自动化运维Python脚本分享
Jun 04 #Python
python中验证码连通域分割的方法详解
Jun 04 #Python
python 匹配url中是否存在IP地址的方法
Jun 04 #Python
Python实现ping指定IP的示例
Jun 04 #Python
You might like
CodeIgniter中实现泛域名解析
2014/07/19 PHP
PHP伪造来源HTTP_REFERER的方法实例详解
2015/07/06 PHP
深入研究PHP中的preg_replace和代码执行
2018/08/15 PHP
jQuery实现商品活动倒计时
2015/10/16 Javascript
jQuery插件Validate实现自定义表单验证
2016/01/18 Javascript
基于BootStrap Metronic开发框架经验小结【三】下拉列表Select2插件的使用
2016/05/12 Javascript
每日十条JavaScript经验技巧(二)
2016/06/23 Javascript
JavaScript 中 avalon绑定属性总结
2016/10/19 Javascript
微信JSAPI支付操作需要注意的细节
2017/01/10 Javascript
详谈js遍历集合(Array,Map,Set)
2017/04/06 Javascript
荐书|您有一份JavaScript书单待签收
2017/07/21 Javascript
浅谈JS中的反柯里化( uncurrying)
2017/08/17 Javascript
Vue利用History记录上一页面的数据方法实例
2018/11/02 Javascript
如何封装了一个vue移动端下拉加载下一页数据的组件
2019/01/06 Javascript
简单了解vue.js数组的常用操作
2019/06/17 Javascript
python多线程编程中的join函数使用心得
2014/09/02 Python
python提取字典key列表的方法
2015/07/11 Python
在Python中定义和使用抽象类的方法
2016/06/30 Python
Python Socket传输文件示例
2017/01/16 Python
python实现人脸识别代码
2017/11/08 Python
Python实现嵌套列表及字典并按某一元素去重复功能示例
2017/11/30 Python
Python获取指定文件夹下的文件名的方法
2018/02/06 Python
Python字典遍历操作实例小结
2019/03/05 Python
django 微信网页授权登陆的实现
2019/07/30 Python
基于Python检测动态物体颜色过程解析
2019/12/04 Python
Pytorch 实现focal_loss 多类别和二分类示例
2020/01/14 Python
python numpy 矩阵堆叠实例
2020/01/17 Python
python实现音乐播放和下载小程序功能
2020/04/26 Python
Python 3.9的到来到底是意味着什么
2020/10/14 Python
Marmot土拨鼠官网:美国专业户外运动品牌
2018/01/11 全球购物
WINDOWS域的具体实现方式是什么
2014/02/20 面试题
详解如何解决使用JSON.stringify时遇到的循环引用问题
2021/03/23 Javascript
大学生职业生涯规划范文——找准自我,定位人生
2014/01/23 职场文书
七匹狼男装广告词
2014/03/21 职场文书
2014年环卫工作总结
2014/11/22 职场文书
职工培训工作总结
2015/08/10 职场文书