pytorch 准备、训练和测试自己的图片数据的方法


Posted in Python onJanuary 10, 2020

大部分的pytorch入门教程,都是使用torchvision里面的数据进行训练和测试。如果我们是自己的图片数据,又该怎么做呢?

一、我的数据

我在学习的时候,使用的是fashion-mnist。这个数据比较小,我的电脑没有GPU,还能吃得消。关于fashion-mnist数据,可以百度,也可以点此 了解一下,数据就像这个样子:

pytorch 准备、训练和测试自己的图片数据的方法

下载地址:https://github.com/zalandoresearch/fashion-mnist

pytorch 准备、训练和测试自己的图片数据的方法

但是下载下来是一种二进制文件,并不是图片,因此我先转换成了图片。

我先解压gz文件到e:/fashion_mnist/文件夹

然后运行代码:

import os
from skimage import io
import torchvision.datasets.mnist as mnist

root="E:/fashion_mnist/"
train_set = (
  mnist.read_image_file(os.path.join(root, 'train-images-idx3-ubyte')),
  mnist.read_label_file(os.path.join(root, 'train-labels-idx1-ubyte'))
    )
test_set = (
  mnist.read_image_file(os.path.join(root, 't10k-images-idx3-ubyte')),
  mnist.read_label_file(os.path.join(root, 't10k-labels-idx1-ubyte'))
    )
print("training set :",train_set[0].size())
print("test set :",test_set[0].size())

def convert_to_img(train=True):
  if(train):
    f=open(root+'train.txt','w')
    data_path=root+'/train/'
    if(not os.path.exists(data_path)):
      os.makedirs(data_path)
    for i, (img,label) in enumerate(zip(train_set[0],train_set[1])):
      img_path=data_path+str(i)+'.jpg'
      io.imsave(img_path,img.numpy())
      f.write(img_path+' '+str(label)+'\n')
    f.close()
  else:
    f = open(root + 'test.txt', 'w')
    data_path = root + '/test/'
    if (not os.path.exists(data_path)):
      os.makedirs(data_path)
    for i, (img,label) in enumerate(zip(test_set[0],test_set[1])):
      img_path = data_path+ str(i) + '.jpg'
      io.imsave(img_path, img.numpy())
      f.write(img_path + ' ' + str(label) + '\n')
    f.close()

convert_to_img(True)
convert_to_img(False)

这样就会在e:/fashion_mnist/目录下分别生成train和test文件夹,用于存放图片。还在该目录下生成了标签文件train.txt和test.txt.

二、进行CNN分类训练和测试

先要将图片读取出来,准备成torch专用的dataset格式,再通过Dataloader进行分批次训练。

代码如下:

import torch
from torch.autograd import Variable
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
root="E:/fashion_mnist/"

# -----------------ready the dataset--------------------------
def default_loader(path):
  return Image.open(path).convert('RGB')
class MyDataset(Dataset):
  def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
    fh = open(txt, 'r')
    imgs = []
    for line in fh:
      line = line.strip('\n')
      line = line.rstrip()
      words = line.split()
      imgs.append((words[0],int(words[1])))
    self.imgs = imgs
    self.transform = transform
    self.target_transform = target_transform
    self.loader = loader

  def __getitem__(self, index):
    fn, label = self.imgs[index]
    img = self.loader(fn)
    if self.transform is not None:
      img = self.transform(img)
    return img,label

  def __len__(self):
    return len(self.imgs)

train_data=MyDataset(txt=root+'train.txt', transform=transforms.ToTensor())
test_data=MyDataset(txt=root+'test.txt', transform=transforms.ToTensor())
train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=64)


#-----------------create the Net and training------------------------

class Net(torch.nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.conv1 = torch.nn.Sequential(
      torch.nn.Conv2d(3, 32, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2))
    self.conv2 = torch.nn.Sequential(
      torch.nn.Conv2d(32, 64, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2)
    )
    self.conv3 = torch.nn.Sequential(
      torch.nn.Conv2d(64, 64, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2)
    )
    self.dense = torch.nn.Sequential(
      torch.nn.Linear(64 * 3 * 3, 128),
      torch.nn.ReLU(),
      torch.nn.Linear(128, 10)
    )

  def forward(self, x):
    conv1_out = self.conv1(x)
    conv2_out = self.conv2(conv1_out)
    conv3_out = self.conv3(conv2_out)
    res = conv3_out.view(conv3_out.size(0), -1)
    out = self.dense(res)
    return out


model = Net()
print(model)

optimizer = torch.optim.Adam(model.parameters())
loss_func = torch.nn.CrossEntropyLoss()

for epoch in range(10):
  print('epoch {}'.format(epoch + 1))
  # training-----------------------------
  train_loss = 0.
  train_acc = 0.
  for batch_x, batch_y in train_loader:
    batch_x, batch_y = Variable(batch_x), Variable(batch_y)
    out = model(batch_x)
    loss = loss_func(out, batch_y)
    train_loss += loss.data[0]
    pred = torch.max(out, 1)[1]
    train_correct = (pred == batch_y).sum()
    train_acc += train_correct.data[0]
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
  print('Train Loss: {:.6f}, Acc: {:.6f}'.format(train_loss / (len(
    train_data)), train_acc / (len(train_data))))

  # evaluation--------------------------------
  model.eval()
  eval_loss = 0.
  eval_acc = 0.
  for batch_x, batch_y in test_loader:
    batch_x, batch_y = Variable(batch_x, volatile=True), Variable(batch_y, volatile=True)
    out = model(batch_x)
    loss = loss_func(out, batch_y)
    eval_loss += loss.data[0]
    pred = torch.max(out, 1)[1]
    num_correct = (pred == batch_y).sum()
    eval_acc += num_correct.data[0]
  print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(
    test_data)), eval_acc / (len(test_data))))

打印出来的网络模型:

pytorch 准备、训练和测试自己的图片数据的方法

训练和测试结果:

pytorch 准备、训练和测试自己的图片数据的方法

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
手动实现把python项目发布为exe可执行程序过程分享
Oct 23 Python
Python读写Excel文件方法介绍
Nov 22 Python
AI人工智能 Python实现人机对话
Nov 13 Python
pandas修改DataFrame列名的方法
Apr 08 Python
python使用opencv实现马赛克效果示例
Sep 28 Python
pytorch中的上采样以及各种反操作,求逆操作详解
Jan 03 Python
python使用Geany编辑器配置方法
Feb 21 Python
简单了解pytest测试框架setup和tearDown
Apr 14 Python
python图片验证码识别最新模块muggle_ocr的示例代码
Jul 03 Python
Django nginx配置实现过程详解
Sep 10 Python
Python中的np.argmin()和np.argmax()函数用法
Jun 02 Python
Python学习之异常中的finally使用详解
Mar 16 Python
pytorch GAN伪造手写体mnist数据集方式
Jan 10 #Python
MNIST数据集转化为二维图片的实现示例
Jan 10 #Python
pytorch:实现简单的GAN示例(MNIST数据集)
Jan 10 #Python
pytorch GAN生成对抗网络实例
Jan 10 #Python
解决pytorch报错:AssertionError: Invalid device id的问题
Jan 10 #Python
python3中关于excel追加写入格式被覆盖问题(实例代码)
Jan 10 #Python
mac使用python识别图形验证码功能
Jan 10 #Python
You might like
php 应用程序安全防范技术研究
2009/09/25 PHP
php 字符串替换的方法
2012/01/10 PHP
关于查看MSSQL 数据库 用户每个表 占用的空间大小
2013/06/21 PHP
关于Yii中模型场景的一些简单介绍
2019/09/22 PHP
js数字输入框(包括最大值最小值限制和四舍五入)
2009/11/24 Javascript
div浮层,滚动条移动,位置保持不变的4种方法汇总
2013/12/11 Javascript
推荐一个自己用的封装好的javascript插件
2015/01/29 Javascript
js+css实现上下翻页相册代码分享
2015/08/18 Javascript
jquery zTree异步加载简单实例讲解
2016/02/25 Javascript
JavaScript面试开发常用的知识点总结
2016/08/08 Javascript
JQuery实现DIV其他动画效果的简单实例
2016/09/18 Javascript
微信小程序 wxapp画布 canvas详细介绍
2016/10/31 Javascript
js中的this的指向问题详解
2019/08/29 Javascript
从零搭一个自用的前端脚手架的方法步骤
2019/09/23 Javascript
[01:12:35]Spirit vs Navi Supermajor小组赛 A组败者组第一轮 BO3 第二场 6.2
2018/06/03 DOTA
Python实现从订阅源下载图片的方法
2015/03/11 Python
Python中编写ORM框架的入门指引
2015/04/29 Python
python MySQLdb Windows下安装教程及问题解决方法
2015/05/09 Python
Python通过matplotlib绘制动画简单实例
2017/12/13 Python
Python网络爬虫神器PyQuery的基本使用教程
2018/02/03 Python
selenium设置proxy、headers的方法(phantomjs、Chrome、Firefox)
2018/11/29 Python
Python Django实现layui风格+django分页功能的例子
2019/08/29 Python
python利用JMeter测试Tornado的多线程
2020/01/12 Python
Python操作MongoDb数据库流程详解
2020/03/05 Python
基于python实现计算且附带进度条代码实例
2020/03/31 Python
几款好用的python工具库(小结)
2020/10/20 Python
python 将html转换为pdf的几种方法
2020/12/29 Python
耐克美国官网:Nike.com
2016/08/01 全球购物
全球地下的服装和态度:Slam Jam
2018/02/04 全球购物
以工厂直接定价的传奇性能:Ben Hogan Golf
2019/01/04 全球购物
12月红领巾广播稿
2014/02/13 职场文书
优秀毕业生推荐信范文
2014/03/07 职场文书
高中军训感言600字
2014/03/11 职场文书
廉洁自律演讲稿
2014/05/22 职场文书
只用40行Python代码就能写出pdf转word小工具
2021/05/31 Python
Python可视化动图组件ipyvizzu绘制惊艳的可视化动图
2022/04/21 Python