pytorch + visdom CNN处理自建图片数据集的方法


Posted in Python onJune 04, 2018

环境

系统:win10

cpu:i7-6700HQ

gpu:gtx965m

python : 3.6

pytorch :0.3

数据下载

来源自Sasank Chilamkurthy 的教程; 数据:下载链接。

下载后解压放到项目根目录:

pytorch + visdom CNN处理自建图片数据集的方法 

数据集为用来分类 蚂蚁和蜜蜂。有大约120个训练图像,每个类有75个验证图像。

数据导入

可以使用 torchvision.datasets.ImageFolder(root,transforms) 模块 可以将 图片转换为 tensor。

先定义transform:

ata_transforms = {
  'train': transforms.Compose([
    # 随机切成224x224 大小图片 统一图片格式
    transforms.RandomResizedCrop(224),
    # 图像翻转
    transforms.RandomHorizontalFlip(),
    # totensor 归一化(0,255) >> (0,1)  normalize  channel=(channel-mean)/std
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  ]),
  "val" : transforms.Compose([
    # 图片大小缩放 统一图片格式
    transforms.Resize(256),
    # 以中心裁剪
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  ])
}

导入,加载数据:

data_dir = './hymenoptera_data'
# trans data
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}
# load data
data_loaders = {x: DataLoader(image_datasets[x], batch_size=BATCH_SIZE, shuffle=True) for x in ['train', 'val']}

data_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
print(data_sizes, class_names)

{'train': 244, 'val': 153} ['ants', 'bees']

训练集 244图片 , 测试集153图片 。

可视化部分图片看看,由于visdom支持tensor输入 ,不用换成numpy,直接用tensor计算即可 :

inputs, classes = next(iter(data_loaders['val']))

out = torchvision.utils.make_grid(inputs)
inp = torch.transpose(out, 0, 2)
mean = torch.FloatTensor([0.485, 0.456, 0.406])
std = torch.FloatTensor([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = torch.transpose(inp, 0, 2)
viz.images(inp)

pytorch + visdom CNN处理自建图片数据集的方法

创建CNN

net 根据上一篇的处理cifar10的改了一下规格:

class CNN(nn.Module):
  def __init__(self, in_dim, n_class):
    super(CNN, self).__init__()
    self.cnn = nn.Sequential(
      nn.BatchNorm2d(in_dim),
      nn.ReLU(True),
      nn.Conv2d(in_dim, 16, 7), # 224 >> 218
      nn.BatchNorm2d(16),
      nn.ReLU(inplace=True),
      nn.MaxPool2d(2, 2), # 218 >> 109
      nn.ReLU(True),
      nn.Conv2d(16, 32, 5), # 105
      nn.BatchNorm2d(32),
      nn.ReLU(True),
      nn.Conv2d(32, 64, 5), # 101
      nn.BatchNorm2d(64),
      nn.ReLU(True),
      nn.Conv2d(64, 64, 3, 1, 1),
      nn.BatchNorm2d(64),
      nn.ReLU(True),
      nn.MaxPool2d(2, 2), # 101 >> 50
      nn.Conv2d(64, 128, 3, 1, 1), #
      nn.BatchNorm2d(128),
      nn.ReLU(True),
      nn.MaxPool2d(3), # 50 >> 16
    )
    self.fc = nn.Sequential(
      nn.Linear(128*16*16, 120),
      nn.BatchNorm1d(120),
      nn.ReLU(True),
      nn.Linear(120, n_class))
  def forward(self, x):
    out = self.cnn(x)
    out = self.fc(out.view(-1, 128*16*16))
    return out

# 输入3层rgb ,输出 分类 2    
model = CNN(3, 2)

loss,优化函数:

line = viz.line(Y=np.arange(10))
loss_f = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0.9)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

参数:

BATCH_SIZE = 4
LR = 0.001
EPOCHS = 10

运行 10个 epoch 看看:

[9/10] train_loss:0.650|train_acc:0.639|test_loss:0.621|test_acc0.706
[10/10] train_loss:0.645|train_acc:0.627|test_loss:0.654|test_acc0.686
Training complete in 1m 16s
Best val Acc: 0.712418

pytorch + visdom CNN处理自建图片数据集的方法

运行 20个看看:

[19/20] train_loss:0.592|train_acc:0.701|test_loss:0.563|test_acc0.712
[20/20] train_loss:0.564|train_acc:0.721|test_loss:0.571|test_acc0.706
Training complete in 2m 30s
Best val Acc: 0.745098

pytorch + visdom CNN处理自建图片数据集的方法

准确率比较低:只有74.5%

我们使用models 里的 resnet18 运行 10个epoch:

model = torchvision.models.resnet18(True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)

[9/10] train_loss:0.621|train_acc:0.652|test_loss:0.588|test_acc0.667
[10/10] train_loss:0.610|train_acc:0.680|test_loss:0.561|test_acc0.667
Training complete in 1m 24s
Best val Acc: 0.686275

效果也很一般,想要短时间内就训练出效果很好的models,我们可以下载训练好的state,在此基础上训练:

model = torchvision.models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)

[9/10] train_loss:0.308|train_acc:0.877|test_loss:0.160|test_acc0.941
[10/10] train_loss:0.267|train_acc:0.885|test_loss:0.148|test_acc0.954
Training complete in 1m 25s
Best val Acc: 0.954248

10个epoch直接的到95%的准确率。

pytorch + visdom CNN处理自建图片数据集的方法

示例代码:https://github.com/ffzs/ml_pytorch/blob/master/ml_pytorch_hymenoptera

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python监控网卡流量并使用graphite绘图的示例
Apr 27 Python
Python内置函数之filter map reduce介绍
Nov 30 Python
Python3.4编程实现简单抓取爬虫功能示例
Sep 14 Python
python批量导入数据进Elasticsearch的实例
May 30 Python
详解Python字符串切片
May 20 Python
python实现视频分帧效果
May 31 Python
python制作简单五子棋游戏
Jun 18 Python
Python autoescape标签用法解析
Jan 17 Python
opencv 形态学变换(开运算,闭运算,梯度运算)
Jul 07 Python
如何使用scrapy中的ItemLoader提取数据
Sep 30 Python
python 基于pygame实现俄罗斯方块
Mar 02 Python
Pandas 数据编码的十种方法
Apr 20 Python
python验证码识别教程之滑动验证码
Jun 04 #Python
python验证码识别教程之利用投影法、连通域法分割图片
Jun 04 #Python
python验证码识别教程之灰度处理、二值化、降噪与tesserocr识别
Jun 04 #Python
实用自动化运维Python脚本分享
Jun 04 #Python
python中验证码连通域分割的方法详解
Jun 04 #Python
python 匹配url中是否存在IP地址的方法
Jun 04 #Python
Python实现ping指定IP的示例
Jun 04 #Python
You might like
PHP4实际应用经验篇(9)
2006/10/09 PHP
PHP输入流php://input介绍
2012/09/18 PHP
利用phpExcel实现Excel数据的导入导出(全步骤详细解析)
2013/11/26 PHP
php如何实现只替换一次或N次
2015/10/29 PHP
Yii2使用$this->context获取当前的Module、Controller(控制器)、Action等
2017/03/29 PHP
javascript mouseover、mouseout停止事件冒泡的解决方案
2009/04/07 Javascript
让firefox支持IE的一些方法的javascript扩展函数代码
2010/01/02 Javascript
网站基于flash实现的Banner图切换效果代码
2014/10/14 Javascript
jQuery插件ajaxFileUpload实现异步上传文件效果
2015/04/14 Javascript
Jquery数字上下滚动动态切换插件
2015/08/08 Javascript
jQuery多级手风琴菜单实例讲解
2015/10/22 Javascript
Bootstrap前端开发案例一
2016/06/17 Javascript
获取JavaScript异步函数的返回值
2016/12/21 Javascript
微信小程序-小说阅读小程序实例(demo)
2017/01/12 Javascript
nodejs个人博客开发第七步 后台登陆
2017/04/12 NodeJs
JavaScript惰性载入函数实例分析
2019/03/27 Javascript
浅谈ECMAScript 中的Array类型
2019/06/10 Javascript
JQuery发送ajax请求时中文乱码问题解决
2019/11/14 jQuery
jQuery操作元素的内容和样式完整实例分析
2020/01/10 jQuery
在vs code 中如何创建一个自己的 Vue 模板代码
2020/11/10 Javascript
[01:20:05]DOTA2-DPC中国联赛 正赛 Ehome vs VG BO3 第二场 2月5日
2021/03/11 DOTA
跟老齐学Python之总结参数的传递
2014/10/10 Python
Python TestCase中的断言方法介绍
2019/05/02 Python
python3实现小球转动抽奖小游戏
2020/04/15 Python
基于python和flask实现http接口过程解析
2020/06/15 Python
Python中openpyxl实现vlookup函数的实例
2020/10/28 Python
用Python 执行cmd命令
2020/12/18 Python
h5页面背景图很长要有滚动条滑动效果的实现
2021/01/27 HTML / CSS
如果重写了对象的equals()方法,需要考虑什么
2014/11/02 面试题
房地产出纳岗位职责
2013/12/01 职场文书
文明学生事迹材料
2014/01/29 职场文书
优秀的2014年两会精神解读
2014/03/17 职场文书
冬季安全检查方案
2014/05/23 职场文书
机关干部三严三实心得体会
2014/10/13 职场文书
2016暑期政治学习心得体会
2016/01/23 职场文书
2016年乡镇七一建党节活动总结
2016/04/05 职场文书