pytorch + visdom CNN处理自建图片数据集的方法


Posted in Python onJune 04, 2018

环境

系统:win10

cpu:i7-6700HQ

gpu:gtx965m

python : 3.6

pytorch :0.3

数据下载

来源自Sasank Chilamkurthy 的教程; 数据:下载链接。

下载后解压放到项目根目录:

pytorch + visdom CNN处理自建图片数据集的方法 

数据集为用来分类 蚂蚁和蜜蜂。有大约120个训练图像,每个类有75个验证图像。

数据导入

可以使用 torchvision.datasets.ImageFolder(root,transforms) 模块 可以将 图片转换为 tensor。

先定义transform:

ata_transforms = {
  'train': transforms.Compose([
    # 随机切成224x224 大小图片 统一图片格式
    transforms.RandomResizedCrop(224),
    # 图像翻转
    transforms.RandomHorizontalFlip(),
    # totensor 归一化(0,255) >> (0,1)  normalize  channel=(channel-mean)/std
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  ]),
  "val" : transforms.Compose([
    # 图片大小缩放 统一图片格式
    transforms.Resize(256),
    # 以中心裁剪
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  ])
}

导入,加载数据:

data_dir = './hymenoptera_data'
# trans data
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}
# load data
data_loaders = {x: DataLoader(image_datasets[x], batch_size=BATCH_SIZE, shuffle=True) for x in ['train', 'val']}

data_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
print(data_sizes, class_names)

{'train': 244, 'val': 153} ['ants', 'bees']

训练集 244图片 , 测试集153图片 。

可视化部分图片看看,由于visdom支持tensor输入 ,不用换成numpy,直接用tensor计算即可 :

inputs, classes = next(iter(data_loaders['val']))

out = torchvision.utils.make_grid(inputs)
inp = torch.transpose(out, 0, 2)
mean = torch.FloatTensor([0.485, 0.456, 0.406])
std = torch.FloatTensor([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = torch.transpose(inp, 0, 2)
viz.images(inp)

pytorch + visdom CNN处理自建图片数据集的方法

创建CNN

net 根据上一篇的处理cifar10的改了一下规格:

class CNN(nn.Module):
  def __init__(self, in_dim, n_class):
    super(CNN, self).__init__()
    self.cnn = nn.Sequential(
      nn.BatchNorm2d(in_dim),
      nn.ReLU(True),
      nn.Conv2d(in_dim, 16, 7), # 224 >> 218
      nn.BatchNorm2d(16),
      nn.ReLU(inplace=True),
      nn.MaxPool2d(2, 2), # 218 >> 109
      nn.ReLU(True),
      nn.Conv2d(16, 32, 5), # 105
      nn.BatchNorm2d(32),
      nn.ReLU(True),
      nn.Conv2d(32, 64, 5), # 101
      nn.BatchNorm2d(64),
      nn.ReLU(True),
      nn.Conv2d(64, 64, 3, 1, 1),
      nn.BatchNorm2d(64),
      nn.ReLU(True),
      nn.MaxPool2d(2, 2), # 101 >> 50
      nn.Conv2d(64, 128, 3, 1, 1), #
      nn.BatchNorm2d(128),
      nn.ReLU(True),
      nn.MaxPool2d(3), # 50 >> 16
    )
    self.fc = nn.Sequential(
      nn.Linear(128*16*16, 120),
      nn.BatchNorm1d(120),
      nn.ReLU(True),
      nn.Linear(120, n_class))
  def forward(self, x):
    out = self.cnn(x)
    out = self.fc(out.view(-1, 128*16*16))
    return out

# 输入3层rgb ,输出 分类 2    
model = CNN(3, 2)

loss,优化函数:

line = viz.line(Y=np.arange(10))
loss_f = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0.9)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

参数:

BATCH_SIZE = 4
LR = 0.001
EPOCHS = 10

运行 10个 epoch 看看:

[9/10] train_loss:0.650|train_acc:0.639|test_loss:0.621|test_acc0.706
[10/10] train_loss:0.645|train_acc:0.627|test_loss:0.654|test_acc0.686
Training complete in 1m 16s
Best val Acc: 0.712418

pytorch + visdom CNN处理自建图片数据集的方法

运行 20个看看:

[19/20] train_loss:0.592|train_acc:0.701|test_loss:0.563|test_acc0.712
[20/20] train_loss:0.564|train_acc:0.721|test_loss:0.571|test_acc0.706
Training complete in 2m 30s
Best val Acc: 0.745098

pytorch + visdom CNN处理自建图片数据集的方法

准确率比较低:只有74.5%

我们使用models 里的 resnet18 运行 10个epoch:

model = torchvision.models.resnet18(True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)

[9/10] train_loss:0.621|train_acc:0.652|test_loss:0.588|test_acc0.667
[10/10] train_loss:0.610|train_acc:0.680|test_loss:0.561|test_acc0.667
Training complete in 1m 24s
Best val Acc: 0.686275

效果也很一般,想要短时间内就训练出效果很好的models,我们可以下载训练好的state,在此基础上训练:

model = torchvision.models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)

[9/10] train_loss:0.308|train_acc:0.877|test_loss:0.160|test_acc0.941
[10/10] train_loss:0.267|train_acc:0.885|test_loss:0.148|test_acc0.954
Training complete in 1m 25s
Best val Acc: 0.954248

10个epoch直接的到95%的准确率。

pytorch + visdom CNN处理自建图片数据集的方法

示例代码:https://github.com/ffzs/ml_pytorch/blob/master/ml_pytorch_hymenoptera

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 命令行参数sys.argv
Sep 06 Python
python网络编程学习笔记(七):HTML和XHTML解析(HTMLParser、BeautifulSoup)
Jun 09 Python
python实现web方式logview的方法
Aug 10 Python
详解设计模式中的工厂方法模式在Python程序中的运用
Mar 02 Python
Django自定义用户认证示例详解
Mar 14 Python
对numpy中轴与维度的理解
Apr 18 Python
python简单验证码识别的实现方法
May 10 Python
python使用Qt界面以及逻辑实现方法
Jul 10 Python
详解Python3 pandas.merge用法
Sep 05 Python
探秘TensorFlow 和 NumPy 的 Broadcasting 机制
Mar 13 Python
Linux安装Python3如何和系统自带的Python2并存
Jul 23 Python
如何解决python多种版本冲突问题
Oct 13 Python
python验证码识别教程之滑动验证码
Jun 04 #Python
python验证码识别教程之利用投影法、连通域法分割图片
Jun 04 #Python
python验证码识别教程之灰度处理、二值化、降噪与tesserocr识别
Jun 04 #Python
实用自动化运维Python脚本分享
Jun 04 #Python
python中验证码连通域分割的方法详解
Jun 04 #Python
python 匹配url中是否存在IP地址的方法
Jun 04 #Python
Python实现ping指定IP的示例
Jun 04 #Python
You might like
php像数组一样存取和修改字符串字符
2014/03/21 PHP
老生常谈ThinkPHP中的行为扩展和插件(推荐)
2017/05/05 PHP
PHP实现链式操作的三种方法详解
2017/11/16 PHP
TP5框架model常见操作示例小结【增删改查、聚合、时间戳、软删除等】
2020/04/05 PHP
HTML中事件触发列表与解说
2007/07/09 Javascript
10个基于jQuery或JavaScript的WYSIWYG 编辑器整理
2010/05/06 Javascript
动态加载外部javascript文件的函数代码分享
2011/07/28 Javascript
JavaScript数组深拷贝和浅拷贝的两种方法
2014/04/16 Javascript
JS函数重载的解决方案
2014/05/13 Javascript
深入了解Node.js中的一些特性
2014/09/25 Javascript
微信支付如何实现内置浏览器的H5页面支付
2015/09/25 Javascript
理解javascript模块化
2016/03/28 Javascript
原生js的数组除重复简单实例
2016/05/24 Javascript
微信小程序 教程之条件渲染
2016/10/18 Javascript
使用BootStrap和Metroui设计的metro风格微网站或手机app界面
2016/10/21 Javascript
微信小程序 toast 详解及实例代码
2016/11/09 Javascript
JS跨域请求外部服务器的资源
2017/02/06 Javascript
微信小程序实战之自定义toast(6)
2017/04/18 Javascript
vue-cli webpack 引入jquery的方法
2018/01/10 jQuery
vue 使某个组件不被 keep-alive 缓存的方法
2018/09/21 Javascript
小程序getLocation需要在app.json中声明permission字段
2019/04/04 Javascript
一篇文章,教你学会Vue CLI 插件开发
2019/04/17 Javascript
微信小程序在其他页面监听globalData中值的变化
2019/07/15 Javascript
NumPy 数学函数及代数运算的实现代码
2018/07/18 Python
python实现时间o(1)的最小栈的实例代码
2018/07/23 Python
学生信息管理系统python版
2018/10/17 Python
django-rest-framework解析请求参数过程详解
2019/07/18 Python
wxPython:python首选的GUI库实例分享
2019/10/05 Python
Django 构建模板form表单的两种方法
2020/06/14 Python
马来西亚排名第一的宠物用品店:Pets Wonderland
2020/04/16 全球购物
求职信模版
2013/11/30 职场文书
篮球社团活动总结
2014/06/27 职场文书
高中学校对照检查材料
2014/08/31 职场文书
2015年基层党组织公开承诺书
2015/01/21 职场文书
2016继续教育研修日志
2015/11/13 职场文书
2019暑假学生安全口号
2019/06/27 职场文书