Pytorch自己加载单通道图片用作数据集训练的实例


Posted in Python onJanuary 18, 2020

pytorch 在torchvision包里面有很多的的打包好的数据集,例如minist,Imagenet-12,CIFAR10 和CIFAR100。在torchvision的dataset包里面,用的时候直接调用就行了。具体的调用格式可以去看文档(目前好像只有英文的)。网上也有很多源代码。

不过,当我们想利用自己制作的数据集来训练网络模型时,就要有自己的方法了。pytorch在torchvision.dataset包里面封装过一个函数ImageFolder()。这个函数功能很强大,只要你直接将数据集路径保存为例如“train/1/1.jpg ,rain/1/2.jpg …… ”就可以根据根目录“./train”将数据集装载了。

dataset.ImageFolder(root="datapath", transfroms.ToTensor())

但是后来我发现一个问题,就是这个函数加载出来的图像矩阵都是三通道的,并且没有什么参数调用可以让其变为单通道。如果我们要用到单通道数据集(灰度图)的话,比如自己加载Lenet-5模型的数据集,就只能自己写numpy数组再转为pytorch的Tensor()张量了。

接下来是我做的过程:

首先,还是要用到opencv,用灰度图打开一张图片,省事。

#读取图片 这里是灰度图 
 for item in all_path:
  img = cv2.imread(item[1],0)
  img = cv2.resize(img,(28,28))
  arr = np.asarray(img,dtype="float32")
  data_x[i ,:,:,:] = arr
  i+=1
  data_y.append(int(item[0]))
  
 data_x = data_x / 255
 data_y = np.asarray(data_y)

其次,pytorch有自己的numpy转Tensor函数,直接转就行了。

data_x = torch.from_numpy(data_x)
 data_y = torch.from_numpy(data_y)

下一步利用torch.util和torchvision里面的dataLoader函数,就能直接得到和torchvision.dataset里面封装好的包相同的数据集样本了

dataset = dataf.TensorDataset(data_x,data_y)
 loader = dataf.DataLoader(dataset, batch_size=batchsize, shuffle=True)

最后就是自己建网络设计参数训练了,这部分和文档以及github中的差不多,就不赘述了。

下面是整个程序的源代码,我利用的还是上次的车标识别的数据集,一共分四类,用的是2层卷积核两层全连接。

源代码:

# coding=utf-8
import os
import cv2
import numpy as np
import random
 
import torch
import torch.nn as nn
import torch.utils.data as dataf
from torch.autograd import Variable
import torch.nn.functional as F
import torch.optim as optim
 
#训练参数
cuda = False
train_epoch = 20
train_lr = 0.01
train_momentum = 0.5
batchsize = 5
 
 
#测试训练集路径
test_path = "/home/test/"
train_path = "/home/train/"
 
#路径数据
all_path =[]
 
def load_data(data_path):
 signal = os.listdir(data_path)
 for fsingal in signal: 
  filepath = data_path+fsingal
  filename = os.listdir(filepath)
  for fname in filename:
   ffpath = filepath+"/"+fname
   path = [fsingal,ffpath]
   all_path.append(path)
   
#设立数据集多大
 count = len(all_path)
 data_x = np.empty((count,1,28,28),dtype="float32")
 data_y = []
#打乱顺序
 random.shuffle(all_path)
 i=0;
 
#读取图片 这里是灰度图 最后结果是i*i*i*i
#分别表示:batch大小 , 通道数, 像素矩阵
 for item in all_path:
  img = cv2.imread(item[1],0)
  img = cv2.resize(img,(28,28))
  arr = np.asarray(img,dtype="float32")
  data_x[i ,:,:,:] = arr
  i+=1
  data_y.append(int(item[0]))
  
 data_x = data_x / 255
 data_y = np.asarray(data_y)
#  lener = len(all_path)
 data_x = torch.from_numpy(data_x)
 data_y = torch.from_numpy(data_y)
 dataset = dataf.TensorDataset(data_x,data_y)
 
 loader = dataf.DataLoader(dataset, batch_size=batchsize, shuffle=True)
  
 return loader
#  print data_y
 
 
 
train_load = load_data(train_path)
test_load = load_data(test_path)
 
class L5_NET(nn.Module):
 def __init__(self):
  super(L5_NET ,self).__init__();
  #第一层输入1,20个卷积核 每个5*5
  self.conv1 = nn.Conv2d(1 , 20 , kernel_size=5)
  #第二层输入20,30个卷积核 每个5*5
  self.conv2 = nn.Conv2d(20 , 30 , kernel_size=5)
  #drop函数
  self.conv2_drop = nn.Dropout2d()
  #全链接层1,展开30*4*4,连接层50个神经元
  self.fc1 = nn.Linear(30*4*4,50)
  #全链接层1,50-4 ,4为最后的输出分类
  self.fc2 = nn.Linear(50,4)
 
 #前向传播
 def forward(self,x):
  #池化层1 对于第一层卷积池化,池化核2*2
  x = F.relu(F.max_pool2d( self.conv1(x)  ,2 ) )
  #池化层2 对于第二层卷积池化,池化核2*2
  x = F.relu(F.max_pool2d( self.conv2_drop( self.conv2(x) ) , 2 ) )
  #平铺轴30*4*4个神经元
  x = x.view(-1 , 30*4*4)
  #全链接1
  x = F.relu( self.fc1(x) )
  #dropout链接
  x = F.dropout(x , training= self.training)
  #全链接w
  x = self.fc2(x)
  #softmax链接返回结果
  return F.log_softmax(x)
 
model = L5_NET()
if cuda :
 model.cuda()
  
 
optimizer = optim.SGD(model.parameters()  , lr =train_lr , momentum = train_momentum )
 
#预测函数
def train(epoch):
 model.train()
 for batch_idx, (data, target) in enumerate(train_load):
  if cuda:
   data, target = data.cuda(), target.cuda()
  data, target = Variable(data), Variable(target)
  #求导
  optimizer.zero_grad()
  #训练模型,输出结果
  output = model(data)
  #在数据集上预测loss
  loss = F.nll_loss(output, target)
  #反向传播调整参数pytorch直接可以用loss
  loss.backward()
  #SGD刷新进步
  optimizer.step()
  #实时输出
  if batch_idx % 10 == 0:
   print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
    epoch, batch_idx * len(data), len(train_load.dataset),
    100. * batch_idx / len(train_load), loss.data[0]))
#    
   
#测试函数
def test(epoch):
 model.eval()
 test_loss = 0
 correct = 0
 for data, target in test_load:
  
  if cuda:
   data, target = data.cuda(), target.cuda()
   
  data, target = Variable(data, volatile=True), Variable(target)
  #在测试集上预测
  output = model(data)
  #计算在测试集上的loss
  test_loss += F.nll_loss(output, target).data[0]
  #获得预测的结果
  pred = output.data.max(1)[1] # get the index of the max log-probability
  #如果正确,correct+1
  correct += pred.eq(target.data).cpu().sum()
 
 #loss计算
 test_loss = test_loss
 test_loss /= len(test_load)
 #输出结果
 print('\nThe {} epoch result : Average loss: {:.6f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
  epoch,test_loss, correct, len(test_load.dataset),
  100. * correct / len(test_load.dataset)))
 
for epoch in range(1, train_epoch+ 1):
 train(epoch)
 test(epoch)

最后的训练结果和在keras下差不多,不过我训练的时候好像把训练集和测试集弄反了,数目好像测试集比训练集还多,有点尴尬,不过无伤大雅。结果图如下:

Pytorch自己加载单通道图片用作数据集训练的实例

以上这篇Pytorch自己加载单通道图片用作数据集训练的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中表达式x += y和x = x+y 的区别详解
Jun 20 Python
django模型层(model)进行建表、查询与删除的基础教程
Nov 21 Python
1分钟快速生成用于网页内容提取的xslt
Feb 23 Python
python分治法求二维数组局部峰值方法
Apr 03 Python
一看就懂得Python的math模块
Oct 21 Python
Python模拟浏览器上传文件脚本的方法(Multipart/form-data格式)
Oct 22 Python
详解python3安装pillow后报错没有pillow模块以及没有PIL模块问题解决
Apr 17 Python
pyinstaller打包多个py文件和去除cmd黑框的方法
Jun 21 Python
python3 BeautifulSoup模块使用字典的方法抓取a标签内的数据示例
Nov 28 Python
Python中顺序表原理与实现方法详解
Dec 03 Python
Python 读写 Matlab Mat 格式数据的操作
May 19 Python
Python移位密码、仿射变换解密实例代码
Jun 27 Python
pyinstaller 3.6版本通过pip安装失败的解决办法(推荐)
Jan 18 #Python
Python实现点云投影到平面显示
Jan 18 #Python
Pytorch 实现计算分类器准确率(总分类及子分类)
Jan 18 #Python
在pytorch 中计算精度、回归率、F1 score等指标的实例
Jan 18 #Python
Python中实现输入超时及如何通过变量获取变量名
Jan 18 #Python
Pytorch 计算误判率,计算准确率,计算召回率的例子
Jan 18 #Python
python:目标检测模型预测准确度计算方式(基于IoU)
Jan 18 #Python
You might like
php生成zip文件类实例
2015/04/07 PHP
在Laravel中使用DataTables插件的方法
2018/05/29 PHP
PHP的mysqli_rollback()函数讲解
2019/01/23 PHP
php使用event扩展的io复用测试的示例
2020/10/20 PHP
PJ Blog修改-禁止复制的代码和方法
2006/10/25 Javascript
网页和浏览器兼容性问题汇总(draft1)
2009/06/01 Javascript
jQuery DOM操作小结与实例
2010/01/07 Javascript
AJAX异步从优酷专辑中采集所有视频及信息(JavaScript代码)
2010/11/20 Javascript
jquery(live)中File input的change方法只起一次作用的解决办法
2011/10/21 Javascript
不同的jQuery API来处理不同的浏览器事件
2012/12/09 Javascript
JavaScript中的常见问题解决方法(乱码,IE缓存,代理)
2013/11/28 Javascript
JS弹出层的显示与隐藏示例代码
2013/12/27 Javascript
bootstrap布局中input输入框右侧图标点击功能
2016/05/16 Javascript
Vue 2.0在IE11中打开项目页面空白的问题解决
2017/07/16 Javascript
React中使用collections时key的重要性详解
2017/08/07 Javascript
Vue.js实现双向数据绑定方法(表单自动赋值、表单自动取值)
2018/08/27 Javascript
Vuejs开发环境搭建及热更新【推荐】
2018/09/07 Javascript
Python模块学习 filecmp 文件比较
2012/08/27 Python
Pycharm学习教程(1) 定制外观
2017/05/02 Python
python pygame实现五子棋小游戏
2020/10/26 Python
python实现屏保程序(适用于背单词)
2019/07/30 Python
Python 支持向量机分类器的实现
2020/01/15 Python
利用Python裁切tiff图像且读取tiff,shp文件的实例
2020/03/10 Python
Python数据正态性检验实现过程
2020/04/18 Python
使用Python实现音频双通道分离
2020/12/25 Python
英国复古服装和球衣购买网站:3Retro Football
2018/07/09 全球购物
Interhome丹麦:在线预订度假屋和公寓
2019/07/18 全球购物
智能室内花园:Click & Grow
2021/01/29 全球购物
编写用C语言实现的求n阶阶乘问题的递归算法
2014/10/21 面试题
枚举和一组预处理的#define有什么不同
2016/09/21 面试题
生物制药毕业生自荐信
2013/10/16 职场文书
教师岗位聘任书范文
2014/03/29 职场文书
公司担保书格式范文
2014/05/12 职场文书
中国梦读书活动总结
2014/07/10 职场文书
2015国际残疾人日活动总结
2015/03/24 职场文书
2015年文明创建工作总结
2015/04/30 职场文书