pytorch 准备、训练和测试自己的图片数据的方法


Posted in Python onJanuary 10, 2020

大部分的pytorch入门教程,都是使用torchvision里面的数据进行训练和测试。如果我们是自己的图片数据,又该怎么做呢?

一、我的数据

我在学习的时候,使用的是fashion-mnist。这个数据比较小,我的电脑没有GPU,还能吃得消。关于fashion-mnist数据,可以百度,也可以点此 了解一下,数据就像这个样子:

pytorch 准备、训练和测试自己的图片数据的方法

下载地址:https://github.com/zalandoresearch/fashion-mnist

pytorch 准备、训练和测试自己的图片数据的方法

但是下载下来是一种二进制文件,并不是图片,因此我先转换成了图片。

我先解压gz文件到e:/fashion_mnist/文件夹

然后运行代码:

import os
from skimage import io
import torchvision.datasets.mnist as mnist

root="E:/fashion_mnist/"
train_set = (
  mnist.read_image_file(os.path.join(root, 'train-images-idx3-ubyte')),
  mnist.read_label_file(os.path.join(root, 'train-labels-idx1-ubyte'))
    )
test_set = (
  mnist.read_image_file(os.path.join(root, 't10k-images-idx3-ubyte')),
  mnist.read_label_file(os.path.join(root, 't10k-labels-idx1-ubyte'))
    )
print("training set :",train_set[0].size())
print("test set :",test_set[0].size())

def convert_to_img(train=True):
  if(train):
    f=open(root+'train.txt','w')
    data_path=root+'/train/'
    if(not os.path.exists(data_path)):
      os.makedirs(data_path)
    for i, (img,label) in enumerate(zip(train_set[0],train_set[1])):
      img_path=data_path+str(i)+'.jpg'
      io.imsave(img_path,img.numpy())
      f.write(img_path+' '+str(label)+'\n')
    f.close()
  else:
    f = open(root + 'test.txt', 'w')
    data_path = root + '/test/'
    if (not os.path.exists(data_path)):
      os.makedirs(data_path)
    for i, (img,label) in enumerate(zip(test_set[0],test_set[1])):
      img_path = data_path+ str(i) + '.jpg'
      io.imsave(img_path, img.numpy())
      f.write(img_path + ' ' + str(label) + '\n')
    f.close()

convert_to_img(True)
convert_to_img(False)

这样就会在e:/fashion_mnist/目录下分别生成train和test文件夹,用于存放图片。还在该目录下生成了标签文件train.txt和test.txt.

二、进行CNN分类训练和测试

先要将图片读取出来,准备成torch专用的dataset格式,再通过Dataloader进行分批次训练。

代码如下:

import torch
from torch.autograd import Variable
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
root="E:/fashion_mnist/"

# -----------------ready the dataset--------------------------
def default_loader(path):
  return Image.open(path).convert('RGB')
class MyDataset(Dataset):
  def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
    fh = open(txt, 'r')
    imgs = []
    for line in fh:
      line = line.strip('\n')
      line = line.rstrip()
      words = line.split()
      imgs.append((words[0],int(words[1])))
    self.imgs = imgs
    self.transform = transform
    self.target_transform = target_transform
    self.loader = loader

  def __getitem__(self, index):
    fn, label = self.imgs[index]
    img = self.loader(fn)
    if self.transform is not None:
      img = self.transform(img)
    return img,label

  def __len__(self):
    return len(self.imgs)

train_data=MyDataset(txt=root+'train.txt', transform=transforms.ToTensor())
test_data=MyDataset(txt=root+'test.txt', transform=transforms.ToTensor())
train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=64)


#-----------------create the Net and training------------------------

class Net(torch.nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.conv1 = torch.nn.Sequential(
      torch.nn.Conv2d(3, 32, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2))
    self.conv2 = torch.nn.Sequential(
      torch.nn.Conv2d(32, 64, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2)
    )
    self.conv3 = torch.nn.Sequential(
      torch.nn.Conv2d(64, 64, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2)
    )
    self.dense = torch.nn.Sequential(
      torch.nn.Linear(64 * 3 * 3, 128),
      torch.nn.ReLU(),
      torch.nn.Linear(128, 10)
    )

  def forward(self, x):
    conv1_out = self.conv1(x)
    conv2_out = self.conv2(conv1_out)
    conv3_out = self.conv3(conv2_out)
    res = conv3_out.view(conv3_out.size(0), -1)
    out = self.dense(res)
    return out


model = Net()
print(model)

optimizer = torch.optim.Adam(model.parameters())
loss_func = torch.nn.CrossEntropyLoss()

for epoch in range(10):
  print('epoch {}'.format(epoch + 1))
  # training-----------------------------
  train_loss = 0.
  train_acc = 0.
  for batch_x, batch_y in train_loader:
    batch_x, batch_y = Variable(batch_x), Variable(batch_y)
    out = model(batch_x)
    loss = loss_func(out, batch_y)
    train_loss += loss.data[0]
    pred = torch.max(out, 1)[1]
    train_correct = (pred == batch_y).sum()
    train_acc += train_correct.data[0]
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
  print('Train Loss: {:.6f}, Acc: {:.6f}'.format(train_loss / (len(
    train_data)), train_acc / (len(train_data))))

  # evaluation--------------------------------
  model.eval()
  eval_loss = 0.
  eval_acc = 0.
  for batch_x, batch_y in test_loader:
    batch_x, batch_y = Variable(batch_x, volatile=True), Variable(batch_y, volatile=True)
    out = model(batch_x)
    loss = loss_func(out, batch_y)
    eval_loss += loss.data[0]
    pred = torch.max(out, 1)[1]
    num_correct = (pred == batch_y).sum()
    eval_acc += num_correct.data[0]
  print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(
    test_data)), eval_acc / (len(test_data))))

打印出来的网络模型:

pytorch 准备、训练和测试自己的图片数据的方法

训练和测试结果:

pytorch 准备、训练和测试自己的图片数据的方法

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现多行注释的另类方法
Aug 22 Python
全面解析Python的While循环语句的使用方法
Oct 13 Python
python编写分类决策树的代码
Dec 21 Python
python实现雨滴下落到地面效果
Jun 21 Python
mac安装pytorch及系统的numpy更新方法
Jul 26 Python
PyQt4实时显示文本内容GUI的示例
Jun 14 Python
PowerBI和Python关于数据分析的对比
Jul 11 Python
使用PyCharm进行远程开发和调试的实现
Nov 04 Python
Pytorch DataLoader 变长数据处理方式
Jan 08 Python
Python类和实例的属性机制原理详解
Mar 21 Python
Python实现中英文全文搜索的示例
Dec 04 Python
python实现双向链表原理
May 25 Python
pytorch GAN伪造手写体mnist数据集方式
Jan 10 #Python
MNIST数据集转化为二维图片的实现示例
Jan 10 #Python
pytorch:实现简单的GAN示例(MNIST数据集)
Jan 10 #Python
pytorch GAN生成对抗网络实例
Jan 10 #Python
解决pytorch报错:AssertionError: Invalid device id的问题
Jan 10 #Python
python3中关于excel追加写入格式被覆盖问题(实例代码)
Jan 10 #Python
mac使用python识别图形验证码功能
Jan 10 #Python
You might like
php注入实例
2006/10/09 PHP
PHP4与PHP3中一个不兼容问题的解决方法
2006/10/09 PHP
浅析PHP水印技术
2007/02/14 PHP
php+ajax做仿百度搜索下拉自动提示框(有实例)
2012/08/21 PHP
深入PHP操作MongoDB的技术总结
2013/06/02 PHP
学习php设计模式 php实现享元模式(flyweight)
2015/12/07 PHP
Yii 2.0中场景的使用教程
2017/06/02 PHP
JavaScript OOP类与继承
2009/11/15 Javascript
javascript基础知识大全 便于大家学习,也便于我自己查看
2012/08/17 Javascript
jquery及原生js获取select下拉框选中的值示例
2013/10/25 Javascript
javascript删除字符串最后一个字符
2014/01/14 Javascript
node.js中的fs.mkdirSync方法使用说明
2014/12/17 Javascript
全面解析Bootstrap中Carousel轮播的使用方法
2016/06/13 Javascript
Bootstrap CSS组件之按钮组(btn-group)
2016/12/17 Javascript
Javascript中的prototype与继承
2017/02/06 Javascript
jquery.validate表单验证插件使用详解
2017/06/21 jQuery
bootstrap3-dialog-master模态框使用详解
2017/08/22 Javascript
Angular入口组件(entry component)与声明式组件的区别详解
2018/04/09 Javascript
Vue二次封装axios为插件使用详解
2018/05/21 Javascript
vue动态改变背景图片demo分享
2018/09/13 Javascript
详解express使用vue-router的history踩坑
2019/06/05 Javascript
javascript实现的字符串转换成数组操作示例
2019/06/13 Javascript
JavaScript实现多文件下载方法解析
2020/08/07 Javascript
python 排列组合之itertools
2013/03/20 Python
python K近邻算法的kd树实现
2018/09/06 Python
使用python实现简单五子棋游戏
2019/06/18 Python
python fuzzywuzzy模块模糊字符串匹配详细用法
2019/08/29 Python
如何基于Python实现数字类型转换
2020/02/07 Python
python自动点赞功能的实现思路
2020/02/26 Python
python将音频进行变速的操作方法
2020/04/08 Python
西班牙网上书店:Casa del Libro
2016/11/01 全球购物
向全球直邮输送天然健康产品:iHerb.com
2020/05/03 全球购物
关于爱国的标语
2014/06/24 职场文书
课外活动总结范文
2014/07/09 职场文书
MySQL 视图(View)原理解析
2021/05/19 MySQL
pytorch中的torch.nn.Conv2d()函数图文详解
2022/02/28 Python