Pytorch自己加载单通道图片用作数据集训练的实例


Posted in Python onJanuary 18, 2020

pytorch 在torchvision包里面有很多的的打包好的数据集,例如minist,Imagenet-12,CIFAR10 和CIFAR100。在torchvision的dataset包里面,用的时候直接调用就行了。具体的调用格式可以去看文档(目前好像只有英文的)。网上也有很多源代码。

不过,当我们想利用自己制作的数据集来训练网络模型时,就要有自己的方法了。pytorch在torchvision.dataset包里面封装过一个函数ImageFolder()。这个函数功能很强大,只要你直接将数据集路径保存为例如“train/1/1.jpg ,rain/1/2.jpg …… ”就可以根据根目录“./train”将数据集装载了。

dataset.ImageFolder(root="datapath", transfroms.ToTensor())

但是后来我发现一个问题,就是这个函数加载出来的图像矩阵都是三通道的,并且没有什么参数调用可以让其变为单通道。如果我们要用到单通道数据集(灰度图)的话,比如自己加载Lenet-5模型的数据集,就只能自己写numpy数组再转为pytorch的Tensor()张量了。

接下来是我做的过程:

首先,还是要用到opencv,用灰度图打开一张图片,省事。

#读取图片 这里是灰度图 
 for item in all_path:
  img = cv2.imread(item[1],0)
  img = cv2.resize(img,(28,28))
  arr = np.asarray(img,dtype="float32")
  data_x[i ,:,:,:] = arr
  i+=1
  data_y.append(int(item[0]))
  
 data_x = data_x / 255
 data_y = np.asarray(data_y)

其次,pytorch有自己的numpy转Tensor函数,直接转就行了。

data_x = torch.from_numpy(data_x)
 data_y = torch.from_numpy(data_y)

下一步利用torch.util和torchvision里面的dataLoader函数,就能直接得到和torchvision.dataset里面封装好的包相同的数据集样本了

dataset = dataf.TensorDataset(data_x,data_y)
 loader = dataf.DataLoader(dataset, batch_size=batchsize, shuffle=True)

最后就是自己建网络设计参数训练了,这部分和文档以及github中的差不多,就不赘述了。

下面是整个程序的源代码,我利用的还是上次的车标识别的数据集,一共分四类,用的是2层卷积核两层全连接。

源代码:

# coding=utf-8
import os
import cv2
import numpy as np
import random
 
import torch
import torch.nn as nn
import torch.utils.data as dataf
from torch.autograd import Variable
import torch.nn.functional as F
import torch.optim as optim
 
#训练参数
cuda = False
train_epoch = 20
train_lr = 0.01
train_momentum = 0.5
batchsize = 5
 
 
#测试训练集路径
test_path = "/home/test/"
train_path = "/home/train/"
 
#路径数据
all_path =[]
 
def load_data(data_path):
 signal = os.listdir(data_path)
 for fsingal in signal: 
  filepath = data_path+fsingal
  filename = os.listdir(filepath)
  for fname in filename:
   ffpath = filepath+"/"+fname
   path = [fsingal,ffpath]
   all_path.append(path)
   
#设立数据集多大
 count = len(all_path)
 data_x = np.empty((count,1,28,28),dtype="float32")
 data_y = []
#打乱顺序
 random.shuffle(all_path)
 i=0;
 
#读取图片 这里是灰度图 最后结果是i*i*i*i
#分别表示:batch大小 , 通道数, 像素矩阵
 for item in all_path:
  img = cv2.imread(item[1],0)
  img = cv2.resize(img,(28,28))
  arr = np.asarray(img,dtype="float32")
  data_x[i ,:,:,:] = arr
  i+=1
  data_y.append(int(item[0]))
  
 data_x = data_x / 255
 data_y = np.asarray(data_y)
#  lener = len(all_path)
 data_x = torch.from_numpy(data_x)
 data_y = torch.from_numpy(data_y)
 dataset = dataf.TensorDataset(data_x,data_y)
 
 loader = dataf.DataLoader(dataset, batch_size=batchsize, shuffle=True)
  
 return loader
#  print data_y
 
 
 
train_load = load_data(train_path)
test_load = load_data(test_path)
 
class L5_NET(nn.Module):
 def __init__(self):
  super(L5_NET ,self).__init__();
  #第一层输入1,20个卷积核 每个5*5
  self.conv1 = nn.Conv2d(1 , 20 , kernel_size=5)
  #第二层输入20,30个卷积核 每个5*5
  self.conv2 = nn.Conv2d(20 , 30 , kernel_size=5)
  #drop函数
  self.conv2_drop = nn.Dropout2d()
  #全链接层1,展开30*4*4,连接层50个神经元
  self.fc1 = nn.Linear(30*4*4,50)
  #全链接层1,50-4 ,4为最后的输出分类
  self.fc2 = nn.Linear(50,4)
 
 #前向传播
 def forward(self,x):
  #池化层1 对于第一层卷积池化,池化核2*2
  x = F.relu(F.max_pool2d( self.conv1(x)  ,2 ) )
  #池化层2 对于第二层卷积池化,池化核2*2
  x = F.relu(F.max_pool2d( self.conv2_drop( self.conv2(x) ) , 2 ) )
  #平铺轴30*4*4个神经元
  x = x.view(-1 , 30*4*4)
  #全链接1
  x = F.relu( self.fc1(x) )
  #dropout链接
  x = F.dropout(x , training= self.training)
  #全链接w
  x = self.fc2(x)
  #softmax链接返回结果
  return F.log_softmax(x)
 
model = L5_NET()
if cuda :
 model.cuda()
  
 
optimizer = optim.SGD(model.parameters()  , lr =train_lr , momentum = train_momentum )
 
#预测函数
def train(epoch):
 model.train()
 for batch_idx, (data, target) in enumerate(train_load):
  if cuda:
   data, target = data.cuda(), target.cuda()
  data, target = Variable(data), Variable(target)
  #求导
  optimizer.zero_grad()
  #训练模型,输出结果
  output = model(data)
  #在数据集上预测loss
  loss = F.nll_loss(output, target)
  #反向传播调整参数pytorch直接可以用loss
  loss.backward()
  #SGD刷新进步
  optimizer.step()
  #实时输出
  if batch_idx % 10 == 0:
   print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
    epoch, batch_idx * len(data), len(train_load.dataset),
    100. * batch_idx / len(train_load), loss.data[0]))
#    
   
#测试函数
def test(epoch):
 model.eval()
 test_loss = 0
 correct = 0
 for data, target in test_load:
  
  if cuda:
   data, target = data.cuda(), target.cuda()
   
  data, target = Variable(data, volatile=True), Variable(target)
  #在测试集上预测
  output = model(data)
  #计算在测试集上的loss
  test_loss += F.nll_loss(output, target).data[0]
  #获得预测的结果
  pred = output.data.max(1)[1] # get the index of the max log-probability
  #如果正确,correct+1
  correct += pred.eq(target.data).cpu().sum()
 
 #loss计算
 test_loss = test_loss
 test_loss /= len(test_load)
 #输出结果
 print('\nThe {} epoch result : Average loss: {:.6f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
  epoch,test_loss, correct, len(test_load.dataset),
  100. * correct / len(test_load.dataset)))
 
for epoch in range(1, train_epoch+ 1):
 train(epoch)
 test(epoch)

最后的训练结果和在keras下差不多,不过我训练的时候好像把训练集和测试集弄反了,数目好像测试集比训练集还多,有点尴尬,不过无伤大雅。结果图如下:

Pytorch自己加载单通道图片用作数据集训练的实例

以上这篇Pytorch自己加载单通道图片用作数据集训练的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python自动化测试之setUp与tearDown实例
Sep 28 Python
详解Python中find()方法的使用
May 18 Python
python3.5使用tkinter制作记事本
Jun 20 Python
简单谈谈python中的Queue与多进程
Aug 25 Python
Python爬取附近餐馆信息代码示例
Dec 09 Python
python实现二维数组的对角线遍历
Mar 02 Python
Python动态语言与鸭子类型详解
Jul 01 Python
Kears+Opencv实现简单人脸识别
Aug 28 Python
Python获取时间戳代码实例
Sep 24 Python
Python3+selenium配置常见报错解决方案
Aug 28 Python
Python 数据分析之逐块读取文本的实现
Dec 14 Python
Django中的JWT身份验证的实现
May 07 Python
pyinstaller 3.6版本通过pip安装失败的解决办法(推荐)
Jan 18 #Python
Python实现点云投影到平面显示
Jan 18 #Python
Pytorch 实现计算分类器准确率(总分类及子分类)
Jan 18 #Python
在pytorch 中计算精度、回归率、F1 score等指标的实例
Jan 18 #Python
Python中实现输入超时及如何通过变量获取变量名
Jan 18 #Python
Pytorch 计算误判率,计算准确率,计算召回率的例子
Jan 18 #Python
python:目标检测模型预测准确度计算方式(基于IoU)
Jan 18 #Python
You might like
Laravel5权限管理方法详解
2016/07/26 PHP
php常用数组array函数实例总结【赋值,拆分,合并,计算,添加,删除,查询,判断,排序】
2016/12/07 PHP
ThinkPHP中调用PHPExcel的实现代码
2017/04/08 PHP
YII框架http缓存操作示例
2019/04/29 PHP
在线编辑器的实现原理(兼容IE和FireFox)
2007/03/09 Javascript
jquery 屏蔽一个区域内的所有元素,禁止输入
2009/10/22 Javascript
最简单的js图片切换效果实现代码
2011/09/24 Javascript
使用ajaxfileupload.js实现ajax上传文件php版
2014/06/26 Javascript
js实现文字在按钮上滚动的方法
2015/08/20 Javascript
jQuery点击按钮弹出遮罩层且内容居中特效
2015/12/14 Javascript
基于JavaScript代码实现兼容各浏览器的设为首页和加入收藏
2016/01/07 Javascript
JavaScript将DOM事件处理程序封装为event.js 出现的低级错误问题
2016/08/03 Javascript
微信小程序使用第三方库Underscore.js步骤详解
2016/09/27 Javascript
jQuery实现菜单栏导航效果
2017/08/15 jQuery
JS实现颜色的10进制转化成rgba格式的方法
2017/09/04 Javascript
vue+elementUI 复杂表单的验证、数据提交方案问题
2019/06/24 Javascript
layui动态表头的实现代码
2019/08/22 Javascript
uni-app如何页面传参数的几种方法总结
2020/04/28 Javascript
python连接oracle数据库实例
2014/10/17 Python
Python导出DBF文件到Excel的方法
2015/07/25 Python
python3使用smtplib实现发送邮件功能
2018/05/22 Python
使用python生成杨辉三角形的示例代码
2018/08/29 Python
解决python中使用PYQT时中文乱码问题
2019/06/17 Python
Gauss-Seidel迭代算法的Python实现详解
2019/06/29 Python
Python列表元素常见操作简单示例
2019/10/25 Python
使用python 对验证码图片进行降噪处理
2019/12/18 Python
Django 实现将图片转为Base64,然后使用json传输
2020/03/27 Python
python 实现ping测试延迟的两种方法
2020/12/10 Python
简单掌握CSS3将文字描边及填充文字颜色的方法
2016/03/07 HTML / CSS
HTML5 WebGL 实现民航客机飞行监控系统
2019/07/25 HTML / CSS
html5 canvas简单封装一个echarts实现不了的饼图
2018/06/12 HTML / CSS
电影T恤、80年代T恤和80年代服装:TV Store Online
2020/01/05 全球购物
财产分割协议书范本
2014/11/03 职场文书
单位接收函格式
2015/01/30 职场文书
2015年信息中心工作总结
2015/05/25 职场文书
小学运动会加油词
2015/07/18 职场文书