pytorch + visdom CNN处理自建图片数据集的方法


Posted in Python onJune 04, 2018

环境

系统:win10

cpu:i7-6700HQ

gpu:gtx965m

python : 3.6

pytorch :0.3

数据下载

来源自Sasank Chilamkurthy 的教程; 数据:下载链接。

下载后解压放到项目根目录:

pytorch + visdom CNN处理自建图片数据集的方法 

数据集为用来分类 蚂蚁和蜜蜂。有大约120个训练图像,每个类有75个验证图像。

数据导入

可以使用 torchvision.datasets.ImageFolder(root,transforms) 模块 可以将 图片转换为 tensor。

先定义transform:

ata_transforms = {
  'train': transforms.Compose([
    # 随机切成224x224 大小图片 统一图片格式
    transforms.RandomResizedCrop(224),
    # 图像翻转
    transforms.RandomHorizontalFlip(),
    # totensor 归一化(0,255) >> (0,1)  normalize  channel=(channel-mean)/std
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  ]),
  "val" : transforms.Compose([
    # 图片大小缩放 统一图片格式
    transforms.Resize(256),
    # 以中心裁剪
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  ])
}

导入,加载数据:

data_dir = './hymenoptera_data'
# trans data
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}
# load data
data_loaders = {x: DataLoader(image_datasets[x], batch_size=BATCH_SIZE, shuffle=True) for x in ['train', 'val']}

data_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
print(data_sizes, class_names)

{'train': 244, 'val': 153} ['ants', 'bees']

训练集 244图片 , 测试集153图片 。

可视化部分图片看看,由于visdom支持tensor输入 ,不用换成numpy,直接用tensor计算即可 :

inputs, classes = next(iter(data_loaders['val']))

out = torchvision.utils.make_grid(inputs)
inp = torch.transpose(out, 0, 2)
mean = torch.FloatTensor([0.485, 0.456, 0.406])
std = torch.FloatTensor([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = torch.transpose(inp, 0, 2)
viz.images(inp)

pytorch + visdom CNN处理自建图片数据集的方法

创建CNN

net 根据上一篇的处理cifar10的改了一下规格:

class CNN(nn.Module):
  def __init__(self, in_dim, n_class):
    super(CNN, self).__init__()
    self.cnn = nn.Sequential(
      nn.BatchNorm2d(in_dim),
      nn.ReLU(True),
      nn.Conv2d(in_dim, 16, 7), # 224 >> 218
      nn.BatchNorm2d(16),
      nn.ReLU(inplace=True),
      nn.MaxPool2d(2, 2), # 218 >> 109
      nn.ReLU(True),
      nn.Conv2d(16, 32, 5), # 105
      nn.BatchNorm2d(32),
      nn.ReLU(True),
      nn.Conv2d(32, 64, 5), # 101
      nn.BatchNorm2d(64),
      nn.ReLU(True),
      nn.Conv2d(64, 64, 3, 1, 1),
      nn.BatchNorm2d(64),
      nn.ReLU(True),
      nn.MaxPool2d(2, 2), # 101 >> 50
      nn.Conv2d(64, 128, 3, 1, 1), #
      nn.BatchNorm2d(128),
      nn.ReLU(True),
      nn.MaxPool2d(3), # 50 >> 16
    )
    self.fc = nn.Sequential(
      nn.Linear(128*16*16, 120),
      nn.BatchNorm1d(120),
      nn.ReLU(True),
      nn.Linear(120, n_class))
  def forward(self, x):
    out = self.cnn(x)
    out = self.fc(out.view(-1, 128*16*16))
    return out

# 输入3层rgb ,输出 分类 2    
model = CNN(3, 2)

loss,优化函数:

line = viz.line(Y=np.arange(10))
loss_f = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0.9)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

参数:

BATCH_SIZE = 4
LR = 0.001
EPOCHS = 10

运行 10个 epoch 看看:

[9/10] train_loss:0.650|train_acc:0.639|test_loss:0.621|test_acc0.706
[10/10] train_loss:0.645|train_acc:0.627|test_loss:0.654|test_acc0.686
Training complete in 1m 16s
Best val Acc: 0.712418

pytorch + visdom CNN处理自建图片数据集的方法

运行 20个看看:

[19/20] train_loss:0.592|train_acc:0.701|test_loss:0.563|test_acc0.712
[20/20] train_loss:0.564|train_acc:0.721|test_loss:0.571|test_acc0.706
Training complete in 2m 30s
Best val Acc: 0.745098

pytorch + visdom CNN处理自建图片数据集的方法

准确率比较低:只有74.5%

我们使用models 里的 resnet18 运行 10个epoch:

model = torchvision.models.resnet18(True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)

[9/10] train_loss:0.621|train_acc:0.652|test_loss:0.588|test_acc0.667
[10/10] train_loss:0.610|train_acc:0.680|test_loss:0.561|test_acc0.667
Training complete in 1m 24s
Best val Acc: 0.686275

效果也很一般,想要短时间内就训练出效果很好的models,我们可以下载训练好的state,在此基础上训练:

model = torchvision.models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)

[9/10] train_loss:0.308|train_acc:0.877|test_loss:0.160|test_acc0.941
[10/10] train_loss:0.267|train_acc:0.885|test_loss:0.148|test_acc0.954
Training complete in 1m 25s
Best val Acc: 0.954248

10个epoch直接的到95%的准确率。

pytorch + visdom CNN处理自建图片数据集的方法

示例代码:https://github.com/ffzs/ml_pytorch/blob/master/ml_pytorch_hymenoptera

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
尝试使用Python多线程抓取代理服务器IP地址的示例
Nov 09 Python
由浅入深讲解python中的yield与generator
Apr 05 Python
python版本坑:md5例子(python2与python3中md5区别)
Jun 20 Python
python好玩的项目—色情图片识别代码分享
Nov 07 Python
python调用c++ ctype list传数组或者返回数组的方法
Feb 13 Python
解决win7操作系统Python3.7.1安装后启动提示缺少.dll文件问题
Jul 15 Python
python打印直角三角形与等腰三角形实例代码
Oct 20 Python
基于python3 的百度图片下载器的实现代码
Nov 05 Python
TensorFlow tensor的拼接实例
Jan 19 Python
浅谈Pycharm的项目文件名是红色的原因及解决方式
Jun 01 Python
python 线程的五个状态
Sep 22 Python
Python爬虫实战之爬取京东商品数据并实实现数据可视化
Jun 07 Python
python验证码识别教程之滑动验证码
Jun 04 #Python
python验证码识别教程之利用投影法、连通域法分割图片
Jun 04 #Python
python验证码识别教程之灰度处理、二值化、降噪与tesserocr识别
Jun 04 #Python
实用自动化运维Python脚本分享
Jun 04 #Python
python中验证码连通域分割的方法详解
Jun 04 #Python
python 匹配url中是否存在IP地址的方法
Jun 04 #Python
Python实现ping指定IP的示例
Jun 04 #Python
You might like
239军机修复记
2021/03/02 无线电
php读取大文件示例分享(文件操作类)
2014/04/13 PHP
8个PHP程序员常用的功能汇总
2014/12/18 PHP
PHP将MySQL的查询结果转换为数组并用where拼接的示例
2016/05/13 PHP
PHP实现四种基础排序算法的运行时间比较(推荐)
2016/08/11 PHP
深入理解JavaScript 函数
2016/06/06 Javascript
BootStrap table删除指定行的注意事项(笔记整理)
2017/02/05 Javascript
详谈js遍历集合(Array,Map,Set)
2017/04/06 Javascript
Bootstrap.css与layDate日期选择样式起冲突的解决办法
2017/04/07 Javascript
mui上拉加载功能实例详解
2017/04/13 Javascript
关于express与koa的使用对比详解
2018/01/25 Javascript
vue实现在表格里,取每行的id的方法
2018/03/09 Javascript
一个基于react的图片裁剪组件示例
2018/04/18 Javascript
Vue自定义指令封装节流函数的方法示例
2018/07/09 Javascript
[01:03:00]DOTA2上海特级锦标赛A组败者赛 EHOME VS CDEC第一局
2016/02/25 DOTA
Python中的迭代器漫谈
2015/02/03 Python
Python的randrange()方法使用教程
2015/05/15 Python
基于Django contrib Comments 评论模块(详解)
2017/12/08 Python
Python处理文本换行符实例代码
2018/02/03 Python
Django中针对基于类的视图添加csrf_exempt实例代码
2018/02/11 Python
python3 判断列表是一个空列表的方法
2018/05/04 Python
查看TensorFlow checkpoint文件中的变量名和对应值方法
2018/06/14 Python
Python图像处理之图片文字识别功能(OCR)
2019/07/30 Python
python读取mysql数据绘制条形图
2020/03/25 Python
canvas绘制表情包的示例代码
2018/07/09 HTML / CSS
阿巴庭院:Abba Patio
2019/06/18 全球购物
CAT鞋加拿大官网:CAT Footwear加拿大
2020/08/05 全球购物
几个MySql的面试题
2013/04/22 面试题
深入开展党的群众路线教育实践活动方案
2014/02/04 职场文书
供货协议书范本
2014/04/22 职场文书
加强作风建设演讲稿
2014/10/24 职场文书
幼儿园教师自我评价
2015/03/04 职场文书
不同意离婚上诉状
2015/05/23 职场文书
雷锋的观后感
2015/06/10 职场文书
浅谈Redis存储数据类型及存取值方法
2021/05/08 Redis
python分分钟绘制精美地图海报
2022/02/15 Python