pytorch 准备、训练和测试自己的图片数据的方法


Posted in Python onJanuary 10, 2020

大部分的pytorch入门教程,都是使用torchvision里面的数据进行训练和测试。如果我们是自己的图片数据,又该怎么做呢?

一、我的数据

我在学习的时候,使用的是fashion-mnist。这个数据比较小,我的电脑没有GPU,还能吃得消。关于fashion-mnist数据,可以百度,也可以点此 了解一下,数据就像这个样子:

pytorch 准备、训练和测试自己的图片数据的方法

下载地址:https://github.com/zalandoresearch/fashion-mnist

pytorch 准备、训练和测试自己的图片数据的方法

但是下载下来是一种二进制文件,并不是图片,因此我先转换成了图片。

我先解压gz文件到e:/fashion_mnist/文件夹

然后运行代码:

import os
from skimage import io
import torchvision.datasets.mnist as mnist

root="E:/fashion_mnist/"
train_set = (
  mnist.read_image_file(os.path.join(root, 'train-images-idx3-ubyte')),
  mnist.read_label_file(os.path.join(root, 'train-labels-idx1-ubyte'))
    )
test_set = (
  mnist.read_image_file(os.path.join(root, 't10k-images-idx3-ubyte')),
  mnist.read_label_file(os.path.join(root, 't10k-labels-idx1-ubyte'))
    )
print("training set :",train_set[0].size())
print("test set :",test_set[0].size())

def convert_to_img(train=True):
  if(train):
    f=open(root+'train.txt','w')
    data_path=root+'/train/'
    if(not os.path.exists(data_path)):
      os.makedirs(data_path)
    for i, (img,label) in enumerate(zip(train_set[0],train_set[1])):
      img_path=data_path+str(i)+'.jpg'
      io.imsave(img_path,img.numpy())
      f.write(img_path+' '+str(label)+'\n')
    f.close()
  else:
    f = open(root + 'test.txt', 'w')
    data_path = root + '/test/'
    if (not os.path.exists(data_path)):
      os.makedirs(data_path)
    for i, (img,label) in enumerate(zip(test_set[0],test_set[1])):
      img_path = data_path+ str(i) + '.jpg'
      io.imsave(img_path, img.numpy())
      f.write(img_path + ' ' + str(label) + '\n')
    f.close()

convert_to_img(True)
convert_to_img(False)

这样就会在e:/fashion_mnist/目录下分别生成train和test文件夹,用于存放图片。还在该目录下生成了标签文件train.txt和test.txt.

二、进行CNN分类训练和测试

先要将图片读取出来,准备成torch专用的dataset格式,再通过Dataloader进行分批次训练。

代码如下:

import torch
from torch.autograd import Variable
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
root="E:/fashion_mnist/"

# -----------------ready the dataset--------------------------
def default_loader(path):
  return Image.open(path).convert('RGB')
class MyDataset(Dataset):
  def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
    fh = open(txt, 'r')
    imgs = []
    for line in fh:
      line = line.strip('\n')
      line = line.rstrip()
      words = line.split()
      imgs.append((words[0],int(words[1])))
    self.imgs = imgs
    self.transform = transform
    self.target_transform = target_transform
    self.loader = loader

  def __getitem__(self, index):
    fn, label = self.imgs[index]
    img = self.loader(fn)
    if self.transform is not None:
      img = self.transform(img)
    return img,label

  def __len__(self):
    return len(self.imgs)

train_data=MyDataset(txt=root+'train.txt', transform=transforms.ToTensor())
test_data=MyDataset(txt=root+'test.txt', transform=transforms.ToTensor())
train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=64)


#-----------------create the Net and training------------------------

class Net(torch.nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.conv1 = torch.nn.Sequential(
      torch.nn.Conv2d(3, 32, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2))
    self.conv2 = torch.nn.Sequential(
      torch.nn.Conv2d(32, 64, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2)
    )
    self.conv3 = torch.nn.Sequential(
      torch.nn.Conv2d(64, 64, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2)
    )
    self.dense = torch.nn.Sequential(
      torch.nn.Linear(64 * 3 * 3, 128),
      torch.nn.ReLU(),
      torch.nn.Linear(128, 10)
    )

  def forward(self, x):
    conv1_out = self.conv1(x)
    conv2_out = self.conv2(conv1_out)
    conv3_out = self.conv3(conv2_out)
    res = conv3_out.view(conv3_out.size(0), -1)
    out = self.dense(res)
    return out


model = Net()
print(model)

optimizer = torch.optim.Adam(model.parameters())
loss_func = torch.nn.CrossEntropyLoss()

for epoch in range(10):
  print('epoch {}'.format(epoch + 1))
  # training-----------------------------
  train_loss = 0.
  train_acc = 0.
  for batch_x, batch_y in train_loader:
    batch_x, batch_y = Variable(batch_x), Variable(batch_y)
    out = model(batch_x)
    loss = loss_func(out, batch_y)
    train_loss += loss.data[0]
    pred = torch.max(out, 1)[1]
    train_correct = (pred == batch_y).sum()
    train_acc += train_correct.data[0]
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
  print('Train Loss: {:.6f}, Acc: {:.6f}'.format(train_loss / (len(
    train_data)), train_acc / (len(train_data))))

  # evaluation--------------------------------
  model.eval()
  eval_loss = 0.
  eval_acc = 0.
  for batch_x, batch_y in test_loader:
    batch_x, batch_y = Variable(batch_x, volatile=True), Variable(batch_y, volatile=True)
    out = model(batch_x)
    loss = loss_func(out, batch_y)
    eval_loss += loss.data[0]
    pred = torch.max(out, 1)[1]
    num_correct = (pred == batch_y).sum()
    eval_acc += num_correct.data[0]
  print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(
    test_data)), eval_acc / (len(test_data))))

打印出来的网络模型:

pytorch 准备、训练和测试自己的图片数据的方法

训练和测试结果:

pytorch 准备、训练和测试自己的图片数据的方法

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
17个Python小技巧分享
Jan 23 Python
Python中操作MySQL入门实例
Feb 08 Python
用python的requests第三方模块抓取王者荣耀所有英雄的皮肤实例
Dec 14 Python
Python中Threading用法详解
Dec 27 Python
人生苦短我用python python如何快速入门?
Mar 12 Python
python直接获取API传递回来的参数方法
Dec 17 Python
Python处理时间日期坐标轴过程详解
Jun 25 Python
基于Django ORM、一对一、一对多、多对多的全面讲解
Jul 26 Python
将Python文件打包成.EXE可执行文件的方法
Aug 11 Python
python如何求数组连续最大和的示例代码
Feb 04 Python
python 线性回归分析模型检验标准--拟合优度详解
Feb 24 Python
Python读写操作csv和excle文件代码实例
Mar 16 Python
pytorch GAN伪造手写体mnist数据集方式
Jan 10 #Python
MNIST数据集转化为二维图片的实现示例
Jan 10 #Python
pytorch:实现简单的GAN示例(MNIST数据集)
Jan 10 #Python
pytorch GAN生成对抗网络实例
Jan 10 #Python
解决pytorch报错:AssertionError: Invalid device id的问题
Jan 10 #Python
python3中关于excel追加写入格式被覆盖问题(实例代码)
Jan 10 #Python
mac使用python识别图形验证码功能
Jan 10 #Python
You might like
PHP4实际应用经验篇(2)
2006/10/09 PHP
php不使用插件导出excel的简单方法
2014/03/04 PHP
destoon实现首页显示供应、企业、资讯条数的方法
2014/07/15 PHP
ThinkPHP入库出现两次反斜线转义及数据库类转义的解决方法
2014/11/04 PHP
php无序树实现方法
2015/07/28 PHP
ThinkPHP中调用PHPExcel的实现代码
2017/04/08 PHP
PHP7 其他修改
2021/03/09 PHP
jquery入门—编写一个导航条(可伸缩)
2013/01/07 Javascript
HTML5 Shiv完美解决IE(IE6/IE7/IE8)不兼容HTML5标签的方法
2015/11/25 Javascript
实例详解JavaScript获取链接参数的方法
2016/01/01 Javascript
JavaScript、C# URL编码、解码总结
2017/01/21 Javascript
Vue2几种常见开局方式详解
2017/09/09 Javascript
《javascript少儿编程》location术语总结
2018/05/27 Javascript
在Vue中使用axios请求拦截的实现方法
2018/10/25 Javascript
Node.js JSON模块用法实例分析
2019/01/04 Javascript
Vue 组件修改根实例的数据的方法
2019/04/02 Javascript
Vue使用mixin分发组件的可复用功能
2019/09/01 Javascript
浅谈layui 绑定form submit提交表单的注意事项
2019/10/25 Javascript
ajax jquery实现页面某一个div的刷新效果
2021/03/04 jQuery
Python的词法分析与语法分析
2013/05/18 Python
Python中使用HTMLParser解析html实例
2015/02/08 Python
python字典DICT类型合并详解
2017/08/17 Python
Python中super函数的用法
2017/11/17 Python
python实现抖音点赞功能
2019/04/07 Python
用Python配平化学方程式的方法
2019/07/20 Python
python3.7 sys模块的具体使用
2019/07/22 Python
Flask框架学习笔记之使用Flask实现表单开发详解
2019/08/12 Python
python实现自动化报表功能(Oracle/plsql/Excel/多线程)
2019/12/02 Python
Python OpenCV实现测量图片物体宽度
2020/05/27 Python
请写出一段Python代码实现删除一个list里面的重复元素
2015/12/29 面试题
幼儿园教师备课制度
2014/01/12 职场文书
24岁生日感言
2014/01/13 职场文书
招标授权委托书样本
2014/09/23 职场文书
2015年健康教育工作总结
2015/04/10 职场文书
请病假条范文
2015/08/17 职场文书
Pillow图像处理库安装及使用
2022/04/12 Python