pytorch加载语音类自定义数据集的方法教程


Posted in Python onNovember 10, 2020

前言

pytorch对一下常用的公开数据集有很方便的API接口,但是当我们需要使用自己的数据集训练神经网络时,就需要自定义数据集,在pytorch中,提供了一些类,方便我们定义自己的数据集合

  • torch.utils.data.Dataset:所有继承他的子类都应该重写  __len()__  , __getitem()__ 这两个方法
    •  __len()__ :返回数据集中数据的数量
    •   __getitem()__ :返回支持下标索引方式获取的一个数据
  • torch.utils.data.DataLoader:对数据集进行包装,可以设置batch_size、是否shuffle....

第一步

自定义的 Dataset 都需要继承 torch.utils.data.Dataset 类,并且重写它的两个成员方法:

  • __len()__:读取数据,返回数据和标签
  • __getitem()__:返回数据集的长度
from torch.utils.data import Dataset


class AudioDataset(Dataset):
 def __init__(self, ...):
 """类的初始化"""
 pass

 def __getitem__(self, item):
 """每次怎么读数据,返回数据和标签"""
 return data, label

 def __len__(self):
 """返回整个数据集的长度"""
 return total

注意事项:Dataset只负责数据的抽象,一次调用getiitem只返回一个样本

案例:

文件目录结构

  • p225
    • ***.wav
    • ***.wav
    • ***.wav
    • ...
  • dataset.py

目的:读取p225文件夹中的音频数据

class AudioDataset(Dataset):
 def __init__(self, data_folder, sr=16000, dimension=8192):
 self.data_folder = data_folder
 self.sr = sr
 self.dim = dimension

 # 获取音频名列表
 self.wav_list = []
 for root, dirnames, filenames in os.walk(data_folder):
 for filename in fnmatch.filter(filenames, "*.wav"): # 实现列表特殊字符的过滤或筛选,返回符合匹配“.wav”字符列表
 self.wav_list.append(os.path.join(root, filename))

 def __getitem__(self, item):
 # 读取一个音频文件,返回每个音频数据
 filename = self.wav_list[item]
 wb_wav, _ = librosa.load(filename, sr=self.sr)

 # 取 帧
 if len(wb_wav) >= self.dim:
 max_audio_start = len(wb_wav) - self.dim
 audio_start = np.random.randint(0, max_audio_start)
 wb_wav = wb_wav[audio_start: audio_start + self.dim]
 else:
 wb_wav = np.pad(wb_wav, (0, self.dim - len(wb_wav)), "constant")

 return wb_wav, filename

 def __len__(self):
 # 音频文件的总数
 return len(self.wav_list)

注意事项:19-24行:每个音频的长度不一样,如果直接读取数据返回出来的话,会造成维度不匹配而报错,因此只能每次取一个音频文件读取一帧,这样显然并没有用到所有的语音数据,

第二步

实例化 Dataset 对象

Dataset= AudioDataset("./p225", sr=16000)

如果要通过batch读取数据的可直接跳到第三步,如果你想一个一个读取数据的可以看我接下来的操作

# 实例化AudioDataset对象
train_set = AudioDataset("./p225", sr=16000)

for i, data in enumerate(train_set):
 wb_wav, filname = data
 print(i, wb_wav.shape, filname)

 if i == 3:
 break
 # 0 (8192,) ./p225\p225_001.wav
 # 1 (8192,) ./p225\p225_002.wav
 # 2 (8192,) ./p225\p225_003.wav
 # 3 (8192,) ./p225\p225_004.wav

第三步

如果想要通过batch读取数据,需要使用DataLoader进行包装

为何要使用DataLoader?

  • 深度学习的输入是mini_batch形式
  • 样本加载时候可能需要随机打乱顺序,shuffle操作
  • 样本加载需要采用多线程

pytorch提供的 DataLoader 封装了上述的功能,这样使用起来更方便。

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False)

参数:

  • dataset:加载的数据集(Dataset对象)
  • batch_size:每个批次要加载多少个样本(默认值:1)
  • shuffle:每个epoch是否将数据打乱
  • sampler:定义从数据集中抽取样本的策略。如果指定,则不能指定洗牌。
  • batch_sampler:类似于sampler,但每次返回一批索引。与batch_size、shuffle、sampler和drop_last相互排斥。
  • num_workers:使用多进程加载的进程数,0代表不使用多线程
  • collate_fn:如何将多个样本数据拼接成一个batch,一般使用默认拼接方式
  • pin_memory:是否将数据保存在pin memory区,pin memory中的数据转到GPU会快一些
  • drop_last:dataset中的数据个数可能不是batch_size的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃

返回:数据加载器

案例:

# 实例化AudioDataset对象
train_set = AudioDataset("./p225", sr=16000)
train_loader = DataLoader(train_set, batch_size=8, shuffle=True)

for (i, data) in enumerate(train_loader):
 wav_data, wav_name = data
 print(wav_data.shape) # torch.Size([8, 8192])
 print(i, wav_name)
 # ('./p225\\p225_293.wav', './p225\\p225_156.wav', './p225\\p225_277.wav', './p225\\p225_210.wav',
 # './p225\\p225_126.wav', './p225\\p225_021.wav', './p225\\p225_257.wav', './p225\\p225_192.wav')

我们来吃几个栗子消化一下:

栗子1

这个例子就是本文一直举例的,栗子1只是合并了一下而已

文件目录结构

  • p225
    • ***.wav
    • ***.wav
    • ***.wav
    • ...
  • dataset.py

目的:读取p225文件夹中的音频数据

import fnmatch
import os
import librosa
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader


class Aduio_DataLoader(Dataset):
 def __init__(self, data_folder, sr=16000, dimension=8192):
 self.data_folder = data_folder
 self.sr = sr
 self.dim = dimension

 # 获取音频名列表
 self.wav_list = []
 for root, dirnames, filenames in os.walk(data_folder):
  for filename in fnmatch.filter(filenames, "*.wav"): # 实现列表特殊字符的过滤或筛选,返回符合匹配“.wav”字符列表
  self.wav_list.append(os.path.join(root, filename))

 def __getitem__(self, item):
 # 读取一个音频文件,返回每个音频数据
 filename = self.wav_list[item]
 print(filename)
 wb_wav, _ = librosa.load(filename, sr=self.sr)

 # 取 帧
 if len(wb_wav) >= self.dim:
  max_audio_start = len(wb_wav) - self.dim
  audio_start = np.random.randint(0, max_audio_start)
  wb_wav = wb_wav[audio_start: audio_start + self.dim]
 else:
  wb_wav = np.pad(wb_wav, (0, self.dim - len(wb_wav)), "constant")

 return wb_wav, filename

 def __len__(self):
 # 音频文件的总数
 return len(self.wav_list)


train_set = Aduio_DataLoader("./p225", sr=16000)
train_loader = DataLoader(train_set, batch_size=8, shuffle=True)


for (i, data) in enumerate(train_loader):
 wav_data, wav_name = data
 print(wav_data.shape) # torch.Size([8, 8192])
 print(i, wav_name)
 # ('./p225\\p225_293.wav', './p225\\p225_156.wav', './p225\\p225_277.wav', './p225\\p225_210.wav',
 # './p225\\p225_126.wav', './p225\\p225_021.wav', './p225\\p225_257.wav', './p225\\p225_192.wav')

注意事项:

  1. 27-33行:每个音频的长度不一样,如果直接读取数据返回出来的话,会造成维度不匹配而报错,因此只能每次取一个音频文件读取一帧,这样显然并没有用到所有的语音数据,
  2. 48行:我们在__getitem__中并没有将numpy数组转换为tensor格式,可是第48行显示数据是tensor格式的。这里需要引起注意

栗子2

相比于案例1,案例二才是重点,因为我们不可能每次只从一音频文件中读取一帧,然后读取另一个音频文件,通常情况下,一段音频有很多帧,我们需要的是按顺序的读取一个batch_size的音频帧,先读取第一个音频文件,如果满足一个batch,则不用读取第二个batch,如果不足一个batch则读取第二个音频文件,来补充。

我给出一个建议,先按顺序读取每个音频文件,以窗长8192、帧移4096对语音进行分帧,然后拼接。得到(帧数,帧长,1)(frame_num, frame_len, 1)的数组保存到h5中。然后用上面讲到的 torch.utils.data.Dataset 和 torch.utils.data.DataLoader 读取数据。

具体实现代码:

第一步:创建一个H5_generation脚本用来将数据转换为h5格式文件:

第二步:通过Dataset从h5格式文件中读取数据

import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import h5py

def load_h5(h5_path):
 # load training data
 with h5py.File(h5_path, 'r') as hf:
 print('List of arrays in input file:', hf.keys())
 X = np.array(hf.get('data'), dtype=np.float32)
 Y = np.array(hf.get('label'), dtype=np.float32)
 return X, Y


class AudioDataset(Dataset):
 """数据加载器"""
 def __init__(self, data_folder):
 self.data_folder = data_folder
 self.X, self.Y = load_h5(data_folder) # (3392, 8192, 1)

 def __getitem__(self, item):
 # 返回一个音频数据
 X = self.X[item]
 Y = self.Y[item]

 return X, Y

 def __len__(self):
 return len(self.X)


train_set = AudioDataset("./speaker225_resample_train.h5")
train_loader = DataLoader(train_set, batch_size=64, shuffle=True, drop_last=True)


for (i, wav_data) in enumerate(train_loader):
 X, Y = wav_data
 print(i, X.shape)
 # 0 torch.Size([64, 8192, 1])
 # 1 torch.Size([64, 8192, 1])
 # ...

我尝试在__init__中生成h5文件,但是会导致内存爆炸,就很奇怪,因此我只好分开了,

参考

总结

到此这篇关于pytorch加载语音类自定义数据集的文章就介绍到这了,更多相关pytorch加载语音类自定义数据集内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Django 浅谈根据配置生成SQL语句的问题
May 29 Python
Pycharm无法使用已经安装Selenium的解决方法
Oct 13 Python
对python判断ip是否可达的实例详解
Jan 31 Python
详解Python字典的操作
Mar 04 Python
python使用beautifulsoup4爬取酷狗音乐代码实例
Dec 04 Python
使用opencv将视频帧转成图片输出
Dec 10 Python
通过实例简单了解Python中yield的作用
Dec 11 Python
关于windows下Tensorflow和pytorch安装教程
Feb 04 Python
Python for循环通过序列索引迭代过程解析
Feb 07 Python
使用matplotlib动态刷新指定曲线实例
Apr 23 Python
python实现测试工具(一)——命令行发送get请求
Oct 19 Python
Python 类,对象,数据分类,函数参数传递详解
Sep 25 Python
sublime3之内网安装python插件Anaconda的流程
Nov 10 #Python
python+excel接口自动化获取token并作为请求参数进行传参操作
Nov 10 #Python
python request 模块详细介绍
Nov 10 #Python
解决使用Pandas 读取超过65536行的Excel文件问题
Nov 10 #Python
python各种excel写入方式的速度对比
Nov 10 #Python
Python 使用xlwt模块将多行多列数据循环写入excel文档的操作
Nov 10 #Python
详解vscode实现远程linux服务器上Python开发
Nov 10 #Python
You might like
php性能优化分析工具XDebug 大型网站调试工具
2011/05/22 PHP
批量去除PHP文件中bom的PHP代码
2012/03/13 PHP
PHP容易忘记的知识点分享
2013/04/30 PHP
比较详细的关于javascript中void(0)的具体含义解释
2007/08/02 Javascript
Javascript hasOwnProperty 方法 & in 关键字
2008/11/26 Javascript
利用cookie记住背景颜色示例代码
2013/11/04 Javascript
小议JavaScript中Generator和Iterator的使用
2015/07/29 Javascript
js实现文字闪烁特效的方法
2015/12/17 Javascript
用原生JS对AJAX做简单封装的实例代码
2016/07/13 Javascript
jquery radio 动态控制选中失效问题的解决方法
2018/02/28 jQuery
webpack4.0打包优化策略整理小结
2018/03/30 Javascript
详解关于vue2.0工程发布上线操作步骤
2018/09/27 Javascript
jquery无缝图片轮播组件封装
2020/11/25 jQuery
springboot+vue+对接支付宝接口+二维码扫描支付功能(沙箱环境)
2020/10/15 Javascript
[45:52]完美世界DOTA2联赛PWL S3 Forest vs INK ICE 第二场 12.09
2020/12/12 DOTA
Python的Django框架中的URL配置与松耦合
2015/07/15 Python
详解Python中使用base64模块来处理base64编码的方法
2016/07/01 Python
如何高效使用Python字典的方法详解
2017/08/31 Python
Django渲染Markdown文章目录的方法示例
2019/01/02 Python
Django web框架使用url path name详解
2019/04/29 Python
一篇文章彻底搞懂Python中可迭代(Iterable)、迭代器(Iterator)与生成器(Generator)的概念
2019/05/13 Python
java中的控制结构(if,循环)详解
2019/06/26 Python
Python:slice与indices的用法
2019/11/25 Python
PyTorch和Keras计算模型参数的例子
2020/01/02 Python
通过 Python 和 OpenCV 实现目标数量监控
2020/01/05 Python
python numpy 矩阵堆叠实例
2020/01/17 Python
Python类的绑定方法和非绑定方法实例解析
2020/03/04 Python
HTML5实践-图片设置成灰度图
2012/11/12 HTML / CSS
html5使用canvas绘制文字特效
2014/12/15 HTML / CSS
加拿大约会网站:EliteSingles.ca
2018/01/12 全球购物
Feelunique德国官方网站:欧洲最大的在线美容零售商
2019/07/20 全球购物
必须要使用游标的SQL语句有那些
2012/05/07 面试题
思想汇报范文
2013/11/04 职场文书
我未来的职业规划范文
2014/01/11 职场文书
分公司总经理岗位职责
2014/07/30 职场文书
集体生日活动方案
2014/08/18 职场文书