pytorch加载语音类自定义数据集的方法教程


Posted in Python onNovember 10, 2020

前言

pytorch对一下常用的公开数据集有很方便的API接口,但是当我们需要使用自己的数据集训练神经网络时,就需要自定义数据集,在pytorch中,提供了一些类,方便我们定义自己的数据集合

  • torch.utils.data.Dataset:所有继承他的子类都应该重写  __len()__  , __getitem()__ 这两个方法
    •  __len()__ :返回数据集中数据的数量
    •   __getitem()__ :返回支持下标索引方式获取的一个数据
  • torch.utils.data.DataLoader:对数据集进行包装,可以设置batch_size、是否shuffle....

第一步

自定义的 Dataset 都需要继承 torch.utils.data.Dataset 类,并且重写它的两个成员方法:

  • __len()__:读取数据,返回数据和标签
  • __getitem()__:返回数据集的长度
from torch.utils.data import Dataset


class AudioDataset(Dataset):
 def __init__(self, ...):
 """类的初始化"""
 pass

 def __getitem__(self, item):
 """每次怎么读数据,返回数据和标签"""
 return data, label

 def __len__(self):
 """返回整个数据集的长度"""
 return total

注意事项:Dataset只负责数据的抽象,一次调用getiitem只返回一个样本

案例:

文件目录结构

  • p225
    • ***.wav
    • ***.wav
    • ***.wav
    • ...
  • dataset.py

目的:读取p225文件夹中的音频数据

class AudioDataset(Dataset):
 def __init__(self, data_folder, sr=16000, dimension=8192):
 self.data_folder = data_folder
 self.sr = sr
 self.dim = dimension

 # 获取音频名列表
 self.wav_list = []
 for root, dirnames, filenames in os.walk(data_folder):
 for filename in fnmatch.filter(filenames, "*.wav"): # 实现列表特殊字符的过滤或筛选,返回符合匹配“.wav”字符列表
 self.wav_list.append(os.path.join(root, filename))

 def __getitem__(self, item):
 # 读取一个音频文件,返回每个音频数据
 filename = self.wav_list[item]
 wb_wav, _ = librosa.load(filename, sr=self.sr)

 # 取 帧
 if len(wb_wav) >= self.dim:
 max_audio_start = len(wb_wav) - self.dim
 audio_start = np.random.randint(0, max_audio_start)
 wb_wav = wb_wav[audio_start: audio_start + self.dim]
 else:
 wb_wav = np.pad(wb_wav, (0, self.dim - len(wb_wav)), "constant")

 return wb_wav, filename

 def __len__(self):
 # 音频文件的总数
 return len(self.wav_list)

注意事项:19-24行:每个音频的长度不一样,如果直接读取数据返回出来的话,会造成维度不匹配而报错,因此只能每次取一个音频文件读取一帧,这样显然并没有用到所有的语音数据,

第二步

实例化 Dataset 对象

Dataset= AudioDataset("./p225", sr=16000)

如果要通过batch读取数据的可直接跳到第三步,如果你想一个一个读取数据的可以看我接下来的操作

# 实例化AudioDataset对象
train_set = AudioDataset("./p225", sr=16000)

for i, data in enumerate(train_set):
 wb_wav, filname = data
 print(i, wb_wav.shape, filname)

 if i == 3:
 break
 # 0 (8192,) ./p225\p225_001.wav
 # 1 (8192,) ./p225\p225_002.wav
 # 2 (8192,) ./p225\p225_003.wav
 # 3 (8192,) ./p225\p225_004.wav

第三步

如果想要通过batch读取数据,需要使用DataLoader进行包装

为何要使用DataLoader?

  • 深度学习的输入是mini_batch形式
  • 样本加载时候可能需要随机打乱顺序,shuffle操作
  • 样本加载需要采用多线程

pytorch提供的 DataLoader 封装了上述的功能,这样使用起来更方便。

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False)

参数:

  • dataset:加载的数据集(Dataset对象)
  • batch_size:每个批次要加载多少个样本(默认值:1)
  • shuffle:每个epoch是否将数据打乱
  • sampler:定义从数据集中抽取样本的策略。如果指定,则不能指定洗牌。
  • batch_sampler:类似于sampler,但每次返回一批索引。与batch_size、shuffle、sampler和drop_last相互排斥。
  • num_workers:使用多进程加载的进程数,0代表不使用多线程
  • collate_fn:如何将多个样本数据拼接成一个batch,一般使用默认拼接方式
  • pin_memory:是否将数据保存在pin memory区,pin memory中的数据转到GPU会快一些
  • drop_last:dataset中的数据个数可能不是batch_size的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃

返回:数据加载器

案例:

# 实例化AudioDataset对象
train_set = AudioDataset("./p225", sr=16000)
train_loader = DataLoader(train_set, batch_size=8, shuffle=True)

for (i, data) in enumerate(train_loader):
 wav_data, wav_name = data
 print(wav_data.shape) # torch.Size([8, 8192])
 print(i, wav_name)
 # ('./p225\\p225_293.wav', './p225\\p225_156.wav', './p225\\p225_277.wav', './p225\\p225_210.wav',
 # './p225\\p225_126.wav', './p225\\p225_021.wav', './p225\\p225_257.wav', './p225\\p225_192.wav')

我们来吃几个栗子消化一下:

栗子1

这个例子就是本文一直举例的,栗子1只是合并了一下而已

文件目录结构

  • p225
    • ***.wav
    • ***.wav
    • ***.wav
    • ...
  • dataset.py

目的:读取p225文件夹中的音频数据

import fnmatch
import os
import librosa
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader


class Aduio_DataLoader(Dataset):
 def __init__(self, data_folder, sr=16000, dimension=8192):
 self.data_folder = data_folder
 self.sr = sr
 self.dim = dimension

 # 获取音频名列表
 self.wav_list = []
 for root, dirnames, filenames in os.walk(data_folder):
  for filename in fnmatch.filter(filenames, "*.wav"): # 实现列表特殊字符的过滤或筛选,返回符合匹配“.wav”字符列表
  self.wav_list.append(os.path.join(root, filename))

 def __getitem__(self, item):
 # 读取一个音频文件,返回每个音频数据
 filename = self.wav_list[item]
 print(filename)
 wb_wav, _ = librosa.load(filename, sr=self.sr)

 # 取 帧
 if len(wb_wav) >= self.dim:
  max_audio_start = len(wb_wav) - self.dim
  audio_start = np.random.randint(0, max_audio_start)
  wb_wav = wb_wav[audio_start: audio_start + self.dim]
 else:
  wb_wav = np.pad(wb_wav, (0, self.dim - len(wb_wav)), "constant")

 return wb_wav, filename

 def __len__(self):
 # 音频文件的总数
 return len(self.wav_list)


train_set = Aduio_DataLoader("./p225", sr=16000)
train_loader = DataLoader(train_set, batch_size=8, shuffle=True)


for (i, data) in enumerate(train_loader):
 wav_data, wav_name = data
 print(wav_data.shape) # torch.Size([8, 8192])
 print(i, wav_name)
 # ('./p225\\p225_293.wav', './p225\\p225_156.wav', './p225\\p225_277.wav', './p225\\p225_210.wav',
 # './p225\\p225_126.wav', './p225\\p225_021.wav', './p225\\p225_257.wav', './p225\\p225_192.wav')

注意事项:

  1. 27-33行:每个音频的长度不一样,如果直接读取数据返回出来的话,会造成维度不匹配而报错,因此只能每次取一个音频文件读取一帧,这样显然并没有用到所有的语音数据,
  2. 48行:我们在__getitem__中并没有将numpy数组转换为tensor格式,可是第48行显示数据是tensor格式的。这里需要引起注意

栗子2

相比于案例1,案例二才是重点,因为我们不可能每次只从一音频文件中读取一帧,然后读取另一个音频文件,通常情况下,一段音频有很多帧,我们需要的是按顺序的读取一个batch_size的音频帧,先读取第一个音频文件,如果满足一个batch,则不用读取第二个batch,如果不足一个batch则读取第二个音频文件,来补充。

我给出一个建议,先按顺序读取每个音频文件,以窗长8192、帧移4096对语音进行分帧,然后拼接。得到(帧数,帧长,1)(frame_num, frame_len, 1)的数组保存到h5中。然后用上面讲到的 torch.utils.data.Dataset 和 torch.utils.data.DataLoader 读取数据。

具体实现代码:

第一步:创建一个H5_generation脚本用来将数据转换为h5格式文件:

第二步:通过Dataset从h5格式文件中读取数据

import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import h5py

def load_h5(h5_path):
 # load training data
 with h5py.File(h5_path, 'r') as hf:
 print('List of arrays in input file:', hf.keys())
 X = np.array(hf.get('data'), dtype=np.float32)
 Y = np.array(hf.get('label'), dtype=np.float32)
 return X, Y


class AudioDataset(Dataset):
 """数据加载器"""
 def __init__(self, data_folder):
 self.data_folder = data_folder
 self.X, self.Y = load_h5(data_folder) # (3392, 8192, 1)

 def __getitem__(self, item):
 # 返回一个音频数据
 X = self.X[item]
 Y = self.Y[item]

 return X, Y

 def __len__(self):
 return len(self.X)


train_set = AudioDataset("./speaker225_resample_train.h5")
train_loader = DataLoader(train_set, batch_size=64, shuffle=True, drop_last=True)


for (i, wav_data) in enumerate(train_loader):
 X, Y = wav_data
 print(i, X.shape)
 # 0 torch.Size([64, 8192, 1])
 # 1 torch.Size([64, 8192, 1])
 # ...

我尝试在__init__中生成h5文件,但是会导致内存爆炸,就很奇怪,因此我只好分开了,

参考

总结

到此这篇关于pytorch加载语音类自定义数据集的文章就介绍到这了,更多相关pytorch加载语音类自定义数据集内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python读写配置文件的方法
Jun 03 Python
Python上下文管理器和with块详解
Sep 09 Python
Python OpenCV 调用摄像头并截图保存功能的实现代码
Jul 02 Python
对Python函数设计规范详解
Jul 19 Python
Django REST framework 如何实现内置访问频率控制
Jul 23 Python
python PyAutoGUI 模拟鼠标键盘操作和截屏功能
Aug 04 Python
python yield关键词案例测试
Oct 15 Python
python通过nmap扫描在线设备并尝试AAA登录(实例代码)
Dec 30 Python
pycharm的python_stubs问题
Apr 08 Python
TensorFlow使用Graph的基本操作的实现
Apr 22 Python
Python定义一个函数的方法
Jun 15 Python
Ubuntu 20.04安装Pycharm2020.2及锁定到任务栏的问题(小白级操作)
Oct 29 Python
sublime3之内网安装python插件Anaconda的流程
Nov 10 #Python
python+excel接口自动化获取token并作为请求参数进行传参操作
Nov 10 #Python
python request 模块详细介绍
Nov 10 #Python
解决使用Pandas 读取超过65536行的Excel文件问题
Nov 10 #Python
python各种excel写入方式的速度对比
Nov 10 #Python
Python 使用xlwt模块将多行多列数据循环写入excel文档的操作
Nov 10 #Python
详解vscode实现远程linux服务器上Python开发
Nov 10 #Python
You might like
php下实现农历日历的代码
2007/03/07 PHP
php class中self,parent,this的区别以及实例介绍
2013/04/24 PHP
PHP利用APC模块实现文件上传进度条的方法
2015/01/26 PHP
PHP文件缓存类示例分享
2015/01/30 PHP
phpstorm激活码2020附使用详细教程
2020/09/25 PHP
javascript 闭包
2011/09/15 Javascript
JavaScript异步编程:异步数据收集的具体方法
2013/08/19 Javascript
js对table的td进行相同内容合并示例详解
2013/12/27 Javascript
jquery制作select列表双向选择示例代码
2014/09/02 Javascript
jQuery Ajax和getJSON获取后台普通json数据和层级json数据用法分析
2016/06/08 Javascript
Javascript 实现全屏滚动实例代码
2016/12/31 Javascript
微信小程序之发送短信倒计时功能
2017/08/30 Javascript
vue实现长图垂直居上 vue实现短图垂直居中
2017/10/18 Javascript
Tensorflow的可视化工具Tensorboard的初步使用详解
2018/02/11 Python
python3+PyQt5图形项的自定义和交互 python3实现page Designer应用程序
2020/07/20 Python
基于Python的PIL库学习详解
2019/05/10 Python
django echarts饼图数据动态加载的实例
2019/08/12 Python
详解python 中in 的 用法
2019/12/12 Python
tensorflow入门:TFRecordDataset变长数据的batch读取详解
2020/01/20 Python
tensorflow将图片保存为tfrecord和tfrecord的读取方式
2020/02/17 Python
基于keras中的回调函数用法说明
2020/06/17 Python
Steiff台湾官网:德国金耳釦泰迪熊
2019/12/26 全球购物
State Cashmere官网:半零售价可持续蒙古羊绒
2020/02/26 全球购物
英国最受欢迎的母婴精品品牌:JoJo Maman BéBé
2021/02/17 全球购物
财务助理岗位职责
2013/11/10 职场文书
通用求职信范文模板分享
2013/12/27 职场文书
会计专业个人求职信范文
2014/01/08 职场文书
活动邀请函范文
2014/01/19 职场文书
《会变的花树叶》教学反思
2014/02/10 职场文书
美术教师岗位职责
2014/03/18 职场文书
秘书英文求职信
2014/04/16 职场文书
文明家庭先进事迹材料
2014/05/14 职场文书
财务工作失职检讨书
2014/11/21 职场文书
小学2016年第十八届推普周活动总结
2016/04/05 职场文书
MybatisPlus代码生成器的使用方法详解
2021/06/13 Java/Android
使用Django框架创建项目
2022/06/10 Python