pytorch 准备、训练和测试自己的图片数据的方法


Posted in Python onJanuary 10, 2020

大部分的pytorch入门教程,都是使用torchvision里面的数据进行训练和测试。如果我们是自己的图片数据,又该怎么做呢?

一、我的数据

我在学习的时候,使用的是fashion-mnist。这个数据比较小,我的电脑没有GPU,还能吃得消。关于fashion-mnist数据,可以百度,也可以点此 了解一下,数据就像这个样子:

pytorch 准备、训练和测试自己的图片数据的方法

下载地址:https://github.com/zalandoresearch/fashion-mnist

pytorch 准备、训练和测试自己的图片数据的方法

但是下载下来是一种二进制文件,并不是图片,因此我先转换成了图片。

我先解压gz文件到e:/fashion_mnist/文件夹

然后运行代码:

import os
from skimage import io
import torchvision.datasets.mnist as mnist

root="E:/fashion_mnist/"
train_set = (
  mnist.read_image_file(os.path.join(root, 'train-images-idx3-ubyte')),
  mnist.read_label_file(os.path.join(root, 'train-labels-idx1-ubyte'))
    )
test_set = (
  mnist.read_image_file(os.path.join(root, 't10k-images-idx3-ubyte')),
  mnist.read_label_file(os.path.join(root, 't10k-labels-idx1-ubyte'))
    )
print("training set :",train_set[0].size())
print("test set :",test_set[0].size())

def convert_to_img(train=True):
  if(train):
    f=open(root+'train.txt','w')
    data_path=root+'/train/'
    if(not os.path.exists(data_path)):
      os.makedirs(data_path)
    for i, (img,label) in enumerate(zip(train_set[0],train_set[1])):
      img_path=data_path+str(i)+'.jpg'
      io.imsave(img_path,img.numpy())
      f.write(img_path+' '+str(label)+'\n')
    f.close()
  else:
    f = open(root + 'test.txt', 'w')
    data_path = root + '/test/'
    if (not os.path.exists(data_path)):
      os.makedirs(data_path)
    for i, (img,label) in enumerate(zip(test_set[0],test_set[1])):
      img_path = data_path+ str(i) + '.jpg'
      io.imsave(img_path, img.numpy())
      f.write(img_path + ' ' + str(label) + '\n')
    f.close()

convert_to_img(True)
convert_to_img(False)

这样就会在e:/fashion_mnist/目录下分别生成train和test文件夹,用于存放图片。还在该目录下生成了标签文件train.txt和test.txt.

二、进行CNN分类训练和测试

先要将图片读取出来,准备成torch专用的dataset格式,再通过Dataloader进行分批次训练。

代码如下:

import torch
from torch.autograd import Variable
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
root="E:/fashion_mnist/"

# -----------------ready the dataset--------------------------
def default_loader(path):
  return Image.open(path).convert('RGB')
class MyDataset(Dataset):
  def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
    fh = open(txt, 'r')
    imgs = []
    for line in fh:
      line = line.strip('\n')
      line = line.rstrip()
      words = line.split()
      imgs.append((words[0],int(words[1])))
    self.imgs = imgs
    self.transform = transform
    self.target_transform = target_transform
    self.loader = loader

  def __getitem__(self, index):
    fn, label = self.imgs[index]
    img = self.loader(fn)
    if self.transform is not None:
      img = self.transform(img)
    return img,label

  def __len__(self):
    return len(self.imgs)

train_data=MyDataset(txt=root+'train.txt', transform=transforms.ToTensor())
test_data=MyDataset(txt=root+'test.txt', transform=transforms.ToTensor())
train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=64)


#-----------------create the Net and training------------------------

class Net(torch.nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.conv1 = torch.nn.Sequential(
      torch.nn.Conv2d(3, 32, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2))
    self.conv2 = torch.nn.Sequential(
      torch.nn.Conv2d(32, 64, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2)
    )
    self.conv3 = torch.nn.Sequential(
      torch.nn.Conv2d(64, 64, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2)
    )
    self.dense = torch.nn.Sequential(
      torch.nn.Linear(64 * 3 * 3, 128),
      torch.nn.ReLU(),
      torch.nn.Linear(128, 10)
    )

  def forward(self, x):
    conv1_out = self.conv1(x)
    conv2_out = self.conv2(conv1_out)
    conv3_out = self.conv3(conv2_out)
    res = conv3_out.view(conv3_out.size(0), -1)
    out = self.dense(res)
    return out


model = Net()
print(model)

optimizer = torch.optim.Adam(model.parameters())
loss_func = torch.nn.CrossEntropyLoss()

for epoch in range(10):
  print('epoch {}'.format(epoch + 1))
  # training-----------------------------
  train_loss = 0.
  train_acc = 0.
  for batch_x, batch_y in train_loader:
    batch_x, batch_y = Variable(batch_x), Variable(batch_y)
    out = model(batch_x)
    loss = loss_func(out, batch_y)
    train_loss += loss.data[0]
    pred = torch.max(out, 1)[1]
    train_correct = (pred == batch_y).sum()
    train_acc += train_correct.data[0]
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
  print('Train Loss: {:.6f}, Acc: {:.6f}'.format(train_loss / (len(
    train_data)), train_acc / (len(train_data))))

  # evaluation--------------------------------
  model.eval()
  eval_loss = 0.
  eval_acc = 0.
  for batch_x, batch_y in test_loader:
    batch_x, batch_y = Variable(batch_x, volatile=True), Variable(batch_y, volatile=True)
    out = model(batch_x)
    loss = loss_func(out, batch_y)
    eval_loss += loss.data[0]
    pred = torch.max(out, 1)[1]
    num_correct = (pred == batch_y).sum()
    eval_acc += num_correct.data[0]
  print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(
    test_data)), eval_acc / (len(test_data))))

打印出来的网络模型:

pytorch 准备、训练和测试自己的图片数据的方法

训练和测试结果:

pytorch 准备、训练和测试自己的图片数据的方法

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
浅谈python新手中常见的疑惑及解答
Jun 14 Python
Python IDLE 错误:IDLE''s subprocess didn''t make connection 的解决方案
Feb 13 Python
python基础之包的导入和__init__.py的介绍
Jan 08 Python
对Python 两大环境管理神器 pyenv 和 virtualenv详解
Dec 31 Python
Django中ajax发送post请求 报403错误CSRF验证失败解决方案
Aug 13 Python
python中有函数重载吗
May 28 Python
python如何处理程序无法打开
Jun 16 Python
在keras中对单一输入图像进行预测并返回预测结果操作
Jul 09 Python
Django日志及中间件模块应用案例
Sep 10 Python
python 批量将中文名转换为拼音
Feb 07 Python
如何在Python中创建二叉树
Mar 30 Python
python字符串的多行输出的实例详解
Jun 08 Python
pytorch GAN伪造手写体mnist数据集方式
Jan 10 #Python
MNIST数据集转化为二维图片的实现示例
Jan 10 #Python
pytorch:实现简单的GAN示例(MNIST数据集)
Jan 10 #Python
pytorch GAN生成对抗网络实例
Jan 10 #Python
解决pytorch报错:AssertionError: Invalid device id的问题
Jan 10 #Python
python3中关于excel追加写入格式被覆盖问题(实例代码)
Jan 10 #Python
mac使用python识别图形验证码功能
Jan 10 #Python
You might like
在PHP中检查PHP文件是否有语法错误的方法
2009/12/23 PHP
jQuery Mobile + PHP实现文件上传
2014/12/12 PHP
浅谈Yii乐观锁的使用及原理
2017/07/25 PHP
PHP的RSA加密解密方法以及开发接口使用
2018/02/11 PHP
PHP远程连接oracle数据库操作实现方法图文详解
2019/04/11 PHP
js bind 函数 使用闭包保存执行上下文
2011/12/26 Javascript
jQuery.each()用法分享
2012/07/31 Javascript
js日期相关函数总结分享
2013/10/15 Javascript
jquery实现在网页指定区域显示自定义右键菜单效果
2015/08/25 Javascript
Angular Js文件上传之form-data
2015/08/28 Javascript
jQuery简单实现仿京东商城的左侧菜单效果代码
2015/09/09 Javascript
原生JS封装ajax 传json,str,excel文件上传提交表单(推荐)
2016/06/21 Javascript
如何制作幻灯片(代码分享)
2017/01/06 Javascript
JavaScript中利用for循环遍历数组
2017/01/15 Javascript
微信JSAPI Ticket接口签名详解
2020/06/28 Javascript
微信小程序实现流程进度的图样式功能
2018/01/16 Javascript
JS面试题大坑之隐式类型转换实例代码
2018/10/14 Javascript
使用原生JS实现火锅点餐小程序(面向对象思想)
2019/12/10 Javascript
Nuxt pages下不同的页面对应layout下的页面布局操作
2020/11/05 Javascript
[01:13:51]TNC vs Serenity 2018国际邀请赛小组赛BO2 第二场 8.18
2018/08/19 DOTA
Python基于PycURL实现POST的方法
2015/07/25 Python
Pycharm技巧之代码跳转该如何回退
2017/07/16 Python
python中类的输出或类的实例输出为这种形式的原因
2019/08/12 Python
Python解压 rar、zip、tar文件的方法
2019/11/19 Python
python3.x中安装web.py步骤方法
2020/06/23 Python
Django正则URL匹配实现流程解析
2020/11/13 Python
css3新增颜色表示方式分享
2014/04/15 HTML / CSS
求职简历自我评价范例
2014/03/12 职场文书
大学生求职计划书
2014/04/30 职场文书
我的中国梦演讲稿400字
2014/08/19 职场文书
刑事代理授权委托书
2014/09/17 职场文书
2014银行授权委托书样本
2014/10/04 职场文书
环保证明
2015/06/23 职场文书
致运动员的广播稿
2015/08/19 职场文书
ROS系统将python包编译为可执行文件的简单步骤
2021/07/25 Python
Oracle 触发器trigger使用案例
2022/02/24 Oracle