pytorch 准备、训练和测试自己的图片数据的方法


Posted in Python onJanuary 10, 2020

大部分的pytorch入门教程,都是使用torchvision里面的数据进行训练和测试。如果我们是自己的图片数据,又该怎么做呢?

一、我的数据

我在学习的时候,使用的是fashion-mnist。这个数据比较小,我的电脑没有GPU,还能吃得消。关于fashion-mnist数据,可以百度,也可以点此 了解一下,数据就像这个样子:

pytorch 准备、训练和测试自己的图片数据的方法

下载地址:https://github.com/zalandoresearch/fashion-mnist

pytorch 准备、训练和测试自己的图片数据的方法

但是下载下来是一种二进制文件,并不是图片,因此我先转换成了图片。

我先解压gz文件到e:/fashion_mnist/文件夹

然后运行代码:

import os
from skimage import io
import torchvision.datasets.mnist as mnist

root="E:/fashion_mnist/"
train_set = (
  mnist.read_image_file(os.path.join(root, 'train-images-idx3-ubyte')),
  mnist.read_label_file(os.path.join(root, 'train-labels-idx1-ubyte'))
    )
test_set = (
  mnist.read_image_file(os.path.join(root, 't10k-images-idx3-ubyte')),
  mnist.read_label_file(os.path.join(root, 't10k-labels-idx1-ubyte'))
    )
print("training set :",train_set[0].size())
print("test set :",test_set[0].size())

def convert_to_img(train=True):
  if(train):
    f=open(root+'train.txt','w')
    data_path=root+'/train/'
    if(not os.path.exists(data_path)):
      os.makedirs(data_path)
    for i, (img,label) in enumerate(zip(train_set[0],train_set[1])):
      img_path=data_path+str(i)+'.jpg'
      io.imsave(img_path,img.numpy())
      f.write(img_path+' '+str(label)+'\n')
    f.close()
  else:
    f = open(root + 'test.txt', 'w')
    data_path = root + '/test/'
    if (not os.path.exists(data_path)):
      os.makedirs(data_path)
    for i, (img,label) in enumerate(zip(test_set[0],test_set[1])):
      img_path = data_path+ str(i) + '.jpg'
      io.imsave(img_path, img.numpy())
      f.write(img_path + ' ' + str(label) + '\n')
    f.close()

convert_to_img(True)
convert_to_img(False)

这样就会在e:/fashion_mnist/目录下分别生成train和test文件夹,用于存放图片。还在该目录下生成了标签文件train.txt和test.txt.

二、进行CNN分类训练和测试

先要将图片读取出来,准备成torch专用的dataset格式,再通过Dataloader进行分批次训练。

代码如下:

import torch
from torch.autograd import Variable
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
root="E:/fashion_mnist/"

# -----------------ready the dataset--------------------------
def default_loader(path):
  return Image.open(path).convert('RGB')
class MyDataset(Dataset):
  def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
    fh = open(txt, 'r')
    imgs = []
    for line in fh:
      line = line.strip('\n')
      line = line.rstrip()
      words = line.split()
      imgs.append((words[0],int(words[1])))
    self.imgs = imgs
    self.transform = transform
    self.target_transform = target_transform
    self.loader = loader

  def __getitem__(self, index):
    fn, label = self.imgs[index]
    img = self.loader(fn)
    if self.transform is not None:
      img = self.transform(img)
    return img,label

  def __len__(self):
    return len(self.imgs)

train_data=MyDataset(txt=root+'train.txt', transform=transforms.ToTensor())
test_data=MyDataset(txt=root+'test.txt', transform=transforms.ToTensor())
train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=64)


#-----------------create the Net and training------------------------

class Net(torch.nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.conv1 = torch.nn.Sequential(
      torch.nn.Conv2d(3, 32, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2))
    self.conv2 = torch.nn.Sequential(
      torch.nn.Conv2d(32, 64, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2)
    )
    self.conv3 = torch.nn.Sequential(
      torch.nn.Conv2d(64, 64, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2)
    )
    self.dense = torch.nn.Sequential(
      torch.nn.Linear(64 * 3 * 3, 128),
      torch.nn.ReLU(),
      torch.nn.Linear(128, 10)
    )

  def forward(self, x):
    conv1_out = self.conv1(x)
    conv2_out = self.conv2(conv1_out)
    conv3_out = self.conv3(conv2_out)
    res = conv3_out.view(conv3_out.size(0), -1)
    out = self.dense(res)
    return out


model = Net()
print(model)

optimizer = torch.optim.Adam(model.parameters())
loss_func = torch.nn.CrossEntropyLoss()

for epoch in range(10):
  print('epoch {}'.format(epoch + 1))
  # training-----------------------------
  train_loss = 0.
  train_acc = 0.
  for batch_x, batch_y in train_loader:
    batch_x, batch_y = Variable(batch_x), Variable(batch_y)
    out = model(batch_x)
    loss = loss_func(out, batch_y)
    train_loss += loss.data[0]
    pred = torch.max(out, 1)[1]
    train_correct = (pred == batch_y).sum()
    train_acc += train_correct.data[0]
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
  print('Train Loss: {:.6f}, Acc: {:.6f}'.format(train_loss / (len(
    train_data)), train_acc / (len(train_data))))

  # evaluation--------------------------------
  model.eval()
  eval_loss = 0.
  eval_acc = 0.
  for batch_x, batch_y in test_loader:
    batch_x, batch_y = Variable(batch_x, volatile=True), Variable(batch_y, volatile=True)
    out = model(batch_x)
    loss = loss_func(out, batch_y)
    eval_loss += loss.data[0]
    pred = torch.max(out, 1)[1]
    num_correct = (pred == batch_y).sum()
    eval_acc += num_correct.data[0]
  print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(
    test_data)), eval_acc / (len(test_data))))

打印出来的网络模型:

pytorch 准备、训练和测试自己的图片数据的方法

训练和测试结果:

pytorch 准备、训练和测试自己的图片数据的方法

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python面向对象编程中的类和对象学习教程
Mar 30 Python
Python程序中的观察者模式结构编写示例
May 27 Python
对tensorflow 的模型保存和调用实例讲解
Jul 28 Python
django框架模板中定义变量(set variable in django template)的方法分析
Jun 24 Python
在python中,使用scatter绘制散点图的实例
Jul 03 Python
Python 日志logging模块用法简单示例
Oct 18 Python
Python操作Excel工作簿的示例代码(\*.xlsx)
Mar 23 Python
python实现PDF中表格转化为Excel的方法
Jun 16 Python
Python设计密码强度校验程序
Jul 30 Python
Django中ORM的基本使用教程
Dec 22 Python
python使用shell脚本创建kafka连接器
Apr 29 Python
Python绘制散点图之可视化神器pyecharts
Jul 07 Python
pytorch GAN伪造手写体mnist数据集方式
Jan 10 #Python
MNIST数据集转化为二维图片的实现示例
Jan 10 #Python
pytorch:实现简单的GAN示例(MNIST数据集)
Jan 10 #Python
pytorch GAN生成对抗网络实例
Jan 10 #Python
解决pytorch报错:AssertionError: Invalid device id的问题
Jan 10 #Python
python3中关于excel追加写入格式被覆盖问题(实例代码)
Jan 10 #Python
mac使用python识别图形验证码功能
Jan 10 #Python
You might like
用PHP书写安全的脚本代码
2012/02/05 PHP
php实现上传图片生成缩略图示例
2014/04/13 PHP
destoon实现资讯信息前面调用它所属分类的方法
2014/07/15 PHP
PHP mysqli事务操作常用方法分析
2017/07/22 PHP
javascript attachEvent绑定多个事件执行顺序问题
2010/10/20 Javascript
Javascript 静态页面实现随机显示广告的办法
2010/11/17 Javascript
js 程序执行与顺序实现详解
2013/05/13 Javascript
JS中把字符转成ASCII值的函数示例代码
2013/11/21 Javascript
node.js中的fs.rename方法使用说明
2014/12/16 Javascript
Bootstrap3使用typeahead插件实现自动补全功能
2016/07/07 Javascript
JavaScript随机打乱数组顺序之随机洗牌算法
2016/08/02 Javascript
AngularJS 指令的交互详解及实例代码
2016/09/14 Javascript
基于JavaScript实现焦点图轮播效果
2017/03/27 Javascript
JavaScript中this的用法及this在不同应用场景的作用解析
2017/04/13 Javascript
推荐三款不错的图片压缩上传插件(webuploader、localResizeIMG4、LUploader)
2017/04/21 Javascript
详解AngularJS1.x学习directive 中‘& ’‘=’ ‘@’符号的区别使用
2017/08/23 Javascript
Node.js中的child_process模块详解
2018/06/08 Javascript
Vue.js组件间通信方式总结【推荐】
2018/11/23 Javascript
配置node服务器并且链接微信公众号接口配置步骤详解
2019/06/21 Javascript
详解小程序BackgroundAudioManager踩坑之旅
2019/12/08 Javascript
如何使用 JavaScript 操作浏览器历史记录 API
2020/11/24 Javascript
[02:07]DOTA2超级联赛专访BBC:难忘网吧超神经历
2013/06/09 DOTA
Python中使用Boolean操作符做真值测试实例
2015/01/30 Python
python uuid模块使用实例
2015/04/08 Python
python中通过pip安装库文件时出现“EnvironmentError: [WinError 5] 拒绝访问”的问题及解决方案
2020/08/11 Python
利用纯html5绘制出来的一款非常漂亮的时钟
2015/01/04 HTML / CSS
爱尔兰电子产品购物网站:Komplett.ie
2018/04/04 全球购物
青年文明号口号
2014/06/17 职场文书
外贸会计专业自荐信
2014/06/22 职场文书
公司年底活动方案
2014/08/17 职场文书
庆祝教师节活动总结
2015/03/23 职场文书
2015年学校信息技术工作总结
2015/05/25 职场文书
jQuery class属性操作addClass()与removeClass()、hasClass()、toggleClass()
2021/03/31 jQuery
自从在 IDEA 中用了热部署神器 JRebel 之后,开发效率提升了 10(真棒)
2021/06/26 Java/Android
Vue监视数据的原理详解
2022/02/24 Vue.js
微软官方消息,在 2023 年 4 月 11 日之后微软将不再为 Office 2013 和 Skype for Business 2015 提供安全更新
2022/04/21 数码科技