Pytorch自己加载单通道图片用作数据集训练的实例


Posted in Python onJanuary 18, 2020

pytorch 在torchvision包里面有很多的的打包好的数据集,例如minist,Imagenet-12,CIFAR10 和CIFAR100。在torchvision的dataset包里面,用的时候直接调用就行了。具体的调用格式可以去看文档(目前好像只有英文的)。网上也有很多源代码。

不过,当我们想利用自己制作的数据集来训练网络模型时,就要有自己的方法了。pytorch在torchvision.dataset包里面封装过一个函数ImageFolder()。这个函数功能很强大,只要你直接将数据集路径保存为例如“train/1/1.jpg ,rain/1/2.jpg …… ”就可以根据根目录“./train”将数据集装载了。

dataset.ImageFolder(root="datapath", transfroms.ToTensor())

但是后来我发现一个问题,就是这个函数加载出来的图像矩阵都是三通道的,并且没有什么参数调用可以让其变为单通道。如果我们要用到单通道数据集(灰度图)的话,比如自己加载Lenet-5模型的数据集,就只能自己写numpy数组再转为pytorch的Tensor()张量了。

接下来是我做的过程:

首先,还是要用到opencv,用灰度图打开一张图片,省事。

#读取图片 这里是灰度图 
 for item in all_path:
  img = cv2.imread(item[1],0)
  img = cv2.resize(img,(28,28))
  arr = np.asarray(img,dtype="float32")
  data_x[i ,:,:,:] = arr
  i+=1
  data_y.append(int(item[0]))
  
 data_x = data_x / 255
 data_y = np.asarray(data_y)

其次,pytorch有自己的numpy转Tensor函数,直接转就行了。

data_x = torch.from_numpy(data_x)
 data_y = torch.from_numpy(data_y)

下一步利用torch.util和torchvision里面的dataLoader函数,就能直接得到和torchvision.dataset里面封装好的包相同的数据集样本了

dataset = dataf.TensorDataset(data_x,data_y)
 loader = dataf.DataLoader(dataset, batch_size=batchsize, shuffle=True)

最后就是自己建网络设计参数训练了,这部分和文档以及github中的差不多,就不赘述了。

下面是整个程序的源代码,我利用的还是上次的车标识别的数据集,一共分四类,用的是2层卷积核两层全连接。

源代码:

# coding=utf-8
import os
import cv2
import numpy as np
import random
 
import torch
import torch.nn as nn
import torch.utils.data as dataf
from torch.autograd import Variable
import torch.nn.functional as F
import torch.optim as optim
 
#训练参数
cuda = False
train_epoch = 20
train_lr = 0.01
train_momentum = 0.5
batchsize = 5
 
 
#测试训练集路径
test_path = "/home/test/"
train_path = "/home/train/"
 
#路径数据
all_path =[]
 
def load_data(data_path):
 signal = os.listdir(data_path)
 for fsingal in signal: 
  filepath = data_path+fsingal
  filename = os.listdir(filepath)
  for fname in filename:
   ffpath = filepath+"/"+fname
   path = [fsingal,ffpath]
   all_path.append(path)
   
#设立数据集多大
 count = len(all_path)
 data_x = np.empty((count,1,28,28),dtype="float32")
 data_y = []
#打乱顺序
 random.shuffle(all_path)
 i=0;
 
#读取图片 这里是灰度图 最后结果是i*i*i*i
#分别表示:batch大小 , 通道数, 像素矩阵
 for item in all_path:
  img = cv2.imread(item[1],0)
  img = cv2.resize(img,(28,28))
  arr = np.asarray(img,dtype="float32")
  data_x[i ,:,:,:] = arr
  i+=1
  data_y.append(int(item[0]))
  
 data_x = data_x / 255
 data_y = np.asarray(data_y)
#  lener = len(all_path)
 data_x = torch.from_numpy(data_x)
 data_y = torch.from_numpy(data_y)
 dataset = dataf.TensorDataset(data_x,data_y)
 
 loader = dataf.DataLoader(dataset, batch_size=batchsize, shuffle=True)
  
 return loader
#  print data_y
 
 
 
train_load = load_data(train_path)
test_load = load_data(test_path)
 
class L5_NET(nn.Module):
 def __init__(self):
  super(L5_NET ,self).__init__();
  #第一层输入1,20个卷积核 每个5*5
  self.conv1 = nn.Conv2d(1 , 20 , kernel_size=5)
  #第二层输入20,30个卷积核 每个5*5
  self.conv2 = nn.Conv2d(20 , 30 , kernel_size=5)
  #drop函数
  self.conv2_drop = nn.Dropout2d()
  #全链接层1,展开30*4*4,连接层50个神经元
  self.fc1 = nn.Linear(30*4*4,50)
  #全链接层1,50-4 ,4为最后的输出分类
  self.fc2 = nn.Linear(50,4)
 
 #前向传播
 def forward(self,x):
  #池化层1 对于第一层卷积池化,池化核2*2
  x = F.relu(F.max_pool2d( self.conv1(x)  ,2 ) )
  #池化层2 对于第二层卷积池化,池化核2*2
  x = F.relu(F.max_pool2d( self.conv2_drop( self.conv2(x) ) , 2 ) )
  #平铺轴30*4*4个神经元
  x = x.view(-1 , 30*4*4)
  #全链接1
  x = F.relu( self.fc1(x) )
  #dropout链接
  x = F.dropout(x , training= self.training)
  #全链接w
  x = self.fc2(x)
  #softmax链接返回结果
  return F.log_softmax(x)
 
model = L5_NET()
if cuda :
 model.cuda()
  
 
optimizer = optim.SGD(model.parameters()  , lr =train_lr , momentum = train_momentum )
 
#预测函数
def train(epoch):
 model.train()
 for batch_idx, (data, target) in enumerate(train_load):
  if cuda:
   data, target = data.cuda(), target.cuda()
  data, target = Variable(data), Variable(target)
  #求导
  optimizer.zero_grad()
  #训练模型,输出结果
  output = model(data)
  #在数据集上预测loss
  loss = F.nll_loss(output, target)
  #反向传播调整参数pytorch直接可以用loss
  loss.backward()
  #SGD刷新进步
  optimizer.step()
  #实时输出
  if batch_idx % 10 == 0:
   print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
    epoch, batch_idx * len(data), len(train_load.dataset),
    100. * batch_idx / len(train_load), loss.data[0]))
#    
   
#测试函数
def test(epoch):
 model.eval()
 test_loss = 0
 correct = 0
 for data, target in test_load:
  
  if cuda:
   data, target = data.cuda(), target.cuda()
   
  data, target = Variable(data, volatile=True), Variable(target)
  #在测试集上预测
  output = model(data)
  #计算在测试集上的loss
  test_loss += F.nll_loss(output, target).data[0]
  #获得预测的结果
  pred = output.data.max(1)[1] # get the index of the max log-probability
  #如果正确,correct+1
  correct += pred.eq(target.data).cpu().sum()
 
 #loss计算
 test_loss = test_loss
 test_loss /= len(test_load)
 #输出结果
 print('\nThe {} epoch result : Average loss: {:.6f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
  epoch,test_loss, correct, len(test_load.dataset),
  100. * correct / len(test_load.dataset)))
 
for epoch in range(1, train_epoch+ 1):
 train(epoch)
 test(epoch)

最后的训练结果和在keras下差不多,不过我训练的时候好像把训练集和测试集弄反了,数目好像测试集比训练集还多,有点尴尬,不过无伤大雅。结果图如下:

Pytorch自己加载单通道图片用作数据集训练的实例

以上这篇Pytorch自己加载单通道图片用作数据集训练的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python通过colorama模块在控制台输出彩色文字的方法
Mar 19 Python
python获取一组汉字拼音首字母的方法
Jul 01 Python
Python中常用信号signal类型实例
Jan 25 Python
Jupyter中直接显示Matplotlib的图形方法
May 24 Python
Python PyAutoGUI模块控制鼠标和键盘实现自动化任务详解
Sep 04 Python
pandas去除重复列的实现方法
Jan 29 Python
Python 从subprocess运行的子进程中实时获取输出的例子
Aug 14 Python
pytorch常见的Tensor类型详解
Jan 15 Python
python实现交并比IOU教程
Apr 16 Python
python 常用日期处理-- datetime 模块的使用
Sep 02 Python
Python通过类的组合模拟街道红绿灯
Sep 16 Python
python关于倒排列的知识点总结
Oct 13 Python
pyinstaller 3.6版本通过pip安装失败的解决办法(推荐)
Jan 18 #Python
Python实现点云投影到平面显示
Jan 18 #Python
Pytorch 实现计算分类器准确率(总分类及子分类)
Jan 18 #Python
在pytorch 中计算精度、回归率、F1 score等指标的实例
Jan 18 #Python
Python中实现输入超时及如何通过变量获取变量名
Jan 18 #Python
Pytorch 计算误判率,计算准确率,计算召回率的例子
Jan 18 #Python
python:目标检测模型预测准确度计算方式(基于IoU)
Jan 18 #Python
You might like
解析link_mysql的php版
2013/06/30 PHP
PHP字符串中特殊符号的过滤方法介绍
2014/02/18 PHP
JavaScript通过元素的ID和name设置样式
2014/07/08 Javascript
jQuery CSS()方法改变现有的CSS样式表
2014/09/09 Javascript
angularJS中router的使用指南
2015/02/09 Javascript
JS动态改变浏览器标题的方法
2016/04/06 Javascript
jQuery获取当前点击的对象元素(实现代码)
2016/05/19 Javascript
JavaScript导航脚本判断当前导航
2016/07/12 Javascript
在DWR中实现直接获取一个JAVA类的返回值的两种方法
2016/12/25 Javascript
ES6新数据结构Map功能与用法示例
2017/03/31 Javascript
深入理解基于vue-cli的vuex配置
2017/07/24 Javascript
详解create-react-app 自定义 eslint 配置
2018/06/07 Javascript
Vue 3.0双向绑定原理的实现方法
2019/10/23 Javascript
微信小程序实现滑动翻页效果(完整代码)
2019/12/06 Javascript
全局安装 Vue cli3 和 继续使用 Vue-cli2.x操作
2020/09/08 Javascript
[01:01:51]EG vs VG Supermajor小组赛B组 BO3 第二场 6.2
2018/06/03 DOTA
python中zip()方法应用实例分析
2016/04/16 Python
Python设置默认编码为utf8的方法
2016/07/01 Python
Python存取XML的常见方法实例分析
2017/03/21 Python
python3读取excel文件只提取某些行某些列的值方法
2018/07/10 Python
django如何连接已存在数据的数据库
2018/08/14 Python
python  Django中的apps.py的目的是什么
2018/10/15 Python
python中如何使用insert函数
2020/01/09 Python
python打包生成so文件的实现
2020/10/30 Python
CSS3实现可爱的小黄人动画
2016/07/11 HTML / CSS
HTML实现代码雨源码及效果示例
2020/02/25 HTML / CSS
美国鞋类购物网站:Shiekh Shoes
2016/08/21 全球购物
java程序员面试交流
2012/11/29 面试题
大学生职业生涯规划书模版
2013/12/30 职场文书
实习会计求职自荐信范文
2014/03/10 职场文书
机关干部四风问题自查报告及整改措施
2014/10/26 职场文书
国庆庆典邀请函
2015/02/02 职场文书
签订劳动合同通知书
2015/04/16 职场文书
公司劳动纪律管理制度
2015/08/04 职场文书
《穷人》教学反思
2016/02/19 职场文书
JS前端宏任务微任务及Event Loop使用详解
2022/07/23 Javascript