pytorch加载语音类自定义数据集的方法教程


Posted in Python onNovember 10, 2020

前言

pytorch对一下常用的公开数据集有很方便的API接口,但是当我们需要使用自己的数据集训练神经网络时,就需要自定义数据集,在pytorch中,提供了一些类,方便我们定义自己的数据集合

  • torch.utils.data.Dataset:所有继承他的子类都应该重写  __len()__  , __getitem()__ 这两个方法
    •  __len()__ :返回数据集中数据的数量
    •   __getitem()__ :返回支持下标索引方式获取的一个数据
  • torch.utils.data.DataLoader:对数据集进行包装,可以设置batch_size、是否shuffle....

第一步

自定义的 Dataset 都需要继承 torch.utils.data.Dataset 类,并且重写它的两个成员方法:

  • __len()__:读取数据,返回数据和标签
  • __getitem()__:返回数据集的长度
from torch.utils.data import Dataset


class AudioDataset(Dataset):
 def __init__(self, ...):
 """类的初始化"""
 pass

 def __getitem__(self, item):
 """每次怎么读数据,返回数据和标签"""
 return data, label

 def __len__(self):
 """返回整个数据集的长度"""
 return total

注意事项:Dataset只负责数据的抽象,一次调用getiitem只返回一个样本

案例:

文件目录结构

  • p225
    • ***.wav
    • ***.wav
    • ***.wav
    • ...
  • dataset.py

目的:读取p225文件夹中的音频数据

class AudioDataset(Dataset):
 def __init__(self, data_folder, sr=16000, dimension=8192):
 self.data_folder = data_folder
 self.sr = sr
 self.dim = dimension

 # 获取音频名列表
 self.wav_list = []
 for root, dirnames, filenames in os.walk(data_folder):
 for filename in fnmatch.filter(filenames, "*.wav"): # 实现列表特殊字符的过滤或筛选,返回符合匹配“.wav”字符列表
 self.wav_list.append(os.path.join(root, filename))

 def __getitem__(self, item):
 # 读取一个音频文件,返回每个音频数据
 filename = self.wav_list[item]
 wb_wav, _ = librosa.load(filename, sr=self.sr)

 # 取 帧
 if len(wb_wav) >= self.dim:
 max_audio_start = len(wb_wav) - self.dim
 audio_start = np.random.randint(0, max_audio_start)
 wb_wav = wb_wav[audio_start: audio_start + self.dim]
 else:
 wb_wav = np.pad(wb_wav, (0, self.dim - len(wb_wav)), "constant")

 return wb_wav, filename

 def __len__(self):
 # 音频文件的总数
 return len(self.wav_list)

注意事项:19-24行:每个音频的长度不一样,如果直接读取数据返回出来的话,会造成维度不匹配而报错,因此只能每次取一个音频文件读取一帧,这样显然并没有用到所有的语音数据,

第二步

实例化 Dataset 对象

Dataset= AudioDataset("./p225", sr=16000)

如果要通过batch读取数据的可直接跳到第三步,如果你想一个一个读取数据的可以看我接下来的操作

# 实例化AudioDataset对象
train_set = AudioDataset("./p225", sr=16000)

for i, data in enumerate(train_set):
 wb_wav, filname = data
 print(i, wb_wav.shape, filname)

 if i == 3:
 break
 # 0 (8192,) ./p225\p225_001.wav
 # 1 (8192,) ./p225\p225_002.wav
 # 2 (8192,) ./p225\p225_003.wav
 # 3 (8192,) ./p225\p225_004.wav

第三步

如果想要通过batch读取数据,需要使用DataLoader进行包装

为何要使用DataLoader?

  • 深度学习的输入是mini_batch形式
  • 样本加载时候可能需要随机打乱顺序,shuffle操作
  • 样本加载需要采用多线程

pytorch提供的 DataLoader 封装了上述的功能,这样使用起来更方便。

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False)

参数:

  • dataset:加载的数据集(Dataset对象)
  • batch_size:每个批次要加载多少个样本(默认值:1)
  • shuffle:每个epoch是否将数据打乱
  • sampler:定义从数据集中抽取样本的策略。如果指定,则不能指定洗牌。
  • batch_sampler:类似于sampler,但每次返回一批索引。与batch_size、shuffle、sampler和drop_last相互排斥。
  • num_workers:使用多进程加载的进程数,0代表不使用多线程
  • collate_fn:如何将多个样本数据拼接成一个batch,一般使用默认拼接方式
  • pin_memory:是否将数据保存在pin memory区,pin memory中的数据转到GPU会快一些
  • drop_last:dataset中的数据个数可能不是batch_size的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃

返回:数据加载器

案例:

# 实例化AudioDataset对象
train_set = AudioDataset("./p225", sr=16000)
train_loader = DataLoader(train_set, batch_size=8, shuffle=True)

for (i, data) in enumerate(train_loader):
 wav_data, wav_name = data
 print(wav_data.shape) # torch.Size([8, 8192])
 print(i, wav_name)
 # ('./p225\\p225_293.wav', './p225\\p225_156.wav', './p225\\p225_277.wav', './p225\\p225_210.wav',
 # './p225\\p225_126.wav', './p225\\p225_021.wav', './p225\\p225_257.wav', './p225\\p225_192.wav')

我们来吃几个栗子消化一下:

栗子1

这个例子就是本文一直举例的,栗子1只是合并了一下而已

文件目录结构

  • p225
    • ***.wav
    • ***.wav
    • ***.wav
    • ...
  • dataset.py

目的:读取p225文件夹中的音频数据

import fnmatch
import os
import librosa
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader


class Aduio_DataLoader(Dataset):
 def __init__(self, data_folder, sr=16000, dimension=8192):
 self.data_folder = data_folder
 self.sr = sr
 self.dim = dimension

 # 获取音频名列表
 self.wav_list = []
 for root, dirnames, filenames in os.walk(data_folder):
  for filename in fnmatch.filter(filenames, "*.wav"): # 实现列表特殊字符的过滤或筛选,返回符合匹配“.wav”字符列表
  self.wav_list.append(os.path.join(root, filename))

 def __getitem__(self, item):
 # 读取一个音频文件,返回每个音频数据
 filename = self.wav_list[item]
 print(filename)
 wb_wav, _ = librosa.load(filename, sr=self.sr)

 # 取 帧
 if len(wb_wav) >= self.dim:
  max_audio_start = len(wb_wav) - self.dim
  audio_start = np.random.randint(0, max_audio_start)
  wb_wav = wb_wav[audio_start: audio_start + self.dim]
 else:
  wb_wav = np.pad(wb_wav, (0, self.dim - len(wb_wav)), "constant")

 return wb_wav, filename

 def __len__(self):
 # 音频文件的总数
 return len(self.wav_list)


train_set = Aduio_DataLoader("./p225", sr=16000)
train_loader = DataLoader(train_set, batch_size=8, shuffle=True)


for (i, data) in enumerate(train_loader):
 wav_data, wav_name = data
 print(wav_data.shape) # torch.Size([8, 8192])
 print(i, wav_name)
 # ('./p225\\p225_293.wav', './p225\\p225_156.wav', './p225\\p225_277.wav', './p225\\p225_210.wav',
 # './p225\\p225_126.wav', './p225\\p225_021.wav', './p225\\p225_257.wav', './p225\\p225_192.wav')

注意事项:

  1. 27-33行:每个音频的长度不一样,如果直接读取数据返回出来的话,会造成维度不匹配而报错,因此只能每次取一个音频文件读取一帧,这样显然并没有用到所有的语音数据,
  2. 48行:我们在__getitem__中并没有将numpy数组转换为tensor格式,可是第48行显示数据是tensor格式的。这里需要引起注意

栗子2

相比于案例1,案例二才是重点,因为我们不可能每次只从一音频文件中读取一帧,然后读取另一个音频文件,通常情况下,一段音频有很多帧,我们需要的是按顺序的读取一个batch_size的音频帧,先读取第一个音频文件,如果满足一个batch,则不用读取第二个batch,如果不足一个batch则读取第二个音频文件,来补充。

我给出一个建议,先按顺序读取每个音频文件,以窗长8192、帧移4096对语音进行分帧,然后拼接。得到(帧数,帧长,1)(frame_num, frame_len, 1)的数组保存到h5中。然后用上面讲到的 torch.utils.data.Dataset 和 torch.utils.data.DataLoader 读取数据。

具体实现代码:

第一步:创建一个H5_generation脚本用来将数据转换为h5格式文件:

第二步:通过Dataset从h5格式文件中读取数据

import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import h5py

def load_h5(h5_path):
 # load training data
 with h5py.File(h5_path, 'r') as hf:
 print('List of arrays in input file:', hf.keys())
 X = np.array(hf.get('data'), dtype=np.float32)
 Y = np.array(hf.get('label'), dtype=np.float32)
 return X, Y


class AudioDataset(Dataset):
 """数据加载器"""
 def __init__(self, data_folder):
 self.data_folder = data_folder
 self.X, self.Y = load_h5(data_folder) # (3392, 8192, 1)

 def __getitem__(self, item):
 # 返回一个音频数据
 X = self.X[item]
 Y = self.Y[item]

 return X, Y

 def __len__(self):
 return len(self.X)


train_set = AudioDataset("./speaker225_resample_train.h5")
train_loader = DataLoader(train_set, batch_size=64, shuffle=True, drop_last=True)


for (i, wav_data) in enumerate(train_loader):
 X, Y = wav_data
 print(i, X.shape)
 # 0 torch.Size([64, 8192, 1])
 # 1 torch.Size([64, 8192, 1])
 # ...

我尝试在__init__中生成h5文件,但是会导致内存爆炸,就很奇怪,因此我只好分开了,

参考

总结

到此这篇关于pytorch加载语音类自定义数据集的文章就介绍到这了,更多相关pytorch加载语音类自定义数据集内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python装饰器decorator用法实例
Nov 10 Python
关于Python元祖,列表,字典,集合的比较
Jan 06 Python
Python+OpenCV人脸检测原理及示例详解
Oct 19 Python
取numpy数组的某几行某几列方法
Apr 03 Python
基于DataFrame改变列类型的方法
Jul 25 Python
Python读取excel指定列生成指定sql脚本的方法
Nov 28 Python
Python跳出多重循环的方法示例
Jul 03 Python
python django下载大的csv文件实现方法分析
Jul 19 Python
如何使用python操作vmware
Jul 27 Python
Django使用Channels实现WebSocket的方法
Jul 28 Python
基于Python+Appium实现京东双十一自动领金币功能
Oct 31 Python
Python使用monkey.patch_all()解决协程阻塞问题
Apr 15 Python
sublime3之内网安装python插件Anaconda的流程
Nov 10 #Python
python+excel接口自动化获取token并作为请求参数进行传参操作
Nov 10 #Python
python request 模块详细介绍
Nov 10 #Python
解决使用Pandas 读取超过65536行的Excel文件问题
Nov 10 #Python
python各种excel写入方式的速度对比
Nov 10 #Python
Python 使用xlwt模块将多行多列数据循环写入excel文档的操作
Nov 10 #Python
详解vscode实现远程linux服务器上Python开发
Nov 10 #Python
You might like
php 调用远程url的六种方法小结
2009/11/02 PHP
PHP学习之数组的定义和填充
2011/04/17 PHP
PHP IE中下载附件问题解决方法
2014/01/07 PHP
[原创]ThinkPHP让../Public在模板不解析(直接输出)的方法
2015/10/09 PHP
PHP Ajax跨域问题解决方案代码实例
2020/08/01 PHP
你的编程语言可以这样做吗?
2006/09/07 Javascript
JavaScript 特殊字符
2007/04/05 Javascript
JAVASCRIPT  THIS详解 面向对象
2009/03/25 Javascript
js 自定义的联动下拉框
2010/02/07 Javascript
jquery中动态效果小结
2010/12/16 Javascript
让你的博客飘雪花超出屏幕依然看得见
2013/01/04 Javascript
体验js中splice()的强大(插入、删除或替换数组的元素)
2013/01/16 Javascript
jQuery循环滚动新闻列表示例代码
2014/06/17 Javascript
js获取对象、数组的实际长度,元素实际个数的实现代码
2016/06/08 Javascript
微信小程序 数据绑定详解及实例
2016/10/25 Javascript
微信端开发--登录小程序步骤
2017/01/11 Javascript
JavaScript实现实时更新系统时间的实例代码
2017/04/04 Javascript
vue2.0 自定义日期时间过滤器
2017/06/07 Javascript
BootStrap点击保存后实现模态框自动关闭的思路(模态框)
2017/09/26 Javascript
分析JS单线程异步io回调的特性
2017/12/01 Javascript
通过vue-cli3构建一个SSR应用程序的方法
2018/09/13 Javascript
vue中的router-view组件的使用教程
2018/10/23 Javascript
监听angularJs列表数据是否渲染完毕的方法示例
2018/11/07 Javascript
详解jQuery设置内容和属性
2019/04/11 jQuery
解决Vue项目中tff报错的问题
2020/10/21 Javascript
详解Python中映射类型(字典)操作符的概念和使用
2015/08/19 Python
Python 实现还原已撤回的微信消息
2019/06/18 Python
python调用win32接口进行截图的示例
2020/11/11 Python
python3中for循环踩过的坑记录
2020/12/14 Python
HTML5实现可缩放时钟代码
2017/08/28 HTML / CSS
应用化学专业职业生涯规划书
2014/01/22 职场文书
出售房屋委托书范本
2014/09/24 职场文书
教师党员自我评议不足范文
2014/10/19 职场文书
2014年教务处工作总结
2014/12/03 职场文书
Python - 10行代码集2000张美女图
2021/05/23 Python
详解PyTorch模型保存与加载
2022/04/28 Python