pytorch + visdom CNN处理自建图片数据集的方法


Posted in Python onJune 04, 2018

环境

系统:win10

cpu:i7-6700HQ

gpu:gtx965m

python : 3.6

pytorch :0.3

数据下载

来源自Sasank Chilamkurthy 的教程; 数据:下载链接。

下载后解压放到项目根目录:

pytorch + visdom CNN处理自建图片数据集的方法 

数据集为用来分类 蚂蚁和蜜蜂。有大约120个训练图像,每个类有75个验证图像。

数据导入

可以使用 torchvision.datasets.ImageFolder(root,transforms) 模块 可以将 图片转换为 tensor。

先定义transform:

ata_transforms = {
  'train': transforms.Compose([
    # 随机切成224x224 大小图片 统一图片格式
    transforms.RandomResizedCrop(224),
    # 图像翻转
    transforms.RandomHorizontalFlip(),
    # totensor 归一化(0,255) >> (0,1)  normalize  channel=(channel-mean)/std
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  ]),
  "val" : transforms.Compose([
    # 图片大小缩放 统一图片格式
    transforms.Resize(256),
    # 以中心裁剪
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  ])
}

导入,加载数据:

data_dir = './hymenoptera_data'
# trans data
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}
# load data
data_loaders = {x: DataLoader(image_datasets[x], batch_size=BATCH_SIZE, shuffle=True) for x in ['train', 'val']}

data_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
print(data_sizes, class_names)

{'train': 244, 'val': 153} ['ants', 'bees']

训练集 244图片 , 测试集153图片 。

可视化部分图片看看,由于visdom支持tensor输入 ,不用换成numpy,直接用tensor计算即可 :

inputs, classes = next(iter(data_loaders['val']))

out = torchvision.utils.make_grid(inputs)
inp = torch.transpose(out, 0, 2)
mean = torch.FloatTensor([0.485, 0.456, 0.406])
std = torch.FloatTensor([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = torch.transpose(inp, 0, 2)
viz.images(inp)

pytorch + visdom CNN处理自建图片数据集的方法

创建CNN

net 根据上一篇的处理cifar10的改了一下规格:

class CNN(nn.Module):
  def __init__(self, in_dim, n_class):
    super(CNN, self).__init__()
    self.cnn = nn.Sequential(
      nn.BatchNorm2d(in_dim),
      nn.ReLU(True),
      nn.Conv2d(in_dim, 16, 7), # 224 >> 218
      nn.BatchNorm2d(16),
      nn.ReLU(inplace=True),
      nn.MaxPool2d(2, 2), # 218 >> 109
      nn.ReLU(True),
      nn.Conv2d(16, 32, 5), # 105
      nn.BatchNorm2d(32),
      nn.ReLU(True),
      nn.Conv2d(32, 64, 5), # 101
      nn.BatchNorm2d(64),
      nn.ReLU(True),
      nn.Conv2d(64, 64, 3, 1, 1),
      nn.BatchNorm2d(64),
      nn.ReLU(True),
      nn.MaxPool2d(2, 2), # 101 >> 50
      nn.Conv2d(64, 128, 3, 1, 1), #
      nn.BatchNorm2d(128),
      nn.ReLU(True),
      nn.MaxPool2d(3), # 50 >> 16
    )
    self.fc = nn.Sequential(
      nn.Linear(128*16*16, 120),
      nn.BatchNorm1d(120),
      nn.ReLU(True),
      nn.Linear(120, n_class))
  def forward(self, x):
    out = self.cnn(x)
    out = self.fc(out.view(-1, 128*16*16))
    return out

# 输入3层rgb ,输出 分类 2    
model = CNN(3, 2)

loss,优化函数:

line = viz.line(Y=np.arange(10))
loss_f = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0.9)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

参数:

BATCH_SIZE = 4
LR = 0.001
EPOCHS = 10

运行 10个 epoch 看看:

[9/10] train_loss:0.650|train_acc:0.639|test_loss:0.621|test_acc0.706
[10/10] train_loss:0.645|train_acc:0.627|test_loss:0.654|test_acc0.686
Training complete in 1m 16s
Best val Acc: 0.712418

pytorch + visdom CNN处理自建图片数据集的方法

运行 20个看看:

[19/20] train_loss:0.592|train_acc:0.701|test_loss:0.563|test_acc0.712
[20/20] train_loss:0.564|train_acc:0.721|test_loss:0.571|test_acc0.706
Training complete in 2m 30s
Best val Acc: 0.745098

pytorch + visdom CNN处理自建图片数据集的方法

准确率比较低:只有74.5%

我们使用models 里的 resnet18 运行 10个epoch:

model = torchvision.models.resnet18(True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)

[9/10] train_loss:0.621|train_acc:0.652|test_loss:0.588|test_acc0.667
[10/10] train_loss:0.610|train_acc:0.680|test_loss:0.561|test_acc0.667
Training complete in 1m 24s
Best val Acc: 0.686275

效果也很一般,想要短时间内就训练出效果很好的models,我们可以下载训练好的state,在此基础上训练:

model = torchvision.models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)

[9/10] train_loss:0.308|train_acc:0.877|test_loss:0.160|test_acc0.941
[10/10] train_loss:0.267|train_acc:0.885|test_loss:0.148|test_acc0.954
Training complete in 1m 25s
Best val Acc: 0.954248

10个epoch直接的到95%的准确率。

pytorch + visdom CNN处理自建图片数据集的方法

示例代码:https://github.com/ffzs/ml_pytorch/blob/master/ml_pytorch_hymenoptera

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python的Flask框架中使用Flask-SQLAlchemy管理数据库的教程
Jun 14 Python
python自动化脚本安装指定版本python环境详解
Sep 14 Python
Python将多个excel表格合并为一个表格
Feb 22 Python
Python打包方法Pyinstaller的使用
Oct 09 Python
对pyqt5多线程正确的开启姿势详解
Jun 14 Python
python之pexpect实现自动交互的例子
Jul 25 Python
python os.path.isfile 的使用误区详解
Nov 29 Python
python 字段拆分详解
Dec 17 Python
python如何通过twisted搭建socket服务
Feb 03 Python
python等差数列求和公式前 100 项的和实例
Feb 25 Python
如何用tempfile库创建python进程中的临时文件
Jan 28 Python
关于Python中进度条的六个实用技巧分享
Apr 05 Python
python验证码识别教程之滑动验证码
Jun 04 #Python
python验证码识别教程之利用投影法、连通域法分割图片
Jun 04 #Python
python验证码识别教程之灰度处理、二值化、降噪与tesserocr识别
Jun 04 #Python
实用自动化运维Python脚本分享
Jun 04 #Python
python中验证码连通域分割的方法详解
Jun 04 #Python
python 匹配url中是否存在IP地址的方法
Jun 04 #Python
Python实现ping指定IP的示例
Jun 04 #Python
You might like
PHP中查询SQL Server或Sybase时TEXT字段被截断的解决方法
2009/03/10 PHP
Zend Studio (eclipse)使用速度优化方法
2011/03/23 PHP
zend framework配置操作数据库实例分析
2012/12/06 PHP
解析php多线程下载远程多个文件
2013/06/25 PHP
PHP设置图片文件上传大小的具体实现方法
2013/10/11 PHP
JS分割字符串并放入数组的函数
2011/07/04 Javascript
js动画(animate)简单引擎代码示例
2012/12/04 Javascript
javascript实现了照片拖拽点击置顶的照片墙代码
2015/04/03 Javascript
jQuery实现不断闪烁文字的方法
2015/05/15 Javascript
jquery实现漫天雪花飞舞的圣诞祝福雪花效果代码分享
2015/08/20 Javascript
iframe中子父类窗口调用JS的方法及注意事项
2015/08/25 Javascript
jquery实现超简洁的TAB选项卡效果代码
2015/08/28 Javascript
Bootstrap轮播插件使用代码
2016/10/11 Javascript
详解nodejs 文本操作模块-fs模块(三)
2016/12/22 NodeJs
javascript实现多张图片左右无缝滚动效果
2017/03/22 Javascript
JS实现标签页切换效果
2017/05/04 Javascript
vue.js如何将echarts封装为组件一键使用详解
2017/10/10 Javascript
在vue使用clipboard.js进行一键复制文本的实现示例
2019/01/15 Javascript
浅谈Vue3.0之前你必须知道的TypeScript实战技巧
2019/09/11 Javascript
vue项目中使用bpmn为节点添加颜色的方法
2020/04/30 Javascript
详解Vue之事件处理
2020/07/10 Javascript
使用Python下载Bing图片(代码)
2013/11/07 Python
高效测试用例组织算法pairwise之Python实现方法
2017/07/19 Python
python 中的list和array的不同之处及转换问题
2018/03/13 Python
Python使用scipy模块实现一维卷积运算示例
2019/09/05 Python
pymysql的简单封装代码实例
2020/01/08 Python
python 穷举指定长度的密码例子
2020/04/02 Python
python爬虫用request库处理cookie的实例讲解
2021/02/20 Python
CSS3打造磨砂玻璃背景效果
2016/09/28 HTML / CSS
西海岸男士和男童服装:Johnnie-O
2018/03/15 全球购物
《东方明珠》教学反思
2014/04/20 职场文书
求职信格式要求
2014/05/23 职场文书
五四青年节优秀演讲稿范文
2014/05/28 职场文书
新兵入伍心得体会
2014/09/04 职场文书
股东授权委托书范本
2014/09/13 职场文书
Java使用JMeter进行高并发测试
2021/11/23 Java/Android