pytorch 准备、训练和测试自己的图片数据的方法


Posted in Python onJanuary 10, 2020

大部分的pytorch入门教程,都是使用torchvision里面的数据进行训练和测试。如果我们是自己的图片数据,又该怎么做呢?

一、我的数据

我在学习的时候,使用的是fashion-mnist。这个数据比较小,我的电脑没有GPU,还能吃得消。关于fashion-mnist数据,可以百度,也可以点此 了解一下,数据就像这个样子:

pytorch 准备、训练和测试自己的图片数据的方法

下载地址:https://github.com/zalandoresearch/fashion-mnist

pytorch 准备、训练和测试自己的图片数据的方法

但是下载下来是一种二进制文件,并不是图片,因此我先转换成了图片。

我先解压gz文件到e:/fashion_mnist/文件夹

然后运行代码:

import os
from skimage import io
import torchvision.datasets.mnist as mnist

root="E:/fashion_mnist/"
train_set = (
  mnist.read_image_file(os.path.join(root, 'train-images-idx3-ubyte')),
  mnist.read_label_file(os.path.join(root, 'train-labels-idx1-ubyte'))
    )
test_set = (
  mnist.read_image_file(os.path.join(root, 't10k-images-idx3-ubyte')),
  mnist.read_label_file(os.path.join(root, 't10k-labels-idx1-ubyte'))
    )
print("training set :",train_set[0].size())
print("test set :",test_set[0].size())

def convert_to_img(train=True):
  if(train):
    f=open(root+'train.txt','w')
    data_path=root+'/train/'
    if(not os.path.exists(data_path)):
      os.makedirs(data_path)
    for i, (img,label) in enumerate(zip(train_set[0],train_set[1])):
      img_path=data_path+str(i)+'.jpg'
      io.imsave(img_path,img.numpy())
      f.write(img_path+' '+str(label)+'\n')
    f.close()
  else:
    f = open(root + 'test.txt', 'w')
    data_path = root + '/test/'
    if (not os.path.exists(data_path)):
      os.makedirs(data_path)
    for i, (img,label) in enumerate(zip(test_set[0],test_set[1])):
      img_path = data_path+ str(i) + '.jpg'
      io.imsave(img_path, img.numpy())
      f.write(img_path + ' ' + str(label) + '\n')
    f.close()

convert_to_img(True)
convert_to_img(False)

这样就会在e:/fashion_mnist/目录下分别生成train和test文件夹,用于存放图片。还在该目录下生成了标签文件train.txt和test.txt.

二、进行CNN分类训练和测试

先要将图片读取出来,准备成torch专用的dataset格式,再通过Dataloader进行分批次训练。

代码如下:

import torch
from torch.autograd import Variable
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
root="E:/fashion_mnist/"

# -----------------ready the dataset--------------------------
def default_loader(path):
  return Image.open(path).convert('RGB')
class MyDataset(Dataset):
  def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
    fh = open(txt, 'r')
    imgs = []
    for line in fh:
      line = line.strip('\n')
      line = line.rstrip()
      words = line.split()
      imgs.append((words[0],int(words[1])))
    self.imgs = imgs
    self.transform = transform
    self.target_transform = target_transform
    self.loader = loader

  def __getitem__(self, index):
    fn, label = self.imgs[index]
    img = self.loader(fn)
    if self.transform is not None:
      img = self.transform(img)
    return img,label

  def __len__(self):
    return len(self.imgs)

train_data=MyDataset(txt=root+'train.txt', transform=transforms.ToTensor())
test_data=MyDataset(txt=root+'test.txt', transform=transforms.ToTensor())
train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=64)


#-----------------create the Net and training------------------------

class Net(torch.nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.conv1 = torch.nn.Sequential(
      torch.nn.Conv2d(3, 32, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2))
    self.conv2 = torch.nn.Sequential(
      torch.nn.Conv2d(32, 64, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2)
    )
    self.conv3 = torch.nn.Sequential(
      torch.nn.Conv2d(64, 64, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2)
    )
    self.dense = torch.nn.Sequential(
      torch.nn.Linear(64 * 3 * 3, 128),
      torch.nn.ReLU(),
      torch.nn.Linear(128, 10)
    )

  def forward(self, x):
    conv1_out = self.conv1(x)
    conv2_out = self.conv2(conv1_out)
    conv3_out = self.conv3(conv2_out)
    res = conv3_out.view(conv3_out.size(0), -1)
    out = self.dense(res)
    return out


model = Net()
print(model)

optimizer = torch.optim.Adam(model.parameters())
loss_func = torch.nn.CrossEntropyLoss()

for epoch in range(10):
  print('epoch {}'.format(epoch + 1))
  # training-----------------------------
  train_loss = 0.
  train_acc = 0.
  for batch_x, batch_y in train_loader:
    batch_x, batch_y = Variable(batch_x), Variable(batch_y)
    out = model(batch_x)
    loss = loss_func(out, batch_y)
    train_loss += loss.data[0]
    pred = torch.max(out, 1)[1]
    train_correct = (pred == batch_y).sum()
    train_acc += train_correct.data[0]
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
  print('Train Loss: {:.6f}, Acc: {:.6f}'.format(train_loss / (len(
    train_data)), train_acc / (len(train_data))))

  # evaluation--------------------------------
  model.eval()
  eval_loss = 0.
  eval_acc = 0.
  for batch_x, batch_y in test_loader:
    batch_x, batch_y = Variable(batch_x, volatile=True), Variable(batch_y, volatile=True)
    out = model(batch_x)
    loss = loss_func(out, batch_y)
    eval_loss += loss.data[0]
    pred = torch.max(out, 1)[1]
    num_correct = (pred == batch_y).sum()
    eval_acc += num_correct.data[0]
  print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(
    test_data)), eval_acc / (len(test_data))))

打印出来的网络模型:

pytorch 准备、训练和测试自己的图片数据的方法

训练和测试结果:

pytorch 准备、训练和测试自己的图片数据的方法

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python函数返回多个值的示例方法
Dec 04 Python
仅用50行Python代码实现一个简单的代理服务器
Apr 08 Python
Python中使用PyQt把网页转换成PDF操作代码实例
Apr 23 Python
Python将字符串常量转化为变量方法总结
Mar 17 Python
PyQt5 实现字体大小自适应分辨率的方法
Jun 18 Python
Python使用lambda表达式对字典排序操作示例
Jul 25 Python
Django CBV与FBV原理及实例详解
Aug 12 Python
tensorflow常用函数API介绍
Apr 19 Python
python中round函数如何使用
Jun 19 Python
Python dict的常用方法示例代码
Jun 23 Python
Django数据库迁移常见使用方法
Nov 12 Python
浅谈盘点5种基于Python生成的个性化语音方法
Feb 05 Python
pytorch GAN伪造手写体mnist数据集方式
Jan 10 #Python
MNIST数据集转化为二维图片的实现示例
Jan 10 #Python
pytorch:实现简单的GAN示例(MNIST数据集)
Jan 10 #Python
pytorch GAN生成对抗网络实例
Jan 10 #Python
解决pytorch报错:AssertionError: Invalid device id的问题
Jan 10 #Python
python3中关于excel追加写入格式被覆盖问题(实例代码)
Jan 10 #Python
mac使用python识别图形验证码功能
Jan 10 #Python
You might like
php中截取中文字符串的代码小结
2011/07/17 PHP
Laravel 5.5 的自定义验证对象/类示例代码详解
2017/08/29 PHP
PHP基于session.upload_progress 实现文件上传进度显示功能详解
2019/08/09 PHP
解决Laravel 不能创建 migration 的问题
2019/10/09 PHP
PHP 与 js的通信(via ajax,json)
2010/11/16 Javascript
javascript控制在光标位置插入文字适合表情的插入
2014/06/09 Javascript
BootStrap select2 动态改变值的方法
2017/02/10 Javascript
javascript 玩转Date对象(实例讲解)
2017/07/11 Javascript
React-Native左右联动List的示例代码
2017/09/21 Javascript
React中的refs的使用教程
2018/02/13 Javascript
详解angular应用容器化部署
2018/08/14 Javascript
Vue中多元素过渡特效的解决方案
2020/02/05 Javascript
微信小程序使用 vant Dialog组件的正确方式
2020/02/21 Javascript
基于Element的组件改造的树形选择器(树形下拉框)
2020/02/27 Javascript
vue二选一tab栏切换新做法实现
2021/01/19 Vue.js
nodejs中使用worker_threads来创建新的线程的方法
2021/01/22 NodeJs
[01:03:38]2014 DOTA2国际邀请赛中国区预选赛5.21 CNB VS CIS
2014/05/22 DOTA
[00:21]DOTA2亚洲邀请赛 Logo演绎
2015/02/07 DOTA
[01:32]2016国际邀请赛中国区预选赛CDEC战队教练采访
2016/06/26 DOTA
Python批量重命名同一文件夹下文件的方法
2015/05/25 Python
python之文件的读写和文件目录以及文件夹的操作实现代码
2016/08/28 Python
Python操作json的方法实例分析
2018/12/06 Python
Python----数据预处理代码实例
2019/03/20 Python
Python内存管理实例分析
2019/07/10 Python
django Admin文档生成器使用详解
2019/07/22 Python
Python提取视频中图片的示例(按帧、按秒)
2020/10/22 Python
KIKO MILANO英国官网:意大利知名化妆品和护肤品品牌
2017/09/25 全球购物
耐克波兰官方网站:Nike波兰
2019/09/03 全球购物
回门宴父母答谢词
2014/01/26 职场文书
2014自主招生自荐信策略
2014/01/27 职场文书
教学改革实施方案
2014/03/31 职场文书
会计专业自荐信
2014/06/03 职场文书
信息管理与信息系统专业求职信
2014/06/21 职场文书
职代会闭幕词
2015/01/28 职场文书
幼儿园中秋节活动总结
2015/03/23 职场文书
apache基于端口创建虚拟主机的示例
2021/04/22 Servers