pytorch 准备、训练和测试自己的图片数据的方法


Posted in Python onJanuary 10, 2020

大部分的pytorch入门教程,都是使用torchvision里面的数据进行训练和测试。如果我们是自己的图片数据,又该怎么做呢?

一、我的数据

我在学习的时候,使用的是fashion-mnist。这个数据比较小,我的电脑没有GPU,还能吃得消。关于fashion-mnist数据,可以百度,也可以点此 了解一下,数据就像这个样子:

pytorch 准备、训练和测试自己的图片数据的方法

下载地址:https://github.com/zalandoresearch/fashion-mnist

pytorch 准备、训练和测试自己的图片数据的方法

但是下载下来是一种二进制文件,并不是图片,因此我先转换成了图片。

我先解压gz文件到e:/fashion_mnist/文件夹

然后运行代码:

import os
from skimage import io
import torchvision.datasets.mnist as mnist

root="E:/fashion_mnist/"
train_set = (
  mnist.read_image_file(os.path.join(root, 'train-images-idx3-ubyte')),
  mnist.read_label_file(os.path.join(root, 'train-labels-idx1-ubyte'))
    )
test_set = (
  mnist.read_image_file(os.path.join(root, 't10k-images-idx3-ubyte')),
  mnist.read_label_file(os.path.join(root, 't10k-labels-idx1-ubyte'))
    )
print("training set :",train_set[0].size())
print("test set :",test_set[0].size())

def convert_to_img(train=True):
  if(train):
    f=open(root+'train.txt','w')
    data_path=root+'/train/'
    if(not os.path.exists(data_path)):
      os.makedirs(data_path)
    for i, (img,label) in enumerate(zip(train_set[0],train_set[1])):
      img_path=data_path+str(i)+'.jpg'
      io.imsave(img_path,img.numpy())
      f.write(img_path+' '+str(label)+'\n')
    f.close()
  else:
    f = open(root + 'test.txt', 'w')
    data_path = root + '/test/'
    if (not os.path.exists(data_path)):
      os.makedirs(data_path)
    for i, (img,label) in enumerate(zip(test_set[0],test_set[1])):
      img_path = data_path+ str(i) + '.jpg'
      io.imsave(img_path, img.numpy())
      f.write(img_path + ' ' + str(label) + '\n')
    f.close()

convert_to_img(True)
convert_to_img(False)

这样就会在e:/fashion_mnist/目录下分别生成train和test文件夹,用于存放图片。还在该目录下生成了标签文件train.txt和test.txt.

二、进行CNN分类训练和测试

先要将图片读取出来,准备成torch专用的dataset格式,再通过Dataloader进行分批次训练。

代码如下:

import torch
from torch.autograd import Variable
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
root="E:/fashion_mnist/"

# -----------------ready the dataset--------------------------
def default_loader(path):
  return Image.open(path).convert('RGB')
class MyDataset(Dataset):
  def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
    fh = open(txt, 'r')
    imgs = []
    for line in fh:
      line = line.strip('\n')
      line = line.rstrip()
      words = line.split()
      imgs.append((words[0],int(words[1])))
    self.imgs = imgs
    self.transform = transform
    self.target_transform = target_transform
    self.loader = loader

  def __getitem__(self, index):
    fn, label = self.imgs[index]
    img = self.loader(fn)
    if self.transform is not None:
      img = self.transform(img)
    return img,label

  def __len__(self):
    return len(self.imgs)

train_data=MyDataset(txt=root+'train.txt', transform=transforms.ToTensor())
test_data=MyDataset(txt=root+'test.txt', transform=transforms.ToTensor())
train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=64)


#-----------------create the Net and training------------------------

class Net(torch.nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.conv1 = torch.nn.Sequential(
      torch.nn.Conv2d(3, 32, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2))
    self.conv2 = torch.nn.Sequential(
      torch.nn.Conv2d(32, 64, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2)
    )
    self.conv3 = torch.nn.Sequential(
      torch.nn.Conv2d(64, 64, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2)
    )
    self.dense = torch.nn.Sequential(
      torch.nn.Linear(64 * 3 * 3, 128),
      torch.nn.ReLU(),
      torch.nn.Linear(128, 10)
    )

  def forward(self, x):
    conv1_out = self.conv1(x)
    conv2_out = self.conv2(conv1_out)
    conv3_out = self.conv3(conv2_out)
    res = conv3_out.view(conv3_out.size(0), -1)
    out = self.dense(res)
    return out


model = Net()
print(model)

optimizer = torch.optim.Adam(model.parameters())
loss_func = torch.nn.CrossEntropyLoss()

for epoch in range(10):
  print('epoch {}'.format(epoch + 1))
  # training-----------------------------
  train_loss = 0.
  train_acc = 0.
  for batch_x, batch_y in train_loader:
    batch_x, batch_y = Variable(batch_x), Variable(batch_y)
    out = model(batch_x)
    loss = loss_func(out, batch_y)
    train_loss += loss.data[0]
    pred = torch.max(out, 1)[1]
    train_correct = (pred == batch_y).sum()
    train_acc += train_correct.data[0]
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
  print('Train Loss: {:.6f}, Acc: {:.6f}'.format(train_loss / (len(
    train_data)), train_acc / (len(train_data))))

  # evaluation--------------------------------
  model.eval()
  eval_loss = 0.
  eval_acc = 0.
  for batch_x, batch_y in test_loader:
    batch_x, batch_y = Variable(batch_x, volatile=True), Variable(batch_y, volatile=True)
    out = model(batch_x)
    loss = loss_func(out, batch_y)
    eval_loss += loss.data[0]
    pred = torch.max(out, 1)[1]
    num_correct = (pred == batch_y).sum()
    eval_acc += num_correct.data[0]
  print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(
    test_data)), eval_acc / (len(test_data))))

打印出来的网络模型:

pytorch 准备、训练和测试自己的图片数据的方法

训练和测试结果:

pytorch 准备、训练和测试自己的图片数据的方法

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
用python 制作图片转pdf工具
Jan 30 Python
Python列表推导式、字典推导式与集合推导式用法实例分析
Feb 07 Python
python实现简单淘宝秒杀功能
May 03 Python
python生成器与迭代器详解
Jan 01 Python
详解python中的数据类型和控制流
Aug 08 Python
Python中IP地址处理IPy模块的方法
Aug 16 Python
画pytorch模型图,以及参数计算的方法
Aug 17 Python
wxPython实现文本框基础组件
Nov 18 Python
YUV转为jpg图像的实现
Dec 09 Python
Python过滤序列元素的方法
Jul 31 Python
属性与 @property 方法让你的python更高效
Sep 21 Python
详解Python中的for循环
Apr 30 Python
pytorch GAN伪造手写体mnist数据集方式
Jan 10 #Python
MNIST数据集转化为二维图片的实现示例
Jan 10 #Python
pytorch:实现简单的GAN示例(MNIST数据集)
Jan 10 #Python
pytorch GAN生成对抗网络实例
Jan 10 #Python
解决pytorch报错:AssertionError: Invalid device id的问题
Jan 10 #Python
python3中关于excel追加写入格式被覆盖问题(实例代码)
Jan 10 #Python
mac使用python识别图形验证码功能
Jan 10 #Python
You might like
日本收入最高的漫画家:海贼王作者版税年收入高达8.45亿元
2020/03/04 日漫
php下MYSQL limit的优化
2008/01/10 PHP
zend optimizer在wamp的基础上安装图文教程
2013/10/26 PHP
php实现微信公众平台发红包功能
2018/06/14 PHP
Linux下安装Memcached服务器和客户端与PHP使用示例
2019/04/15 PHP
window.event快达到全浏览器支持了,以后使用就方便了
2011/11/30 Javascript
一个网页标题title的闪动提示效果实现思路
2014/03/22 Javascript
总结在前端排序中遇到的问题
2016/07/19 Javascript
JS作用域闭包、预解释和this关键字综合实例解析
2016/12/16 Javascript
JavaScript利用闭包实现模块化
2017/01/13 Javascript
Nodejs中使用captchapng模块生成图片验证码
2017/05/18 NodeJs
vue中$set的使用(结合在实际应用中遇到的坑)
2018/07/10 Javascript
vue router带参数页面刷新或回退参数消失的解决方法
2019/02/27 Javascript
vue动态注册组件实例代码详解
2019/05/30 Javascript
一文快速了解JQuery中的AJAX
2019/05/31 jQuery
JavaScript获取某一天所在的星期
2019/09/05 Javascript
a标签调用js的方法总结
2019/09/05 Javascript
node.js使用mongoose操作数据库实现购物车的增、删、改、查功能示例
2019/12/23 Javascript
python删除文件示例分享
2014/01/28 Python
python 简单搭建阻塞式单进程,多进程,多线程服务的实例
2017/11/01 Python
SVM基本概念及Python实现代码
2017/12/27 Python
详解Python3中ceil()函数用法
2019/02/19 Python
Python图像处理库PIL的ImageFilter模块使用介绍
2020/02/26 Python
PyQt5.6+pycharm配置以及pyinstaller生成exe(小白教程)
2020/06/02 Python
Python正则表达式高级使用方法汇总
2020/06/18 Python
利用HTML5中的Canvas绘制一张笑脸的教程
2015/05/07 HTML / CSS
英国高街电视:High Street TV
2018/05/22 全球购物
回馈慈善的设计师太阳镜:DIFF eyewear
2019/10/17 全球购物
机械绘图员岗位职责
2013/11/19 职场文书
自我评价的范文
2014/02/02 职场文书
内刊编辑求职自荐书范文
2014/02/19 职场文书
会计岗位职责范本
2014/03/07 职场文书
个人承诺书怎么写
2014/05/24 职场文书
税务会计岗位职责
2015/04/02 职场文书
离婚承诺书格式范文
2015/05/04 职场文书
Windows Server 2019 域控制器安装图文教程
2022/04/28 Servers