pytorch 准备、训练和测试自己的图片数据的方法


Posted in Python onJanuary 10, 2020

大部分的pytorch入门教程,都是使用torchvision里面的数据进行训练和测试。如果我们是自己的图片数据,又该怎么做呢?

一、我的数据

我在学习的时候,使用的是fashion-mnist。这个数据比较小,我的电脑没有GPU,还能吃得消。关于fashion-mnist数据,可以百度,也可以点此 了解一下,数据就像这个样子:

pytorch 准备、训练和测试自己的图片数据的方法

下载地址:https://github.com/zalandoresearch/fashion-mnist

pytorch 准备、训练和测试自己的图片数据的方法

但是下载下来是一种二进制文件,并不是图片,因此我先转换成了图片。

我先解压gz文件到e:/fashion_mnist/文件夹

然后运行代码:

import os
from skimage import io
import torchvision.datasets.mnist as mnist

root="E:/fashion_mnist/"
train_set = (
  mnist.read_image_file(os.path.join(root, 'train-images-idx3-ubyte')),
  mnist.read_label_file(os.path.join(root, 'train-labels-idx1-ubyte'))
    )
test_set = (
  mnist.read_image_file(os.path.join(root, 't10k-images-idx3-ubyte')),
  mnist.read_label_file(os.path.join(root, 't10k-labels-idx1-ubyte'))
    )
print("training set :",train_set[0].size())
print("test set :",test_set[0].size())

def convert_to_img(train=True):
  if(train):
    f=open(root+'train.txt','w')
    data_path=root+'/train/'
    if(not os.path.exists(data_path)):
      os.makedirs(data_path)
    for i, (img,label) in enumerate(zip(train_set[0],train_set[1])):
      img_path=data_path+str(i)+'.jpg'
      io.imsave(img_path,img.numpy())
      f.write(img_path+' '+str(label)+'\n')
    f.close()
  else:
    f = open(root + 'test.txt', 'w')
    data_path = root + '/test/'
    if (not os.path.exists(data_path)):
      os.makedirs(data_path)
    for i, (img,label) in enumerate(zip(test_set[0],test_set[1])):
      img_path = data_path+ str(i) + '.jpg'
      io.imsave(img_path, img.numpy())
      f.write(img_path + ' ' + str(label) + '\n')
    f.close()

convert_to_img(True)
convert_to_img(False)

这样就会在e:/fashion_mnist/目录下分别生成train和test文件夹,用于存放图片。还在该目录下生成了标签文件train.txt和test.txt.

二、进行CNN分类训练和测试

先要将图片读取出来,准备成torch专用的dataset格式,再通过Dataloader进行分批次训练。

代码如下:

import torch
from torch.autograd import Variable
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
root="E:/fashion_mnist/"

# -----------------ready the dataset--------------------------
def default_loader(path):
  return Image.open(path).convert('RGB')
class MyDataset(Dataset):
  def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
    fh = open(txt, 'r')
    imgs = []
    for line in fh:
      line = line.strip('\n')
      line = line.rstrip()
      words = line.split()
      imgs.append((words[0],int(words[1])))
    self.imgs = imgs
    self.transform = transform
    self.target_transform = target_transform
    self.loader = loader

  def __getitem__(self, index):
    fn, label = self.imgs[index]
    img = self.loader(fn)
    if self.transform is not None:
      img = self.transform(img)
    return img,label

  def __len__(self):
    return len(self.imgs)

train_data=MyDataset(txt=root+'train.txt', transform=transforms.ToTensor())
test_data=MyDataset(txt=root+'test.txt', transform=transforms.ToTensor())
train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=64)


#-----------------create the Net and training------------------------

class Net(torch.nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.conv1 = torch.nn.Sequential(
      torch.nn.Conv2d(3, 32, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2))
    self.conv2 = torch.nn.Sequential(
      torch.nn.Conv2d(32, 64, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2)
    )
    self.conv3 = torch.nn.Sequential(
      torch.nn.Conv2d(64, 64, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2)
    )
    self.dense = torch.nn.Sequential(
      torch.nn.Linear(64 * 3 * 3, 128),
      torch.nn.ReLU(),
      torch.nn.Linear(128, 10)
    )

  def forward(self, x):
    conv1_out = self.conv1(x)
    conv2_out = self.conv2(conv1_out)
    conv3_out = self.conv3(conv2_out)
    res = conv3_out.view(conv3_out.size(0), -1)
    out = self.dense(res)
    return out


model = Net()
print(model)

optimizer = torch.optim.Adam(model.parameters())
loss_func = torch.nn.CrossEntropyLoss()

for epoch in range(10):
  print('epoch {}'.format(epoch + 1))
  # training-----------------------------
  train_loss = 0.
  train_acc = 0.
  for batch_x, batch_y in train_loader:
    batch_x, batch_y = Variable(batch_x), Variable(batch_y)
    out = model(batch_x)
    loss = loss_func(out, batch_y)
    train_loss += loss.data[0]
    pred = torch.max(out, 1)[1]
    train_correct = (pred == batch_y).sum()
    train_acc += train_correct.data[0]
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
  print('Train Loss: {:.6f}, Acc: {:.6f}'.format(train_loss / (len(
    train_data)), train_acc / (len(train_data))))

  # evaluation--------------------------------
  model.eval()
  eval_loss = 0.
  eval_acc = 0.
  for batch_x, batch_y in test_loader:
    batch_x, batch_y = Variable(batch_x, volatile=True), Variable(batch_y, volatile=True)
    out = model(batch_x)
    loss = loss_func(out, batch_y)
    eval_loss += loss.data[0]
    pred = torch.max(out, 1)[1]
    num_correct = (pred == batch_y).sum()
    eval_acc += num_correct.data[0]
  print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(
    test_data)), eval_acc / (len(test_data))))

打印出来的网络模型:

pytorch 准备、训练和测试自己的图片数据的方法

训练和测试结果:

pytorch 准备、训练和测试自己的图片数据的方法

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python人人网登录应用实例
Sep 26 Python
python操作ie登陆土豆网的方法
May 09 Python
VTK与Python实现机械臂三维模型可视化详解
Dec 13 Python
python放大图片和画方格实现算法
Mar 30 Python
浅谈pytorch和Numpy的区别以及相互转换方法
Jul 26 Python
pycharm远程开发项目的实现步骤
Jan 20 Python
简单了解django索引的相关知识
Jul 17 Python
Python笔记之代理模式
Nov 20 Python
python和pywin32实现窗口查找、遍历和点击的示例代码
Apr 01 Python
打印tensorflow恢复模型中所有变量与操作节点方式
May 26 Python
python 实现控制鼠标键盘
Nov 27 Python
pandas 数据类型转换的实现
Dec 29 Python
pytorch GAN伪造手写体mnist数据集方式
Jan 10 #Python
MNIST数据集转化为二维图片的实现示例
Jan 10 #Python
pytorch:实现简单的GAN示例(MNIST数据集)
Jan 10 #Python
pytorch GAN生成对抗网络实例
Jan 10 #Python
解决pytorch报错:AssertionError: Invalid device id的问题
Jan 10 #Python
python3中关于excel追加写入格式被覆盖问题(实例代码)
Jan 10 #Python
mac使用python识别图形验证码功能
Jan 10 #Python
You might like
几个php应用技巧
2008/03/27 PHP
基于initPHP的框架介绍
2013/04/18 PHP
ThinkPHP模板循环输出Volist标签用法实例详解
2016/03/23 PHP
PHP实现从PostgreSQL数据库检索数据分页显示及根据条件查找数据示例
2018/06/09 PHP
从零开始学习jQuery (四) jQuery中操作元素的属性与样式
2011/02/23 Javascript
js有序数组的连接问题
2013/10/01 Javascript
封装的jquery翻页滚动(示例代码)
2013/11/18 Javascript
node.js中的url.parse方法使用说明
2014/12/10 Javascript
JQuery遍历DOM节点的方法
2015/06/11 Javascript
jQuery实现带有动画效果的回到顶部和底部代码
2015/11/04 Javascript
BootStrap组件之进度条的基本用法
2017/01/19 Javascript
bootstrap轮播图示例代码分享
2017/05/17 Javascript
EasyUI实现下拉框多选功能
2017/11/07 Javascript
关于ES6箭头函数中的this问题
2018/02/27 Javascript
Angular开发实践之服务端渲染
2018/03/29 Javascript
深入理解Promise.all
2018/08/08 Javascript
jQuery实现图片简单轮播功能示例
2018/08/13 jQuery
element-ui中Table表格省市区合并单元格的方法实现
2019/08/07 Javascript
vue中element 的upload组件发送请求给后端操作
2020/09/07 Javascript
零基础写python爬虫之使用urllib2组件抓取网页内容
2014/11/04 Python
使用rpclib进行Python网络编程时的注释问题
2015/05/06 Python
python中hashlib模块用法示例
2017/10/30 Python
NumPy.npy与pandas DataFrame的实例讲解
2018/07/09 Python
python 多线程串行和并行的实例
2019/02/22 Python
python实现连连看辅助之图像识别延伸
2019/07/17 Python
Matplotlib绘制雷达图和三维图的示例代码
2020/01/07 Python
详解Python 实现 ZeroMQ 的三种基本工作模式
2020/03/24 Python
如何用Anaconda搭建虚拟环境并创建Django项目
2020/08/02 Python
迪斯尼商品官方网站:ShopDisney
2016/08/01 全球购物
中国跨境在线时尚零售商:Bellelily
2018/04/06 全球购物
小学教师师德感言
2014/02/10 职场文书
入股协议书
2014/04/14 职场文书
法英专业大学生职业生涯规划书范文
2014/09/22 职场文书
2015年学校安全管理工作总结
2015/05/11 职场文书
电影小兵张嘎观后感
2015/06/03 职场文书
地心历险记观后感
2015/06/15 职场文书