pytorch + visdom CNN处理自建图片数据集的方法


Posted in Python onJune 04, 2018

环境

系统:win10

cpu:i7-6700HQ

gpu:gtx965m

python : 3.6

pytorch :0.3

数据下载

来源自Sasank Chilamkurthy 的教程; 数据:下载链接。

下载后解压放到项目根目录:

pytorch + visdom CNN处理自建图片数据集的方法 

数据集为用来分类 蚂蚁和蜜蜂。有大约120个训练图像,每个类有75个验证图像。

数据导入

可以使用 torchvision.datasets.ImageFolder(root,transforms) 模块 可以将 图片转换为 tensor。

先定义transform:

ata_transforms = {
  'train': transforms.Compose([
    # 随机切成224x224 大小图片 统一图片格式
    transforms.RandomResizedCrop(224),
    # 图像翻转
    transforms.RandomHorizontalFlip(),
    # totensor 归一化(0,255) >> (0,1)  normalize  channel=(channel-mean)/std
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  ]),
  "val" : transforms.Compose([
    # 图片大小缩放 统一图片格式
    transforms.Resize(256),
    # 以中心裁剪
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  ])
}

导入,加载数据:

data_dir = './hymenoptera_data'
# trans data
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}
# load data
data_loaders = {x: DataLoader(image_datasets[x], batch_size=BATCH_SIZE, shuffle=True) for x in ['train', 'val']}

data_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
print(data_sizes, class_names)

{'train': 244, 'val': 153} ['ants', 'bees']

训练集 244图片 , 测试集153图片 。

可视化部分图片看看,由于visdom支持tensor输入 ,不用换成numpy,直接用tensor计算即可 :

inputs, classes = next(iter(data_loaders['val']))

out = torchvision.utils.make_grid(inputs)
inp = torch.transpose(out, 0, 2)
mean = torch.FloatTensor([0.485, 0.456, 0.406])
std = torch.FloatTensor([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = torch.transpose(inp, 0, 2)
viz.images(inp)

pytorch + visdom CNN处理自建图片数据集的方法

创建CNN

net 根据上一篇的处理cifar10的改了一下规格:

class CNN(nn.Module):
  def __init__(self, in_dim, n_class):
    super(CNN, self).__init__()
    self.cnn = nn.Sequential(
      nn.BatchNorm2d(in_dim),
      nn.ReLU(True),
      nn.Conv2d(in_dim, 16, 7), # 224 >> 218
      nn.BatchNorm2d(16),
      nn.ReLU(inplace=True),
      nn.MaxPool2d(2, 2), # 218 >> 109
      nn.ReLU(True),
      nn.Conv2d(16, 32, 5), # 105
      nn.BatchNorm2d(32),
      nn.ReLU(True),
      nn.Conv2d(32, 64, 5), # 101
      nn.BatchNorm2d(64),
      nn.ReLU(True),
      nn.Conv2d(64, 64, 3, 1, 1),
      nn.BatchNorm2d(64),
      nn.ReLU(True),
      nn.MaxPool2d(2, 2), # 101 >> 50
      nn.Conv2d(64, 128, 3, 1, 1), #
      nn.BatchNorm2d(128),
      nn.ReLU(True),
      nn.MaxPool2d(3), # 50 >> 16
    )
    self.fc = nn.Sequential(
      nn.Linear(128*16*16, 120),
      nn.BatchNorm1d(120),
      nn.ReLU(True),
      nn.Linear(120, n_class))
  def forward(self, x):
    out = self.cnn(x)
    out = self.fc(out.view(-1, 128*16*16))
    return out

# 输入3层rgb ,输出 分类 2    
model = CNN(3, 2)

loss,优化函数:

line = viz.line(Y=np.arange(10))
loss_f = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0.9)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

参数:

BATCH_SIZE = 4
LR = 0.001
EPOCHS = 10

运行 10个 epoch 看看:

[9/10] train_loss:0.650|train_acc:0.639|test_loss:0.621|test_acc0.706
[10/10] train_loss:0.645|train_acc:0.627|test_loss:0.654|test_acc0.686
Training complete in 1m 16s
Best val Acc: 0.712418

pytorch + visdom CNN处理自建图片数据集的方法

运行 20个看看:

[19/20] train_loss:0.592|train_acc:0.701|test_loss:0.563|test_acc0.712
[20/20] train_loss:0.564|train_acc:0.721|test_loss:0.571|test_acc0.706
Training complete in 2m 30s
Best val Acc: 0.745098

pytorch + visdom CNN处理自建图片数据集的方法

准确率比较低:只有74.5%

我们使用models 里的 resnet18 运行 10个epoch:

model = torchvision.models.resnet18(True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)

[9/10] train_loss:0.621|train_acc:0.652|test_loss:0.588|test_acc0.667
[10/10] train_loss:0.610|train_acc:0.680|test_loss:0.561|test_acc0.667
Training complete in 1m 24s
Best val Acc: 0.686275

效果也很一般,想要短时间内就训练出效果很好的models,我们可以下载训练好的state,在此基础上训练:

model = torchvision.models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)

[9/10] train_loss:0.308|train_acc:0.877|test_loss:0.160|test_acc0.941
[10/10] train_loss:0.267|train_acc:0.885|test_loss:0.148|test_acc0.954
Training complete in 1m 25s
Best val Acc: 0.954248

10个epoch直接的到95%的准确率。

pytorch + visdom CNN处理自建图片数据集的方法

示例代码:https://github.com/ffzs/ml_pytorch/blob/master/ml_pytorch_hymenoptera

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中property属性实例解析
Feb 10 Python
取numpy数组的某几行某几列方法
Apr 03 Python
Python实现读取字符串按列分配后按行输出示例
Apr 17 Python
Python读取txt某几列绘图的方法
Oct 14 Python
Python爬取商家联系电话以及各种数据的方法
Nov 10 Python
Python matplotlib的使用并自定义colormap的方法
Dec 13 Python
django ModelForm修改显示缩略图 imagefield类型的实例
Jul 28 Python
浅谈matplotlib.pyplot与axes的关系
Mar 06 Python
Python 将代码转换为可执行文件脱离python环境运行(步骤详解)
Jan 25 Python
python中numpy数组与list相互转换实例方法
Jan 29 Python
python spilt()分隔字符串的实现示例
May 21 Python
关于Python OS模块常用文件/目录函数详解
Jul 01 Python
python验证码识别教程之滑动验证码
Jun 04 #Python
python验证码识别教程之利用投影法、连通域法分割图片
Jun 04 #Python
python验证码识别教程之灰度处理、二值化、降噪与tesserocr识别
Jun 04 #Python
实用自动化运维Python脚本分享
Jun 04 #Python
python中验证码连通域分割的方法详解
Jun 04 #Python
python 匹配url中是否存在IP地址的方法
Jun 04 #Python
Python实现ping指定IP的示例
Jun 04 #Python
You might like
PHP的一个完整SMTP类(解决邮件服务器需要验证时的问题)
2006/10/09 PHP
关于页面优化和伪静态
2009/10/11 PHP
PHP生成短网址的3种方法代码实例
2014/07/08 PHP
CentOS下与Apache连接的PHP多版本共存方案实现详解
2015/12/19 PHP
thinkPHP5.0框架环境变量配置方法
2017/03/17 PHP
php集成开发环境详解
2019/09/24 PHP
基于jQuery制作迷你背词汇工具
2010/07/27 Javascript
js中的值类型和引用类型小结 文字说明与实例
2010/12/12 Javascript
浅析js中取绝对值的2种方法
2013/07/09 Javascript
js控制input框只读实现示例
2014/01/20 Javascript
IE6-IE9中tbody的innerHTML不能赋值的解决方法
2014/06/05 Javascript
jquery实现图片上传之前预览的方法
2015/07/11 Javascript
JS DOM实现鼠标滑动图片效果
2020/09/17 Javascript
canvas实现手机端用来上传用户头像的代码
2016/10/20 Javascript
JavaScript中的编码和解码函数
2017/02/15 Javascript
完美解决浏览器跨域的几种方法(汇总)
2017/05/08 Javascript
前端页面文件拖拽上传模块js代码示例
2017/05/19 Javascript
详解Angular 4.x NgTemplateOutlet
2017/05/24 Javascript
使用node打造自己的命令行工具方法教程
2018/03/26 Javascript
JQuery Ajax如何实现注册检测用户名
2020/09/25 jQuery
Java 生成随机字符的示例代码
2021/01/13 Javascript
[56:45]DOTA2上海特级锦标赛D组小组赛#1 EG VS COL第一局
2016/02/28 DOTA
python中类和实例如何绑定属性与方法示例详解
2017/08/18 Python
对python实现模板生成脚本的方法详解
2019/01/30 Python
Python如何合并多个字典或映射
2020/07/24 Python
BeautifulSoup获取指定class样式的div的实现
2020/12/07 Python
北美大型运动类产品商城:Champs Sports
2017/01/12 全球购物
美国在线纱线商店:Darn Good Yarn
2019/03/20 全球购物
员工安全承诺书
2014/05/22 职场文书
社保转移委托书范本
2014/10/08 职场文书
JS实现简单控制视频播放倍速的实例代码
2021/04/18 Javascript
springboot项目以jar包运行的操作方法
2021/06/30 Java/Android
python全面解析接口返回数据
2022/02/12 Python
MySQL优化之慢日志查询
2022/06/10 MySQL
Win10加载疑难解答时出错发生意外错误的解决方法
2022/07/07 数码科技
Windows Server 2016服务器用户管理及远程授权图文教程
2022/08/14 Servers