pytorch + visdom CNN处理自建图片数据集的方法


Posted in Python onJune 04, 2018

环境

系统:win10

cpu:i7-6700HQ

gpu:gtx965m

python : 3.6

pytorch :0.3

数据下载

来源自Sasank Chilamkurthy 的教程; 数据:下载链接。

下载后解压放到项目根目录:

pytorch + visdom CNN处理自建图片数据集的方法 

数据集为用来分类 蚂蚁和蜜蜂。有大约120个训练图像,每个类有75个验证图像。

数据导入

可以使用 torchvision.datasets.ImageFolder(root,transforms) 模块 可以将 图片转换为 tensor。

先定义transform:

ata_transforms = {
  'train': transforms.Compose([
    # 随机切成224x224 大小图片 统一图片格式
    transforms.RandomResizedCrop(224),
    # 图像翻转
    transforms.RandomHorizontalFlip(),
    # totensor 归一化(0,255) >> (0,1)  normalize  channel=(channel-mean)/std
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  ]),
  "val" : transforms.Compose([
    # 图片大小缩放 统一图片格式
    transforms.Resize(256),
    # 以中心裁剪
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  ])
}

导入,加载数据:

data_dir = './hymenoptera_data'
# trans data
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}
# load data
data_loaders = {x: DataLoader(image_datasets[x], batch_size=BATCH_SIZE, shuffle=True) for x in ['train', 'val']}

data_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
print(data_sizes, class_names)

{'train': 244, 'val': 153} ['ants', 'bees']

训练集 244图片 , 测试集153图片 。

可视化部分图片看看,由于visdom支持tensor输入 ,不用换成numpy,直接用tensor计算即可 :

inputs, classes = next(iter(data_loaders['val']))

out = torchvision.utils.make_grid(inputs)
inp = torch.transpose(out, 0, 2)
mean = torch.FloatTensor([0.485, 0.456, 0.406])
std = torch.FloatTensor([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = torch.transpose(inp, 0, 2)
viz.images(inp)

pytorch + visdom CNN处理自建图片数据集的方法

创建CNN

net 根据上一篇的处理cifar10的改了一下规格:

class CNN(nn.Module):
  def __init__(self, in_dim, n_class):
    super(CNN, self).__init__()
    self.cnn = nn.Sequential(
      nn.BatchNorm2d(in_dim),
      nn.ReLU(True),
      nn.Conv2d(in_dim, 16, 7), # 224 >> 218
      nn.BatchNorm2d(16),
      nn.ReLU(inplace=True),
      nn.MaxPool2d(2, 2), # 218 >> 109
      nn.ReLU(True),
      nn.Conv2d(16, 32, 5), # 105
      nn.BatchNorm2d(32),
      nn.ReLU(True),
      nn.Conv2d(32, 64, 5), # 101
      nn.BatchNorm2d(64),
      nn.ReLU(True),
      nn.Conv2d(64, 64, 3, 1, 1),
      nn.BatchNorm2d(64),
      nn.ReLU(True),
      nn.MaxPool2d(2, 2), # 101 >> 50
      nn.Conv2d(64, 128, 3, 1, 1), #
      nn.BatchNorm2d(128),
      nn.ReLU(True),
      nn.MaxPool2d(3), # 50 >> 16
    )
    self.fc = nn.Sequential(
      nn.Linear(128*16*16, 120),
      nn.BatchNorm1d(120),
      nn.ReLU(True),
      nn.Linear(120, n_class))
  def forward(self, x):
    out = self.cnn(x)
    out = self.fc(out.view(-1, 128*16*16))
    return out

# 输入3层rgb ,输出 分类 2    
model = CNN(3, 2)

loss,优化函数:

line = viz.line(Y=np.arange(10))
loss_f = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0.9)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

参数:

BATCH_SIZE = 4
LR = 0.001
EPOCHS = 10

运行 10个 epoch 看看:

[9/10] train_loss:0.650|train_acc:0.639|test_loss:0.621|test_acc0.706
[10/10] train_loss:0.645|train_acc:0.627|test_loss:0.654|test_acc0.686
Training complete in 1m 16s
Best val Acc: 0.712418

pytorch + visdom CNN处理自建图片数据集的方法

运行 20个看看:

[19/20] train_loss:0.592|train_acc:0.701|test_loss:0.563|test_acc0.712
[20/20] train_loss:0.564|train_acc:0.721|test_loss:0.571|test_acc0.706
Training complete in 2m 30s
Best val Acc: 0.745098

pytorch + visdom CNN处理自建图片数据集的方法

准确率比较低:只有74.5%

我们使用models 里的 resnet18 运行 10个epoch:

model = torchvision.models.resnet18(True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)

[9/10] train_loss:0.621|train_acc:0.652|test_loss:0.588|test_acc0.667
[10/10] train_loss:0.610|train_acc:0.680|test_loss:0.561|test_acc0.667
Training complete in 1m 24s
Best val Acc: 0.686275

效果也很一般,想要短时间内就训练出效果很好的models,我们可以下载训练好的state,在此基础上训练:

model = torchvision.models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)

[9/10] train_loss:0.308|train_acc:0.877|test_loss:0.160|test_acc0.941
[10/10] train_loss:0.267|train_acc:0.885|test_loss:0.148|test_acc0.954
Training complete in 1m 25s
Best val Acc: 0.954248

10个epoch直接的到95%的准确率。

pytorch + visdom CNN处理自建图片数据集的方法

示例代码:https://github.com/ffzs/ml_pytorch/blob/master/ml_pytorch_hymenoptera

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Django的上下文中设置变量的方法
Jul 20 Python
python常见排序算法基础教程
Apr 13 Python
python 简单的绘图工具turtle使用详解
Jun 21 Python
Python编写一个优美的下载器
Apr 15 Python
删除python pandas.DataFrame 的多重index实例
Jun 08 Python
对Tensorflow中权值和feature map的可视化详解
Jun 14 Python
tensorflow实现逻辑回归模型
Sep 08 Python
解决pip install xxx报错SyntaxError: invalid syntax的问题
Nov 30 Python
在Django下测试与调试REST API的方法详解
Aug 29 Python
解析PyCharm Python运行权限问题
Jan 08 Python
对python中return与yield的区别详解
Mar 12 Python
opencv实现图像平移效果
Mar 24 Python
python验证码识别教程之滑动验证码
Jun 04 #Python
python验证码识别教程之利用投影法、连通域法分割图片
Jun 04 #Python
python验证码识别教程之灰度处理、二值化、降噪与tesserocr识别
Jun 04 #Python
实用自动化运维Python脚本分享
Jun 04 #Python
python中验证码连通域分割的方法详解
Jun 04 #Python
python 匹配url中是否存在IP地址的方法
Jun 04 #Python
Python实现ping指定IP的示例
Jun 04 #Python
You might like
PHP无限分类的类
2007/01/02 PHP
更正确的asp冒泡排序
2007/05/24 Javascript
setTimeout与setInterval在不同浏览器下的差异
2010/01/24 Javascript
利用JS延迟加载百度分享代码,提高网页速度
2013/07/01 Javascript
jQuery取得select选择的文本与值的示例
2013/12/09 Javascript
浅析JQuery UI Dialog的样式设置问题
2013/12/18 Javascript
原生JavaScript生成GUID的实现示例
2014/09/05 Javascript
jquery京东商城双11焦点图多图广告特效代码分享
2015/09/06 Javascript
js检测iframe是否加载完成的方法
2015/11/26 Javascript
AngularJS $injector 依赖注入详解
2016/09/14 Javascript
AngularJS创建自定义指令的方法详解
2016/11/03 Javascript
three.js快速入门【推荐】
2017/01/21 Javascript
jquery封装插件时匿名函数形参和实参的写法解释
2017/02/14 Javascript
微信小程序页面跳转功能之从列表的item项跳转到下一个页面的方法
2017/11/27 Javascript
浅析JS抽象工厂模式
2017/12/14 Javascript
vue+vue-router转场动画的实例代码
2018/09/01 Javascript
小试SVG之新手小白入门教程
2019/01/08 Javascript
Vue 实现简易多行滚动"弹幕"效果
2020/01/02 Javascript
python文件的md5加密方法
2016/04/06 Python
Python的IDEL增加清屏功能实例
2017/06/19 Python
Python使用Selenium+BeautifulSoup爬取淘宝搜索页
2018/02/24 Python
Python正则表达式实现简易计算器功能示例
2019/05/07 Python
python:按行读入,排序然后输出的方法
2019/07/20 Python
python标准库sys和OS的函数使用方法与实例详解
2020/02/12 Python
python matplotlib包图像配色方案分享
2020/03/14 Python
keras打印loss对权重的导数方式
2020/06/10 Python
如何基于Python爬虫爬取美团酒店信息
2020/11/03 Python
Django用户认证系统如何实现自定义
2020/11/12 Python
用python爬虫批量下载pdf的实现
2020/12/01 Python
海淘零差价,宝贝全球购: 宝贝格子
2016/08/24 全球购物
90后毕业生的求职信范文
2013/09/21 职场文书
商务会议邀请函
2014/01/09 职场文书
学生自我评语大全
2014/04/18 职场文书
教师节横幅标语
2014/10/08 职场文书
2015年电气技术员工作总结
2015/07/24 职场文书
Redis如何一键部署脚本
2021/04/12 Redis