pytorch加载语音类自定义数据集的方法教程


Posted in Python onNovember 10, 2020

前言

pytorch对一下常用的公开数据集有很方便的API接口,但是当我们需要使用自己的数据集训练神经网络时,就需要自定义数据集,在pytorch中,提供了一些类,方便我们定义自己的数据集合

  • torch.utils.data.Dataset:所有继承他的子类都应该重写  __len()__  , __getitem()__ 这两个方法
    •  __len()__ :返回数据集中数据的数量
    •   __getitem()__ :返回支持下标索引方式获取的一个数据
  • torch.utils.data.DataLoader:对数据集进行包装,可以设置batch_size、是否shuffle....

第一步

自定义的 Dataset 都需要继承 torch.utils.data.Dataset 类,并且重写它的两个成员方法:

  • __len()__:读取数据,返回数据和标签
  • __getitem()__:返回数据集的长度
from torch.utils.data import Dataset


class AudioDataset(Dataset):
 def __init__(self, ...):
 """类的初始化"""
 pass

 def __getitem__(self, item):
 """每次怎么读数据,返回数据和标签"""
 return data, label

 def __len__(self):
 """返回整个数据集的长度"""
 return total

注意事项:Dataset只负责数据的抽象,一次调用getiitem只返回一个样本

案例:

文件目录结构

  • p225
    • ***.wav
    • ***.wav
    • ***.wav
    • ...
  • dataset.py

目的:读取p225文件夹中的音频数据

class AudioDataset(Dataset):
 def __init__(self, data_folder, sr=16000, dimension=8192):
 self.data_folder = data_folder
 self.sr = sr
 self.dim = dimension

 # 获取音频名列表
 self.wav_list = []
 for root, dirnames, filenames in os.walk(data_folder):
 for filename in fnmatch.filter(filenames, "*.wav"): # 实现列表特殊字符的过滤或筛选,返回符合匹配“.wav”字符列表
 self.wav_list.append(os.path.join(root, filename))

 def __getitem__(self, item):
 # 读取一个音频文件,返回每个音频数据
 filename = self.wav_list[item]
 wb_wav, _ = librosa.load(filename, sr=self.sr)

 # 取 帧
 if len(wb_wav) >= self.dim:
 max_audio_start = len(wb_wav) - self.dim
 audio_start = np.random.randint(0, max_audio_start)
 wb_wav = wb_wav[audio_start: audio_start + self.dim]
 else:
 wb_wav = np.pad(wb_wav, (0, self.dim - len(wb_wav)), "constant")

 return wb_wav, filename

 def __len__(self):
 # 音频文件的总数
 return len(self.wav_list)

注意事项:19-24行:每个音频的长度不一样,如果直接读取数据返回出来的话,会造成维度不匹配而报错,因此只能每次取一个音频文件读取一帧,这样显然并没有用到所有的语音数据,

第二步

实例化 Dataset 对象

Dataset= AudioDataset("./p225", sr=16000)

如果要通过batch读取数据的可直接跳到第三步,如果你想一个一个读取数据的可以看我接下来的操作

# 实例化AudioDataset对象
train_set = AudioDataset("./p225", sr=16000)

for i, data in enumerate(train_set):
 wb_wav, filname = data
 print(i, wb_wav.shape, filname)

 if i == 3:
 break
 # 0 (8192,) ./p225\p225_001.wav
 # 1 (8192,) ./p225\p225_002.wav
 # 2 (8192,) ./p225\p225_003.wav
 # 3 (8192,) ./p225\p225_004.wav

第三步

如果想要通过batch读取数据,需要使用DataLoader进行包装

为何要使用DataLoader?

  • 深度学习的输入是mini_batch形式
  • 样本加载时候可能需要随机打乱顺序,shuffle操作
  • 样本加载需要采用多线程

pytorch提供的 DataLoader 封装了上述的功能,这样使用起来更方便。

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False)

参数:

  • dataset:加载的数据集(Dataset对象)
  • batch_size:每个批次要加载多少个样本(默认值:1)
  • shuffle:每个epoch是否将数据打乱
  • sampler:定义从数据集中抽取样本的策略。如果指定,则不能指定洗牌。
  • batch_sampler:类似于sampler,但每次返回一批索引。与batch_size、shuffle、sampler和drop_last相互排斥。
  • num_workers:使用多进程加载的进程数,0代表不使用多线程
  • collate_fn:如何将多个样本数据拼接成一个batch,一般使用默认拼接方式
  • pin_memory:是否将数据保存在pin memory区,pin memory中的数据转到GPU会快一些
  • drop_last:dataset中的数据个数可能不是batch_size的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃

返回:数据加载器

案例:

# 实例化AudioDataset对象
train_set = AudioDataset("./p225", sr=16000)
train_loader = DataLoader(train_set, batch_size=8, shuffle=True)

for (i, data) in enumerate(train_loader):
 wav_data, wav_name = data
 print(wav_data.shape) # torch.Size([8, 8192])
 print(i, wav_name)
 # ('./p225\\p225_293.wav', './p225\\p225_156.wav', './p225\\p225_277.wav', './p225\\p225_210.wav',
 # './p225\\p225_126.wav', './p225\\p225_021.wav', './p225\\p225_257.wav', './p225\\p225_192.wav')

我们来吃几个栗子消化一下:

栗子1

这个例子就是本文一直举例的,栗子1只是合并了一下而已

文件目录结构

  • p225
    • ***.wav
    • ***.wav
    • ***.wav
    • ...
  • dataset.py

目的:读取p225文件夹中的音频数据

import fnmatch
import os
import librosa
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader


class Aduio_DataLoader(Dataset):
 def __init__(self, data_folder, sr=16000, dimension=8192):
 self.data_folder = data_folder
 self.sr = sr
 self.dim = dimension

 # 获取音频名列表
 self.wav_list = []
 for root, dirnames, filenames in os.walk(data_folder):
  for filename in fnmatch.filter(filenames, "*.wav"): # 实现列表特殊字符的过滤或筛选,返回符合匹配“.wav”字符列表
  self.wav_list.append(os.path.join(root, filename))

 def __getitem__(self, item):
 # 读取一个音频文件,返回每个音频数据
 filename = self.wav_list[item]
 print(filename)
 wb_wav, _ = librosa.load(filename, sr=self.sr)

 # 取 帧
 if len(wb_wav) >= self.dim:
  max_audio_start = len(wb_wav) - self.dim
  audio_start = np.random.randint(0, max_audio_start)
  wb_wav = wb_wav[audio_start: audio_start + self.dim]
 else:
  wb_wav = np.pad(wb_wav, (0, self.dim - len(wb_wav)), "constant")

 return wb_wav, filename

 def __len__(self):
 # 音频文件的总数
 return len(self.wav_list)


train_set = Aduio_DataLoader("./p225", sr=16000)
train_loader = DataLoader(train_set, batch_size=8, shuffle=True)


for (i, data) in enumerate(train_loader):
 wav_data, wav_name = data
 print(wav_data.shape) # torch.Size([8, 8192])
 print(i, wav_name)
 # ('./p225\\p225_293.wav', './p225\\p225_156.wav', './p225\\p225_277.wav', './p225\\p225_210.wav',
 # './p225\\p225_126.wav', './p225\\p225_021.wav', './p225\\p225_257.wav', './p225\\p225_192.wav')

注意事项:

  1. 27-33行:每个音频的长度不一样,如果直接读取数据返回出来的话,会造成维度不匹配而报错,因此只能每次取一个音频文件读取一帧,这样显然并没有用到所有的语音数据,
  2. 48行:我们在__getitem__中并没有将numpy数组转换为tensor格式,可是第48行显示数据是tensor格式的。这里需要引起注意

栗子2

相比于案例1,案例二才是重点,因为我们不可能每次只从一音频文件中读取一帧,然后读取另一个音频文件,通常情况下,一段音频有很多帧,我们需要的是按顺序的读取一个batch_size的音频帧,先读取第一个音频文件,如果满足一个batch,则不用读取第二个batch,如果不足一个batch则读取第二个音频文件,来补充。

我给出一个建议,先按顺序读取每个音频文件,以窗长8192、帧移4096对语音进行分帧,然后拼接。得到(帧数,帧长,1)(frame_num, frame_len, 1)的数组保存到h5中。然后用上面讲到的 torch.utils.data.Dataset 和 torch.utils.data.DataLoader 读取数据。

具体实现代码:

第一步:创建一个H5_generation脚本用来将数据转换为h5格式文件:

第二步:通过Dataset从h5格式文件中读取数据

import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import h5py

def load_h5(h5_path):
 # load training data
 with h5py.File(h5_path, 'r') as hf:
 print('List of arrays in input file:', hf.keys())
 X = np.array(hf.get('data'), dtype=np.float32)
 Y = np.array(hf.get('label'), dtype=np.float32)
 return X, Y


class AudioDataset(Dataset):
 """数据加载器"""
 def __init__(self, data_folder):
 self.data_folder = data_folder
 self.X, self.Y = load_h5(data_folder) # (3392, 8192, 1)

 def __getitem__(self, item):
 # 返回一个音频数据
 X = self.X[item]
 Y = self.Y[item]

 return X, Y

 def __len__(self):
 return len(self.X)


train_set = AudioDataset("./speaker225_resample_train.h5")
train_loader = DataLoader(train_set, batch_size=64, shuffle=True, drop_last=True)


for (i, wav_data) in enumerate(train_loader):
 X, Y = wav_data
 print(i, X.shape)
 # 0 torch.Size([64, 8192, 1])
 # 1 torch.Size([64, 8192, 1])
 # ...

我尝试在__init__中生成h5文件,但是会导致内存爆炸,就很奇怪,因此我只好分开了,

参考

总结

到此这篇关于pytorch加载语音类自定义数据集的文章就介绍到这了,更多相关pytorch加载语音类自定义数据集内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python使用wmi模块获取windows下硬盘信息的方法
May 15 Python
Python使用Redis实现作业调度系统(超简单)
Mar 22 Python
TensorFlow平台下Python实现神经网络
Mar 10 Python
python+opencv 读取文件夹下的所有图像并批量保存ROI的方法
Jan 10 Python
详解Python3中setuptools、Pip安装教程
Jun 18 Python
Python中的正则表达式与JSON数据交换格式
Jul 03 Python
关于numpy.where()函数 返回值的解释
Dec 06 Python
Django生成PDF文档显示网页上以及PDF中文显示乱码的解决方法
Dec 17 Python
Python pathlib模块使用方法及实例解析
Oct 05 Python
python 检测nginx服务邮件报警的脚本
Dec 31 Python
Python数据清洗工具之Numpy的基本操作
Apr 22 Python
在 Python 中利用 Pool 进行多线程
Apr 24 Python
sublime3之内网安装python插件Anaconda的流程
Nov 10 #Python
python+excel接口自动化获取token并作为请求参数进行传参操作
Nov 10 #Python
python request 模块详细介绍
Nov 10 #Python
解决使用Pandas 读取超过65536行的Excel文件问题
Nov 10 #Python
python各种excel写入方式的速度对比
Nov 10 #Python
Python 使用xlwt模块将多行多列数据循环写入excel文档的操作
Nov 10 #Python
详解vscode实现远程linux服务器上Python开发
Nov 10 #Python
You might like
PHP之数组学习
2011/05/29 PHP
解析php中session的实现原理以及大网站应用应注意的问题
2013/06/17 PHP
php实现的Captcha验证码类实例
2014/09/22 PHP
Drupal简体中文语言包安装教程
2014/09/27 PHP
微信支付开发发货通知实例
2016/07/12 PHP
highchart数据源纵轴json内的值必须是int(详解)
2017/02/20 PHP
prototype 1.5相关知识及他人笔记
2006/12/16 Javascript
ymPrompt的doHandler方法来实现获取子窗口返回值的方法
2010/06/25 Javascript
一个字符串反转函数可实现字符串倒序
2014/09/15 Javascript
javascript中substring()、substr()、slice()的区别
2015/08/30 Javascript
BootStrap 模态框实现刷新网页并关闭功能
2017/01/04 Javascript
详解JavaScript中js对象与JSON格式字符串的相互转换
2017/02/14 Javascript
vue升级之路之vue-router的使用教程
2018/08/14 Javascript
微信小程序实现简单文字跑马灯
2020/05/26 Javascript
微信小程序实现电子签名功能
2020/07/29 Javascript
[51:34]Ti4主赛事胜者组 DK vs EG 2
2014/07/19 DOTA
python装饰器与递归算法详解
2016/02/18 Python
分享一下如何编写高效且优雅的 Python 代码
2017/09/07 Python
Python使用functools实现注解同步方法
2018/02/06 Python
python 用lambda函数替换for循环的方法
2018/06/09 Python
Python @property原理解析和用法实例
2020/02/11 Python
Mountain Warehouse德国官网:英国户外零售商
2019/08/11 全球购物
GUESS Factory加拿大:牛仔裤、服装及配饰
2019/09/20 全球购物
五好党支部事迹材料
2014/02/06 职场文书
班委竞选演讲稿
2014/04/28 职场文书
综艺节目策划方案
2014/06/13 职场文书
物流管理专业自荐信
2014/06/23 职场文书
教室布置标语
2014/06/26 职场文书
私人房屋买卖协议书
2014/10/04 职场文书
离婚协议书范本(通用篇)
2014/11/30 职场文书
努力学习保证书
2015/02/26 职场文书
2015年校医个人工作总结
2015/07/24 职场文书
遗嘱范文
2015/08/07 职场文书
关于五一放假的通知
2015/08/18 职场文书
22句经典语录:送给优柔寡断和胡思乱想的朋友们
2019/12/13 职场文书
windows10声卡驱动怎么安装?win10声卡驱动安装操作步骤教程
2022/08/05 数码科技