pytorch 准备、训练和测试自己的图片数据的方法


Posted in Python onJanuary 10, 2020

大部分的pytorch入门教程,都是使用torchvision里面的数据进行训练和测试。如果我们是自己的图片数据,又该怎么做呢?

一、我的数据

我在学习的时候,使用的是fashion-mnist。这个数据比较小,我的电脑没有GPU,还能吃得消。关于fashion-mnist数据,可以百度,也可以点此 了解一下,数据就像这个样子:

pytorch 准备、训练和测试自己的图片数据的方法

下载地址:https://github.com/zalandoresearch/fashion-mnist

pytorch 准备、训练和测试自己的图片数据的方法

但是下载下来是一种二进制文件,并不是图片,因此我先转换成了图片。

我先解压gz文件到e:/fashion_mnist/文件夹

然后运行代码:

import os
from skimage import io
import torchvision.datasets.mnist as mnist

root="E:/fashion_mnist/"
train_set = (
  mnist.read_image_file(os.path.join(root, 'train-images-idx3-ubyte')),
  mnist.read_label_file(os.path.join(root, 'train-labels-idx1-ubyte'))
    )
test_set = (
  mnist.read_image_file(os.path.join(root, 't10k-images-idx3-ubyte')),
  mnist.read_label_file(os.path.join(root, 't10k-labels-idx1-ubyte'))
    )
print("training set :",train_set[0].size())
print("test set :",test_set[0].size())

def convert_to_img(train=True):
  if(train):
    f=open(root+'train.txt','w')
    data_path=root+'/train/'
    if(not os.path.exists(data_path)):
      os.makedirs(data_path)
    for i, (img,label) in enumerate(zip(train_set[0],train_set[1])):
      img_path=data_path+str(i)+'.jpg'
      io.imsave(img_path,img.numpy())
      f.write(img_path+' '+str(label)+'\n')
    f.close()
  else:
    f = open(root + 'test.txt', 'w')
    data_path = root + '/test/'
    if (not os.path.exists(data_path)):
      os.makedirs(data_path)
    for i, (img,label) in enumerate(zip(test_set[0],test_set[1])):
      img_path = data_path+ str(i) + '.jpg'
      io.imsave(img_path, img.numpy())
      f.write(img_path + ' ' + str(label) + '\n')
    f.close()

convert_to_img(True)
convert_to_img(False)

这样就会在e:/fashion_mnist/目录下分别生成train和test文件夹,用于存放图片。还在该目录下生成了标签文件train.txt和test.txt.

二、进行CNN分类训练和测试

先要将图片读取出来,准备成torch专用的dataset格式,再通过Dataloader进行分批次训练。

代码如下:

import torch
from torch.autograd import Variable
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
root="E:/fashion_mnist/"

# -----------------ready the dataset--------------------------
def default_loader(path):
  return Image.open(path).convert('RGB')
class MyDataset(Dataset):
  def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
    fh = open(txt, 'r')
    imgs = []
    for line in fh:
      line = line.strip('\n')
      line = line.rstrip()
      words = line.split()
      imgs.append((words[0],int(words[1])))
    self.imgs = imgs
    self.transform = transform
    self.target_transform = target_transform
    self.loader = loader

  def __getitem__(self, index):
    fn, label = self.imgs[index]
    img = self.loader(fn)
    if self.transform is not None:
      img = self.transform(img)
    return img,label

  def __len__(self):
    return len(self.imgs)

train_data=MyDataset(txt=root+'train.txt', transform=transforms.ToTensor())
test_data=MyDataset(txt=root+'test.txt', transform=transforms.ToTensor())
train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=64)


#-----------------create the Net and training------------------------

class Net(torch.nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.conv1 = torch.nn.Sequential(
      torch.nn.Conv2d(3, 32, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2))
    self.conv2 = torch.nn.Sequential(
      torch.nn.Conv2d(32, 64, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2)
    )
    self.conv3 = torch.nn.Sequential(
      torch.nn.Conv2d(64, 64, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2)
    )
    self.dense = torch.nn.Sequential(
      torch.nn.Linear(64 * 3 * 3, 128),
      torch.nn.ReLU(),
      torch.nn.Linear(128, 10)
    )

  def forward(self, x):
    conv1_out = self.conv1(x)
    conv2_out = self.conv2(conv1_out)
    conv3_out = self.conv3(conv2_out)
    res = conv3_out.view(conv3_out.size(0), -1)
    out = self.dense(res)
    return out


model = Net()
print(model)

optimizer = torch.optim.Adam(model.parameters())
loss_func = torch.nn.CrossEntropyLoss()

for epoch in range(10):
  print('epoch {}'.format(epoch + 1))
  # training-----------------------------
  train_loss = 0.
  train_acc = 0.
  for batch_x, batch_y in train_loader:
    batch_x, batch_y = Variable(batch_x), Variable(batch_y)
    out = model(batch_x)
    loss = loss_func(out, batch_y)
    train_loss += loss.data[0]
    pred = torch.max(out, 1)[1]
    train_correct = (pred == batch_y).sum()
    train_acc += train_correct.data[0]
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
  print('Train Loss: {:.6f}, Acc: {:.6f}'.format(train_loss / (len(
    train_data)), train_acc / (len(train_data))))

  # evaluation--------------------------------
  model.eval()
  eval_loss = 0.
  eval_acc = 0.
  for batch_x, batch_y in test_loader:
    batch_x, batch_y = Variable(batch_x, volatile=True), Variable(batch_y, volatile=True)
    out = model(batch_x)
    loss = loss_func(out, batch_y)
    eval_loss += loss.data[0]
    pred = torch.max(out, 1)[1]
    num_correct = (pred == batch_y).sum()
    eval_acc += num_correct.data[0]
  print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(
    test_data)), eval_acc / (len(test_data))))

打印出来的网络模型:

pytorch 准备、训练和测试自己的图片数据的方法

训练和测试结果:

pytorch 准备、训练和测试自己的图片数据的方法

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 拷贝对象(深拷贝deepcopy与浅拷贝copy)
Sep 06 Python
如何使用七牛Python SDK写一个同步脚本及使用教程
Aug 23 Python
浅谈Python的异常处理
Jun 19 Python
CentOS6.9 Python环境配置(python2.7、pip、virtualenv)
May 06 Python
python多线程http压力测试脚本
Jun 25 Python
MxNet预训练模型到Pytorch模型的转换方式
May 25 Python
Python 绘制可视化折线图
Jul 22 Python
python实现企业微信定时发送文本消息的示例代码
Nov 24 Python
python爬虫中抓取指数的实例讲解
Dec 01 Python
Jupyter notebook 不自动弹出网页的解决方案
May 21 Python
Python利用folium实现地图可视化
May 23 Python
Python办公自动化解决world文件批量转换
Sep 15 Python
pytorch GAN伪造手写体mnist数据集方式
Jan 10 #Python
MNIST数据集转化为二维图片的实现示例
Jan 10 #Python
pytorch:实现简单的GAN示例(MNIST数据集)
Jan 10 #Python
pytorch GAN生成对抗网络实例
Jan 10 #Python
解决pytorch报错:AssertionError: Invalid device id的问题
Jan 10 #Python
python3中关于excel追加写入格式被覆盖问题(实例代码)
Jan 10 #Python
mac使用python识别图形验证码功能
Jan 10 #Python
You might like
PHP常用函数小技巧
2008/09/11 PHP
laravel-admin 管理平台获取当前登陆用户信息的例子
2019/10/08 PHP
jQuery EasyUI API 中文文档 - Pagination分页
2011/09/29 Javascript
动态改变div的z-index属性的简单实例
2013/08/08 Javascript
js的alert弹出框出现乱码解决方案
2013/09/02 Javascript
JQueryiframe页面操作父页面中的元素与方法(实例讲解)
2013/11/19 Javascript
JS实现根据当前文字选择返回被选中的文字
2014/05/21 Javascript
用jquery模仿的a的title属性的例子
2014/10/22 Javascript
jquery动画效果学习笔记(8种效果)
2015/11/13 Javascript
javascript实现的左右无缝滚动效果
2016/09/19 Javascript
Sortable.js拖拽排序使用方法解析
2016/11/04 Javascript
微信小程序 数组中的push与concat的区别
2017/01/05 Javascript
easyui简介_动力节点Java学院整理
2017/07/14 Javascript
在ABP框架中使用BootstrapTable组件的方法
2017/07/31 Javascript
three.js中文文档学习之通过模块导入
2017/11/20 Javascript
MVVM 双向绑定的实现代码
2018/06/21 Javascript
详解关于element el-button使用$attrs的一个注意要点
2018/11/09 Javascript
详解JavaScript的内存空间、赋值和深浅拷贝
2019/04/17 Javascript
详解小程序云开发攻略(解决最棘手的问题)
2019/09/30 Javascript
JavaScript Event Loop相关原理解析
2020/06/10 Javascript
浅谈vue 组件中的setInterval方法和window的不同
2020/07/30 Javascript
[48:32]VGJ.T vs Fnatic 2018国际邀请赛小组赛BO2 第一场 8.16
2018/08/17 DOTA
Django Admin实现三级联动的示例代码(省市区)
2018/06/22 Python
Python实现一个带权无回置随机抽选函数的方法
2019/07/24 Python
Django 自定义权限管理系统详解(通过中间件认证)
2020/03/11 Python
浅谈cv2.imread()和keras.preprocessing中的image.load_img()区别
2020/06/12 Python
Python selenium模块实现定位过程解析
2020/07/09 Python
elf彩妆英国官网:e.l.f. Cosmetics英国(美国平价彩妆品牌)
2017/11/02 全球购物
工程师岗位职责
2013/11/08 职场文书
2014年自我评价
2014/01/04 职场文书
读书演讲主持词
2014/03/18 职场文书
企业优秀员工事迹材料
2014/05/28 职场文书
宣传工作经验材料
2014/06/02 职场文书
颂军魂爱军营演讲稿
2014/09/13 职场文书
2019年思想汇报
2019/06/20 职场文书
js实现上传图片到服务器
2021/04/11 Javascript