pytorch + visdom CNN处理自建图片数据集的方法


Posted in Python onJune 04, 2018

环境

系统:win10

cpu:i7-6700HQ

gpu:gtx965m

python : 3.6

pytorch :0.3

数据下载

来源自Sasank Chilamkurthy 的教程; 数据:下载链接。

下载后解压放到项目根目录:

pytorch + visdom CNN处理自建图片数据集的方法 

数据集为用来分类 蚂蚁和蜜蜂。有大约120个训练图像,每个类有75个验证图像。

数据导入

可以使用 torchvision.datasets.ImageFolder(root,transforms) 模块 可以将 图片转换为 tensor。

先定义transform:

ata_transforms = {
  'train': transforms.Compose([
    # 随机切成224x224 大小图片 统一图片格式
    transforms.RandomResizedCrop(224),
    # 图像翻转
    transforms.RandomHorizontalFlip(),
    # totensor 归一化(0,255) >> (0,1)  normalize  channel=(channel-mean)/std
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  ]),
  "val" : transforms.Compose([
    # 图片大小缩放 统一图片格式
    transforms.Resize(256),
    # 以中心裁剪
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  ])
}

导入,加载数据:

data_dir = './hymenoptera_data'
# trans data
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}
# load data
data_loaders = {x: DataLoader(image_datasets[x], batch_size=BATCH_SIZE, shuffle=True) for x in ['train', 'val']}

data_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
print(data_sizes, class_names)

{'train': 244, 'val': 153} ['ants', 'bees']

训练集 244图片 , 测试集153图片 。

可视化部分图片看看,由于visdom支持tensor输入 ,不用换成numpy,直接用tensor计算即可 :

inputs, classes = next(iter(data_loaders['val']))

out = torchvision.utils.make_grid(inputs)
inp = torch.transpose(out, 0, 2)
mean = torch.FloatTensor([0.485, 0.456, 0.406])
std = torch.FloatTensor([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = torch.transpose(inp, 0, 2)
viz.images(inp)

pytorch + visdom CNN处理自建图片数据集的方法

创建CNN

net 根据上一篇的处理cifar10的改了一下规格:

class CNN(nn.Module):
  def __init__(self, in_dim, n_class):
    super(CNN, self).__init__()
    self.cnn = nn.Sequential(
      nn.BatchNorm2d(in_dim),
      nn.ReLU(True),
      nn.Conv2d(in_dim, 16, 7), # 224 >> 218
      nn.BatchNorm2d(16),
      nn.ReLU(inplace=True),
      nn.MaxPool2d(2, 2), # 218 >> 109
      nn.ReLU(True),
      nn.Conv2d(16, 32, 5), # 105
      nn.BatchNorm2d(32),
      nn.ReLU(True),
      nn.Conv2d(32, 64, 5), # 101
      nn.BatchNorm2d(64),
      nn.ReLU(True),
      nn.Conv2d(64, 64, 3, 1, 1),
      nn.BatchNorm2d(64),
      nn.ReLU(True),
      nn.MaxPool2d(2, 2), # 101 >> 50
      nn.Conv2d(64, 128, 3, 1, 1), #
      nn.BatchNorm2d(128),
      nn.ReLU(True),
      nn.MaxPool2d(3), # 50 >> 16
    )
    self.fc = nn.Sequential(
      nn.Linear(128*16*16, 120),
      nn.BatchNorm1d(120),
      nn.ReLU(True),
      nn.Linear(120, n_class))
  def forward(self, x):
    out = self.cnn(x)
    out = self.fc(out.view(-1, 128*16*16))
    return out

# 输入3层rgb ,输出 分类 2    
model = CNN(3, 2)

loss,优化函数:

line = viz.line(Y=np.arange(10))
loss_f = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0.9)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

参数:

BATCH_SIZE = 4
LR = 0.001
EPOCHS = 10

运行 10个 epoch 看看:

[9/10] train_loss:0.650|train_acc:0.639|test_loss:0.621|test_acc0.706
[10/10] train_loss:0.645|train_acc:0.627|test_loss:0.654|test_acc0.686
Training complete in 1m 16s
Best val Acc: 0.712418

pytorch + visdom CNN处理自建图片数据集的方法

运行 20个看看:

[19/20] train_loss:0.592|train_acc:0.701|test_loss:0.563|test_acc0.712
[20/20] train_loss:0.564|train_acc:0.721|test_loss:0.571|test_acc0.706
Training complete in 2m 30s
Best val Acc: 0.745098

pytorch + visdom CNN处理自建图片数据集的方法

准确率比较低:只有74.5%

我们使用models 里的 resnet18 运行 10个epoch:

model = torchvision.models.resnet18(True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)

[9/10] train_loss:0.621|train_acc:0.652|test_loss:0.588|test_acc0.667
[10/10] train_loss:0.610|train_acc:0.680|test_loss:0.561|test_acc0.667
Training complete in 1m 24s
Best val Acc: 0.686275

效果也很一般,想要短时间内就训练出效果很好的models,我们可以下载训练好的state,在此基础上训练:

model = torchvision.models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)

[9/10] train_loss:0.308|train_acc:0.877|test_loss:0.160|test_acc0.941
[10/10] train_loss:0.267|train_acc:0.885|test_loss:0.148|test_acc0.954
Training complete in 1m 25s
Best val Acc: 0.954248

10个epoch直接的到95%的准确率。

pytorch + visdom CNN处理自建图片数据集的方法

示例代码:https://github.com/ffzs/ml_pytorch/blob/master/ml_pytorch_hymenoptera

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中特殊函数集锦
Jul 27 Python
python subprocess 杀掉全部派生的子进程方法
Jan 16 Python
解决tensorflow模型参数保存和加载的问题
Jul 26 Python
python中使用zip函数出现错误的原因
Sep 28 Python
python批量获取html内body内容的实例
Jan 02 Python
python导入模块交叉引用的方法
Jan 19 Python
Python正则表达式和re库知识点总结
Feb 11 Python
python安装scipy的步骤解析
Sep 28 Python
Python使用Turtle库绘制一棵西兰花
Nov 23 Python
Python collections模块的使用方法
Oct 09 Python
Jupyter Notebook 安装配置与使用详解
Jan 06 Python
Python+Tkinter制作专属图形化界面
Apr 01 Python
python验证码识别教程之滑动验证码
Jun 04 #Python
python验证码识别教程之利用投影法、连通域法分割图片
Jun 04 #Python
python验证码识别教程之灰度处理、二值化、降噪与tesserocr识别
Jun 04 #Python
实用自动化运维Python脚本分享
Jun 04 #Python
python中验证码连通域分割的方法详解
Jun 04 #Python
python 匹配url中是否存在IP地址的方法
Jun 04 #Python
Python实现ping指定IP的示例
Jun 04 #Python
You might like
PHP 数组遍历顺序理解
2009/09/09 PHP
用PHP将数据导入到Foxmail的实现代码
2010/09/05 PHP
thinkphp框架下实现登录、注册、找回密码功能
2016/04/06 PHP
PHP 500报错的快速解决方法
2016/12/14 PHP
一个刚完成的layout(拖动流畅,不受iframe影响)
2007/08/17 Javascript
Javascript select下拉框操作常用方法
2009/11/09 Javascript
js 数据类型转换总结笔记
2011/01/17 Javascript
JS获得URL超链接的参数值实例代码
2013/06/21 Javascript
JS控件ASP.NET的treeview控件全选或者取消(示例代码)
2013/12/16 Javascript
Javascript限制网页只能在微信内置浏览器中访问
2014/11/09 Javascript
基于javascript实现简单的抽奖系统
2020/04/15 Javascript
用JavaScript实现让浏览器停止载入页面的方法
2017/01/19 Javascript
详解JS中的柯里化(currying)
2017/08/17 Javascript
基于react后端渲染模板引擎noox发布使用
2018/01/11 Javascript
通过js动态创建标签,并设置属性方法
2018/02/24 Javascript
微信小程序分享功能之按钮button 边框隐藏和点击隐藏
2018/06/14 Javascript
Vue中$refs的用法详解
2018/06/24 Javascript
webpack打包非模块化js的方法
2018/10/24 Javascript
JavaScript面向对象编程小游戏---贪吃蛇代码实例
2019/05/15 Javascript
JS实现排行榜文字向上滚动轮播效果
2019/11/26 Javascript
Echarts实现单条折线可拖拽效果
2019/12/19 Javascript
vue.js实现照片放大功能
2020/06/23 Javascript
[01:24:16]2018DOTA2亚洲邀请赛 4.6 全明星赛
2018/04/10 DOTA
python版微信跳一跳游戏辅助
2018/01/11 Python
pandas实现选取特定索引的行
2018/04/20 Python
python调用百度语音识别api
2018/08/30 Python
Xadmin+rules实现多选行权限方式(级联效果)
2020/04/07 Python
Python退出时强制运行一段代码的实现方法
2020/04/29 Python
python中upper是做什么用的
2020/07/20 Python
python super()函数的基本使用
2020/09/10 Python
pip/anaconda修改镜像源,加快python模块安装速度的操作
2021/03/04 Python
国际商务系学生个人的自我评价
2013/11/26 职场文书
交通事故和解协议书
2014/09/25 职场文书
解除同居协议书
2015/01/29 职场文书
少先队中队工作总结
2015/08/14 职场文书
go使用Gin框架利用阿里云实现短信验证码功能
2021/08/04 Golang