Pytorch自己加载单通道图片用作数据集训练的实例


Posted in Python onJanuary 18, 2020

pytorch 在torchvision包里面有很多的的打包好的数据集,例如minist,Imagenet-12,CIFAR10 和CIFAR100。在torchvision的dataset包里面,用的时候直接调用就行了。具体的调用格式可以去看文档(目前好像只有英文的)。网上也有很多源代码。

不过,当我们想利用自己制作的数据集来训练网络模型时,就要有自己的方法了。pytorch在torchvision.dataset包里面封装过一个函数ImageFolder()。这个函数功能很强大,只要你直接将数据集路径保存为例如“train/1/1.jpg ,rain/1/2.jpg …… ”就可以根据根目录“./train”将数据集装载了。

dataset.ImageFolder(root="datapath", transfroms.ToTensor())

但是后来我发现一个问题,就是这个函数加载出来的图像矩阵都是三通道的,并且没有什么参数调用可以让其变为单通道。如果我们要用到单通道数据集(灰度图)的话,比如自己加载Lenet-5模型的数据集,就只能自己写numpy数组再转为pytorch的Tensor()张量了。

接下来是我做的过程:

首先,还是要用到opencv,用灰度图打开一张图片,省事。

#读取图片 这里是灰度图 
 for item in all_path:
  img = cv2.imread(item[1],0)
  img = cv2.resize(img,(28,28))
  arr = np.asarray(img,dtype="float32")
  data_x[i ,:,:,:] = arr
  i+=1
  data_y.append(int(item[0]))
  
 data_x = data_x / 255
 data_y = np.asarray(data_y)

其次,pytorch有自己的numpy转Tensor函数,直接转就行了。

data_x = torch.from_numpy(data_x)
 data_y = torch.from_numpy(data_y)

下一步利用torch.util和torchvision里面的dataLoader函数,就能直接得到和torchvision.dataset里面封装好的包相同的数据集样本了

dataset = dataf.TensorDataset(data_x,data_y)
 loader = dataf.DataLoader(dataset, batch_size=batchsize, shuffle=True)

最后就是自己建网络设计参数训练了,这部分和文档以及github中的差不多,就不赘述了。

下面是整个程序的源代码,我利用的还是上次的车标识别的数据集,一共分四类,用的是2层卷积核两层全连接。

源代码:

# coding=utf-8
import os
import cv2
import numpy as np
import random
 
import torch
import torch.nn as nn
import torch.utils.data as dataf
from torch.autograd import Variable
import torch.nn.functional as F
import torch.optim as optim
 
#训练参数
cuda = False
train_epoch = 20
train_lr = 0.01
train_momentum = 0.5
batchsize = 5
 
 
#测试训练集路径
test_path = "/home/test/"
train_path = "/home/train/"
 
#路径数据
all_path =[]
 
def load_data(data_path):
 signal = os.listdir(data_path)
 for fsingal in signal: 
  filepath = data_path+fsingal
  filename = os.listdir(filepath)
  for fname in filename:
   ffpath = filepath+"/"+fname
   path = [fsingal,ffpath]
   all_path.append(path)
   
#设立数据集多大
 count = len(all_path)
 data_x = np.empty((count,1,28,28),dtype="float32")
 data_y = []
#打乱顺序
 random.shuffle(all_path)
 i=0;
 
#读取图片 这里是灰度图 最后结果是i*i*i*i
#分别表示:batch大小 , 通道数, 像素矩阵
 for item in all_path:
  img = cv2.imread(item[1],0)
  img = cv2.resize(img,(28,28))
  arr = np.asarray(img,dtype="float32")
  data_x[i ,:,:,:] = arr
  i+=1
  data_y.append(int(item[0]))
  
 data_x = data_x / 255
 data_y = np.asarray(data_y)
#  lener = len(all_path)
 data_x = torch.from_numpy(data_x)
 data_y = torch.from_numpy(data_y)
 dataset = dataf.TensorDataset(data_x,data_y)
 
 loader = dataf.DataLoader(dataset, batch_size=batchsize, shuffle=True)
  
 return loader
#  print data_y
 
 
 
train_load = load_data(train_path)
test_load = load_data(test_path)
 
class L5_NET(nn.Module):
 def __init__(self):
  super(L5_NET ,self).__init__();
  #第一层输入1,20个卷积核 每个5*5
  self.conv1 = nn.Conv2d(1 , 20 , kernel_size=5)
  #第二层输入20,30个卷积核 每个5*5
  self.conv2 = nn.Conv2d(20 , 30 , kernel_size=5)
  #drop函数
  self.conv2_drop = nn.Dropout2d()
  #全链接层1,展开30*4*4,连接层50个神经元
  self.fc1 = nn.Linear(30*4*4,50)
  #全链接层1,50-4 ,4为最后的输出分类
  self.fc2 = nn.Linear(50,4)
 
 #前向传播
 def forward(self,x):
  #池化层1 对于第一层卷积池化,池化核2*2
  x = F.relu(F.max_pool2d( self.conv1(x)  ,2 ) )
  #池化层2 对于第二层卷积池化,池化核2*2
  x = F.relu(F.max_pool2d( self.conv2_drop( self.conv2(x) ) , 2 ) )
  #平铺轴30*4*4个神经元
  x = x.view(-1 , 30*4*4)
  #全链接1
  x = F.relu( self.fc1(x) )
  #dropout链接
  x = F.dropout(x , training= self.training)
  #全链接w
  x = self.fc2(x)
  #softmax链接返回结果
  return F.log_softmax(x)
 
model = L5_NET()
if cuda :
 model.cuda()
  
 
optimizer = optim.SGD(model.parameters()  , lr =train_lr , momentum = train_momentum )
 
#预测函数
def train(epoch):
 model.train()
 for batch_idx, (data, target) in enumerate(train_load):
  if cuda:
   data, target = data.cuda(), target.cuda()
  data, target = Variable(data), Variable(target)
  #求导
  optimizer.zero_grad()
  #训练模型,输出结果
  output = model(data)
  #在数据集上预测loss
  loss = F.nll_loss(output, target)
  #反向传播调整参数pytorch直接可以用loss
  loss.backward()
  #SGD刷新进步
  optimizer.step()
  #实时输出
  if batch_idx % 10 == 0:
   print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
    epoch, batch_idx * len(data), len(train_load.dataset),
    100. * batch_idx / len(train_load), loss.data[0]))
#    
   
#测试函数
def test(epoch):
 model.eval()
 test_loss = 0
 correct = 0
 for data, target in test_load:
  
  if cuda:
   data, target = data.cuda(), target.cuda()
   
  data, target = Variable(data, volatile=True), Variable(target)
  #在测试集上预测
  output = model(data)
  #计算在测试集上的loss
  test_loss += F.nll_loss(output, target).data[0]
  #获得预测的结果
  pred = output.data.max(1)[1] # get the index of the max log-probability
  #如果正确,correct+1
  correct += pred.eq(target.data).cpu().sum()
 
 #loss计算
 test_loss = test_loss
 test_loss /= len(test_load)
 #输出结果
 print('\nThe {} epoch result : Average loss: {:.6f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
  epoch,test_loss, correct, len(test_load.dataset),
  100. * correct / len(test_load.dataset)))
 
for epoch in range(1, train_epoch+ 1):
 train(epoch)
 test(epoch)

最后的训练结果和在keras下差不多,不过我训练的时候好像把训练集和测试集弄反了,数目好像测试集比训练集还多,有点尴尬,不过无伤大雅。结果图如下:

Pytorch自己加载单通道图片用作数据集训练的实例

以上这篇Pytorch自己加载单通道图片用作数据集训练的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python的pycurl包用法简介
Nov 13 Python
Python列表切片用法示例
Apr 19 Python
python 检查文件mime类型的方法
Dec 08 Python
Python3爬虫学习之爬虫利器Beautiful Soup用法分析
Dec 12 Python
利用python实现在微信群刷屏的方法
Feb 21 Python
Python计算时间间隔(精确到微妙)的代码实例
Feb 26 Python
Python实现定期检查源目录与备份目录的差异并进行备份功能示例
Feb 27 Python
python 将有序数组转换为二叉树的方法
Mar 26 Python
使用python爬取抖音视频列表信息
Jul 15 Python
Python学习笔记之For循环用法详解
Aug 14 Python
Pytorch 中retain_graph的用法详解
Jan 07 Python
Matplotlib绘制条形图的方法你知道吗
Mar 21 Python
pyinstaller 3.6版本通过pip安装失败的解决办法(推荐)
Jan 18 #Python
Python实现点云投影到平面显示
Jan 18 #Python
Pytorch 实现计算分类器准确率(总分类及子分类)
Jan 18 #Python
在pytorch 中计算精度、回归率、F1 score等指标的实例
Jan 18 #Python
Python中实现输入超时及如何通过变量获取变量名
Jan 18 #Python
Pytorch 计算误判率,计算准确率,计算召回率的例子
Jan 18 #Python
python:目标检测模型预测准确度计算方式(基于IoU)
Jan 18 #Python
You might like
WIN98下Apache1.3.14+PHP4.0.4的安装
2006/10/09 PHP
php设计模式 Delegation(委托模式)
2011/06/26 PHP
PHP使用GIFEncoder类生成gif动态滚动字幕
2014/07/01 PHP
PHP curl伪造IP地址和header信息代码实例
2015/04/27 PHP
Javascript 面向对象编程(coolshell)
2012/03/18 Javascript
jQuery基本选择器选择元素使用介绍
2013/04/18 Javascript
javascript解决innerText浏览器兼容问题思路代码
2013/05/17 Javascript
jQuery实现单击弹出Div层窗口效果(可关闭可拖动)
2015/09/19 Javascript
Bootstrap Validator 表单验证
2016/07/25 Javascript
Three.js快速入门教程
2016/09/09 Javascript
JS禁止查看网页源代码的实现方法
2016/10/12 Javascript
微信小程序 后台https域名绑定和免费的https证书申请详解
2016/11/10 Javascript
JS基于onclick事件实现单个按钮的编辑与保存功能示例
2017/02/13 Javascript
Bootstrap输入框组件使用详解
2017/06/09 Javascript
JavaScript模拟实现封装的三种方式及写法区别
2017/10/27 Javascript
在vue项目中安装使用Mint-UI的方法
2017/12/27 Javascript
浅谈Vue组件及组件的注册方法
2018/08/24 Javascript
微信小程序云开发 生成带参小程序码流程
2019/05/18 Javascript
tracking.js实现前端人脸识别功能
2020/04/16 Javascript
vue v-for 点击当前行,获取当前行数据及event当前事件对象的操作
2020/09/10 Javascript
原生js实现照片墙效果
2020/10/13 Javascript
微信小程序实现简单的select下拉框
2020/11/23 Javascript
[04:49]2014DOTA2国际邀请赛 Newbee顺利挺进总决赛 ImbaTV独家专访
2014/07/19 DOTA
Python爬虫信息输入及页面的切换方法
2018/05/11 Python
Python中将两个或多个list合成一个list的方法小结
2019/05/12 Python
Python中openpyxl实现vlookup函数的实例
2020/10/28 Python
10分钟入门CSS3 Animation
2018/12/25 HTML / CSS
美国第一香水网站:Perfume.com
2017/01/23 全球购物
如何在Cookie里面保存Unicode和国际化字符
2013/05/25 面试题
.NET面试10题
2014/02/24 面试题
租房协议书范文
2014/08/20 职场文书
护士求职自荐信范文
2015/03/04 职场文书
品牌形象定位,全面分析
2019/07/23 职场文书
python 爬取哔哩哔哩up主信息和投稿视频
2021/06/07 Python
vue实现列表垂直无缝滚动
2022/04/08 Vue.js
vue实现登陆页面开发实践
2022/05/30 Vue.js