pytorch + visdom CNN处理自建图片数据集的方法


Posted in Python onJune 04, 2018

环境

系统:win10

cpu:i7-6700HQ

gpu:gtx965m

python : 3.6

pytorch :0.3

数据下载

来源自Sasank Chilamkurthy 的教程; 数据:下载链接。

下载后解压放到项目根目录:

pytorch + visdom CNN处理自建图片数据集的方法 

数据集为用来分类 蚂蚁和蜜蜂。有大约120个训练图像,每个类有75个验证图像。

数据导入

可以使用 torchvision.datasets.ImageFolder(root,transforms) 模块 可以将 图片转换为 tensor。

先定义transform:

ata_transforms = {
  'train': transforms.Compose([
    # 随机切成224x224 大小图片 统一图片格式
    transforms.RandomResizedCrop(224),
    # 图像翻转
    transforms.RandomHorizontalFlip(),
    # totensor 归一化(0,255) >> (0,1)  normalize  channel=(channel-mean)/std
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  ]),
  "val" : transforms.Compose([
    # 图片大小缩放 统一图片格式
    transforms.Resize(256),
    # 以中心裁剪
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  ])
}

导入,加载数据:

data_dir = './hymenoptera_data'
# trans data
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}
# load data
data_loaders = {x: DataLoader(image_datasets[x], batch_size=BATCH_SIZE, shuffle=True) for x in ['train', 'val']}

data_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
print(data_sizes, class_names)

{'train': 244, 'val': 153} ['ants', 'bees']

训练集 244图片 , 测试集153图片 。

可视化部分图片看看,由于visdom支持tensor输入 ,不用换成numpy,直接用tensor计算即可 :

inputs, classes = next(iter(data_loaders['val']))

out = torchvision.utils.make_grid(inputs)
inp = torch.transpose(out, 0, 2)
mean = torch.FloatTensor([0.485, 0.456, 0.406])
std = torch.FloatTensor([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = torch.transpose(inp, 0, 2)
viz.images(inp)

pytorch + visdom CNN处理自建图片数据集的方法

创建CNN

net 根据上一篇的处理cifar10的改了一下规格:

class CNN(nn.Module):
  def __init__(self, in_dim, n_class):
    super(CNN, self).__init__()
    self.cnn = nn.Sequential(
      nn.BatchNorm2d(in_dim),
      nn.ReLU(True),
      nn.Conv2d(in_dim, 16, 7), # 224 >> 218
      nn.BatchNorm2d(16),
      nn.ReLU(inplace=True),
      nn.MaxPool2d(2, 2), # 218 >> 109
      nn.ReLU(True),
      nn.Conv2d(16, 32, 5), # 105
      nn.BatchNorm2d(32),
      nn.ReLU(True),
      nn.Conv2d(32, 64, 5), # 101
      nn.BatchNorm2d(64),
      nn.ReLU(True),
      nn.Conv2d(64, 64, 3, 1, 1),
      nn.BatchNorm2d(64),
      nn.ReLU(True),
      nn.MaxPool2d(2, 2), # 101 >> 50
      nn.Conv2d(64, 128, 3, 1, 1), #
      nn.BatchNorm2d(128),
      nn.ReLU(True),
      nn.MaxPool2d(3), # 50 >> 16
    )
    self.fc = nn.Sequential(
      nn.Linear(128*16*16, 120),
      nn.BatchNorm1d(120),
      nn.ReLU(True),
      nn.Linear(120, n_class))
  def forward(self, x):
    out = self.cnn(x)
    out = self.fc(out.view(-1, 128*16*16))
    return out

# 输入3层rgb ,输出 分类 2    
model = CNN(3, 2)

loss,优化函数:

line = viz.line(Y=np.arange(10))
loss_f = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0.9)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

参数:

BATCH_SIZE = 4
LR = 0.001
EPOCHS = 10

运行 10个 epoch 看看:

[9/10] train_loss:0.650|train_acc:0.639|test_loss:0.621|test_acc0.706
[10/10] train_loss:0.645|train_acc:0.627|test_loss:0.654|test_acc0.686
Training complete in 1m 16s
Best val Acc: 0.712418

pytorch + visdom CNN处理自建图片数据集的方法

运行 20个看看:

[19/20] train_loss:0.592|train_acc:0.701|test_loss:0.563|test_acc0.712
[20/20] train_loss:0.564|train_acc:0.721|test_loss:0.571|test_acc0.706
Training complete in 2m 30s
Best val Acc: 0.745098

pytorch + visdom CNN处理自建图片数据集的方法

准确率比较低:只有74.5%

我们使用models 里的 resnet18 运行 10个epoch:

model = torchvision.models.resnet18(True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)

[9/10] train_loss:0.621|train_acc:0.652|test_loss:0.588|test_acc0.667
[10/10] train_loss:0.610|train_acc:0.680|test_loss:0.561|test_acc0.667
Training complete in 1m 24s
Best val Acc: 0.686275

效果也很一般,想要短时间内就训练出效果很好的models,我们可以下载训练好的state,在此基础上训练:

model = torchvision.models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)

[9/10] train_loss:0.308|train_acc:0.877|test_loss:0.160|test_acc0.941
[10/10] train_loss:0.267|train_acc:0.885|test_loss:0.148|test_acc0.954
Training complete in 1m 25s
Best val Acc: 0.954248

10个epoch直接的到95%的准确率。

pytorch + visdom CNN处理自建图片数据集的方法

示例代码:https://github.com/ffzs/ml_pytorch/blob/master/ml_pytorch_hymenoptera

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中使用select模块实现非阻塞的IO
Feb 03 Python
python使用Queue在多个子进程间交换数据的方法
Apr 18 Python
pandas去除重复列的实现方法
Jan 29 Python
Python类的继承用法示例
Jan 31 Python
情人节快乐! python绘制漂亮玫瑰
Aug 18 Python
pandas的连接函数concat()函数的具体使用方法
Jul 09 Python
简单了解python反射机制的一些知识
Jul 13 Python
使用Python和Scribus创建一个RGB立方体的方法
Jul 17 Python
Pandas对DataFrame单列/多列进行运算(map, apply, transform, agg)
Jun 14 Python
Python3开发环境搭建详细教程
Jun 18 Python
Python深度学习之Pytorch初步使用
May 20 Python
python内置模块之上下文管理contextlib
Jun 14 Python
python验证码识别教程之滑动验证码
Jun 04 #Python
python验证码识别教程之利用投影法、连通域法分割图片
Jun 04 #Python
python验证码识别教程之灰度处理、二值化、降噪与tesserocr识别
Jun 04 #Python
实用自动化运维Python脚本分享
Jun 04 #Python
python中验证码连通域分割的方法详解
Jun 04 #Python
python 匹配url中是否存在IP地址的方法
Jun 04 #Python
Python实现ping指定IP的示例
Jun 04 #Python
You might like
php microtime获取浮点的时间戳
2010/02/21 PHP
PHP上传图片进行等比缩放可增加水印功能
2014/01/13 PHP
PHP使用Face++接口开发微信公众平台人脸识别系统的方法
2015/04/17 PHP
PHP+Mysql+jQuery实现发布微博程序 php篇
2015/10/15 PHP
PHP扩展Memcache分布式部署方案
2015/12/06 PHP
PHP实现的进度条效果详解
2016/05/03 PHP
yii使用bootstrap分页样式的实例
2017/01/17 PHP
php实现映射操作实例详解
2019/10/02 PHP
浅析return false的正确使用
2013/11/04 Javascript
javascript实现可拖动变色并关闭层窗口实例
2015/05/15 Javascript
jquery实现可点击伸缩与展开的菜单效果代码
2015/08/31 Javascript
jQuery实现Tab菜单滚动切换的方法
2015/09/21 Javascript
详解Node.js实现301、302重定向服务
2017/04/07 Javascript
Vue利用路由钩子token过期后跳转到登录页的实例
2017/10/26 Javascript
vue 设置路由的登录权限的方法
2018/07/03 Javascript
JQuery样式操作、click事件以及索引值-选项卡应用示例
2019/05/14 jQuery
利用百度echarts实现图表功能简单入门示例【附源码下载】
2019/06/10 Javascript
python快速排序代码实例
2013/11/21 Python
Pyramid将models.py文件的内容分布到多个文件的方法
2013/11/27 Python
python对象及面向对象技术详解
2016/07/19 Python
如何在sae中设置django,让sae的工作环境跟本地python环境一致
2017/11/21 Python
python使用RNN实现文本分类
2018/05/24 Python
python tornado微信开发入门代码
2018/08/24 Python
利用python如何在前程无忧高效投递简历
2019/05/07 Python
python PIL和CV对 图片的读取,显示,裁剪,保存实现方法
2019/08/07 Python
python matplotlib 绘图 和 dpi对应关系详解
2020/03/14 Python
PyQt5.6+pycharm配置以及pyinstaller生成exe(小白教程)
2020/06/02 Python
东方电视购物:东方CJ
2016/10/12 全球购物
大一自我鉴定范文
2013/12/27 职场文书
经典安踏广告词
2014/03/21 职场文书
投资意向书范本
2014/04/01 职场文书
法人授权委托书
2014/04/03 职场文书
资料员岗位职责
2015/02/10 职场文书
Pytest实现setup和teardown的详细使用详解
2021/04/17 Python
jQuery实现广告显示和隐藏动画
2021/07/04 jQuery
解析python中的jsonpath 提取器
2022/01/18 Python