pytorch 准备、训练和测试自己的图片数据的方法


Posted in Python onJanuary 10, 2020

大部分的pytorch入门教程,都是使用torchvision里面的数据进行训练和测试。如果我们是自己的图片数据,又该怎么做呢?

一、我的数据

我在学习的时候,使用的是fashion-mnist。这个数据比较小,我的电脑没有GPU,还能吃得消。关于fashion-mnist数据,可以百度,也可以点此 了解一下,数据就像这个样子:

pytorch 准备、训练和测试自己的图片数据的方法

下载地址:https://github.com/zalandoresearch/fashion-mnist

pytorch 准备、训练和测试自己的图片数据的方法

但是下载下来是一种二进制文件,并不是图片,因此我先转换成了图片。

我先解压gz文件到e:/fashion_mnist/文件夹

然后运行代码:

import os
from skimage import io
import torchvision.datasets.mnist as mnist

root="E:/fashion_mnist/"
train_set = (
  mnist.read_image_file(os.path.join(root, 'train-images-idx3-ubyte')),
  mnist.read_label_file(os.path.join(root, 'train-labels-idx1-ubyte'))
    )
test_set = (
  mnist.read_image_file(os.path.join(root, 't10k-images-idx3-ubyte')),
  mnist.read_label_file(os.path.join(root, 't10k-labels-idx1-ubyte'))
    )
print("training set :",train_set[0].size())
print("test set :",test_set[0].size())

def convert_to_img(train=True):
  if(train):
    f=open(root+'train.txt','w')
    data_path=root+'/train/'
    if(not os.path.exists(data_path)):
      os.makedirs(data_path)
    for i, (img,label) in enumerate(zip(train_set[0],train_set[1])):
      img_path=data_path+str(i)+'.jpg'
      io.imsave(img_path,img.numpy())
      f.write(img_path+' '+str(label)+'\n')
    f.close()
  else:
    f = open(root + 'test.txt', 'w')
    data_path = root + '/test/'
    if (not os.path.exists(data_path)):
      os.makedirs(data_path)
    for i, (img,label) in enumerate(zip(test_set[0],test_set[1])):
      img_path = data_path+ str(i) + '.jpg'
      io.imsave(img_path, img.numpy())
      f.write(img_path + ' ' + str(label) + '\n')
    f.close()

convert_to_img(True)
convert_to_img(False)

这样就会在e:/fashion_mnist/目录下分别生成train和test文件夹,用于存放图片。还在该目录下生成了标签文件train.txt和test.txt.

二、进行CNN分类训练和测试

先要将图片读取出来,准备成torch专用的dataset格式,再通过Dataloader进行分批次训练。

代码如下:

import torch
from torch.autograd import Variable
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
root="E:/fashion_mnist/"

# -----------------ready the dataset--------------------------
def default_loader(path):
  return Image.open(path).convert('RGB')
class MyDataset(Dataset):
  def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
    fh = open(txt, 'r')
    imgs = []
    for line in fh:
      line = line.strip('\n')
      line = line.rstrip()
      words = line.split()
      imgs.append((words[0],int(words[1])))
    self.imgs = imgs
    self.transform = transform
    self.target_transform = target_transform
    self.loader = loader

  def __getitem__(self, index):
    fn, label = self.imgs[index]
    img = self.loader(fn)
    if self.transform is not None:
      img = self.transform(img)
    return img,label

  def __len__(self):
    return len(self.imgs)

train_data=MyDataset(txt=root+'train.txt', transform=transforms.ToTensor())
test_data=MyDataset(txt=root+'test.txt', transform=transforms.ToTensor())
train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=64)


#-----------------create the Net and training------------------------

class Net(torch.nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.conv1 = torch.nn.Sequential(
      torch.nn.Conv2d(3, 32, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2))
    self.conv2 = torch.nn.Sequential(
      torch.nn.Conv2d(32, 64, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2)
    )
    self.conv3 = torch.nn.Sequential(
      torch.nn.Conv2d(64, 64, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2)
    )
    self.dense = torch.nn.Sequential(
      torch.nn.Linear(64 * 3 * 3, 128),
      torch.nn.ReLU(),
      torch.nn.Linear(128, 10)
    )

  def forward(self, x):
    conv1_out = self.conv1(x)
    conv2_out = self.conv2(conv1_out)
    conv3_out = self.conv3(conv2_out)
    res = conv3_out.view(conv3_out.size(0), -1)
    out = self.dense(res)
    return out


model = Net()
print(model)

optimizer = torch.optim.Adam(model.parameters())
loss_func = torch.nn.CrossEntropyLoss()

for epoch in range(10):
  print('epoch {}'.format(epoch + 1))
  # training-----------------------------
  train_loss = 0.
  train_acc = 0.
  for batch_x, batch_y in train_loader:
    batch_x, batch_y = Variable(batch_x), Variable(batch_y)
    out = model(batch_x)
    loss = loss_func(out, batch_y)
    train_loss += loss.data[0]
    pred = torch.max(out, 1)[1]
    train_correct = (pred == batch_y).sum()
    train_acc += train_correct.data[0]
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
  print('Train Loss: {:.6f}, Acc: {:.6f}'.format(train_loss / (len(
    train_data)), train_acc / (len(train_data))))

  # evaluation--------------------------------
  model.eval()
  eval_loss = 0.
  eval_acc = 0.
  for batch_x, batch_y in test_loader:
    batch_x, batch_y = Variable(batch_x, volatile=True), Variable(batch_y, volatile=True)
    out = model(batch_x)
    loss = loss_func(out, batch_y)
    eval_loss += loss.data[0]
    pred = torch.max(out, 1)[1]
    num_correct = (pred == batch_y).sum()
    eval_acc += num_correct.data[0]
  print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(
    test_data)), eval_acc / (len(test_data))))

打印出来的网络模型:

pytorch 准备、训练和测试自己的图片数据的方法

训练和测试结果:

pytorch 准备、训练和测试自己的图片数据的方法

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python按行读取文件的简单实现方法
Jun 22 Python
Python使用ntplib库同步校准当地时间的方法
Jul 02 Python
Python如何实现文本转语音
Aug 08 Python
python爬虫使用cookie登录详解
Dec 27 Python
python实现扫描日志关键字的示例
Apr 28 Python
如何用python整理附件
May 13 Python
pandas分别写入excel的不同sheet方法
Dec 11 Python
Django ORM 查询表中某列字段值的方法
Apr 30 Python
Python常用数据分析模块原理解析
Jul 20 Python
python批量合成bilibili的m4s缓存文件为MP4格式 ver2.5
Dec 01 Python
python中pyqtgraph知识点总结
Jan 26 Python
python树莓派通过队列实现进程交互的程序分析
Jul 04 Python
pytorch GAN伪造手写体mnist数据集方式
Jan 10 #Python
MNIST数据集转化为二维图片的实现示例
Jan 10 #Python
pytorch:实现简单的GAN示例(MNIST数据集)
Jan 10 #Python
pytorch GAN生成对抗网络实例
Jan 10 #Python
解决pytorch报错:AssertionError: Invalid device id的问题
Jan 10 #Python
python3中关于excel追加写入格式被覆盖问题(实例代码)
Jan 10 #Python
mac使用python识别图形验证码功能
Jan 10 #Python
You might like
PHP操作XML作为数据库的类
2010/12/19 PHP
PHP关联数组实现根据元素值删除元素的方法
2015/06/26 PHP
PHP实现登陆表单提交CSRF及验证码
2017/01/24 PHP
ThinkPHP3.2框架操作Redis的方法分析
2019/05/05 PHP
extjs 为某个事件设置拦截器
2010/01/15 Javascript
js获取GridView中行数据的两种方法 分享
2013/07/13 Javascript
jquery索引在使用中的一些困惑
2013/10/24 Javascript
jQuery的缓存机制浅析
2014/06/07 Javascript
javascript实现单击和双击并存的方法
2014/12/13 Javascript
JavaScript函数内部属性和函数方法实例详解
2016/03/17 Javascript
JS验证图片格式和大小并预览的简单实例
2016/10/11 Javascript
JS+HTML5 FileReader对象用法示例
2017/04/07 Javascript
JS组件系列之MVVM组件 vue 30分钟搞定前端增删改查
2017/04/28 Javascript
解决JS外部文件中文注释出现乱码问题
2017/07/09 Javascript
js图片放大镜实例讲解(必看篇)
2017/07/17 Javascript
Vue如何从1.0迁移到2.0
2017/10/19 Javascript
使用Electron构建React+Webpack桌面应用的方法
2017/12/15 Javascript
vue .sync修饰符的使用详解
2018/06/15 Javascript
实例分析vue循环列表动态数据的处理方法
2018/09/28 Javascript
JavaScript实现PC端四格密码输入框功能
2020/02/19 Javascript
Python读取MRI并显示为灰度图像实例代码
2018/01/03 Python
利用ImageAI库只需几行python代码实现目标检测
2019/08/09 Python
python自动循环定时开关机(非重启)测试
2019/08/26 Python
Python爬虫爬取Bilibili弹幕过程解析
2019/10/10 Python
详解python opencv、scikit-image和PIL图像处理库比较
2019/12/26 Python
pytorch绘制并显示loss曲线和acc曲线,LeNet5识别图像准确率
2020/01/02 Python
Pandas时间序列:时期(period)及其算术运算详解
2020/02/25 Python
Html5实现二维码扫描并解析
2016/01/20 HTML / CSS
YOOX台湾:意大利奢侈品电商
2018/10/13 全球购物
薪酬专员岗位职责
2014/02/18 职场文书
马云北大演讲完整版:真心话,什么才是阿里的核心竞争力?
2014/04/04 职场文书
医院深入开展党的群众路线教育实践活动实施方案
2014/08/27 职场文书
《水浒传》读后感3篇(范文)
2019/09/19 职场文书
关于CSS自定义属性与前端页面的主题切换问题
2022/03/21 HTML / CSS
sql server偶发出现死锁的解决方法
2022/04/10 SQL Server
MySQL数据库查询之多表查询总结
2022/08/05 MySQL