In [1]:
import torch
import torch.utils.data as data
import torchaudio
import librosa
import numpy as np
import random
import os
import glob
In [11]:
class VariableLengthDataset(data.Dataset):
def __init__(self, manifest, snippet_length=24000, get_sequentially=False, ret_np=False, use_librosa=False):
self.manifest = manifest
self.snippet_length = snippet_length
self.get_sequentially = get_sequentially
self.use_librosa = use_librosa
self.ret_np = ret_np
self.acc = 0
self.snippet_counter = 0
self.audio_idx = 0
self.st = 0
self.data = {}
def __getitem__(self, index):
# load audio data from file or cache
if self.snippet_counter == 0:
self.audio_idx = index - self.acc
apath = self.manifest[self.audio_idx]
if apath not in self.data:
if self.use_librosa:
sig, sr = librosa.core.load(apath, sr=None)
sig = torch.from_numpy(sig).unsqueeze(1).float()
else:
sig, sr = torchaudio.load(apath, normalization=True)
self.data[apath] = (sig, sr)
else:
sig, sr = self.data[apath]
# increase iterations based on length of audio
num_snippets = int(sig.size(0) // self.snippet_length)
self.acc += max(num_snippets-1,0)
else:
apath = self.manifest[self.audio_idx]
sig, sr = self.data[apath]
num_snippets = int(sig.size(0) // self.snippet_length)
# create snippet
if self.get_sequentially:
self.st += self.snippet_length
else:
self.st = random.randrange(int(sig.size(0)-self.snippet_length))
ret_sig = sig[self.st:(self.st+self.snippet_length)]
if self.ret_np:
ret_sig = ret_sig.numpy()
# update counter for current audio file
self.snippet_counter += 1
# label creation
spkr = os.path.dirname(apath).rsplit("/", 1)[-1]
spkr = 0
# check for reset
if self.snippet_counter >= num_snippets:
self.snippet_counter = 0
self.st = 0
return ret_sig, spkr
def __len__(self):
return len(self.manifest) + self.acc
def reset_acc(self):
self.acc = 0
class FixedLengthDataset(data.Dataset):
def __init__(self, manifest, transforms = snippet_length=24000, ret_np=False, use_librosa=False):
self.manifest = manifest
self.snippet_length = snippet_length
self.use_librosa = use_librosa
self.ret_np = ret_np
self.num_snippets = 1
self.acc = 0
self.audio_idx = 0
self.st = 0
self.data = {}
def __getitem__(self, index):
# load audio data from file or cache
self.audio_idx = index if self.num_snippets == 1 else index // self.num_snippets
apath = self.manifest[self.audio_idx]
if self.use_librosa:
sig, sr = librosa.core.load(apath, sr=None)
sig = torch.from_numpy(sig).unsqueeze(1).float()
else:
sig, sr = torchaudio.load(apath, normalization=True)
# create snippet
if sig.size(0) < self.snippet_length:
ret_sig = sig
else:
self.st = random.randrange(int(sig.size(0)-self.snippet_length))
ret_sig = sig[self.st:(self.st+self.snippet_length)]
if self.ret_np:
ret_sig = ret_sig.numpy()
# label creation
#spkr = os.path.dirname(apath).rsplit("/", 1)[-1]
spkr = 0 # just using a dummy label now.
return ret_sig, spkr
def __len__(self):
return len(self.manifest) * self.num_snippets
def run_dataset():
for epoch in range(1):
all_data = [(x, label) for x, label in ds]
print(epoch, len(all_data))
try:
ds.reset_acc()
except:
pass
In [3]:
datadir = "/home/david/Programming/tests/pcsnpny-20150204-mkj"
audio_manifest = [a for a in glob.glob(datadir+"/**/*.wav", recursive=True)]
ds = VariableLengthDataset(audio_manifest, 12000, get_sequentially=True, ret_np=False, use_librosa=False)
%time run_dataset()
ds = VariableLengthDataset(audio_manifest, 12000, get_sequentially=True, ret_np=False, use_librosa=True)
%time run_dataset()
In [4]:
ds = VariableLengthDataset(audio_manifest, 12000, get_sequentially=False, ret_np=False, use_librosa=False)
dl = data.DataLoader(ds, batch_size=10)
print(len(ds), len(dl))
for mb, tgts in dl:
print(mb.size(), tgts.size())
break
In [13]:
datadir = "/home/david/Programming/tests/pcsnpny-20150204-mkj"
audio_manifest = [a for a in glob.glob(datadir+"/**/*.wav", recursive=True)]
ds = FixedLengthDataset(audio_manifest, 12000, ret_np=True, use_librosa=True)
%time run_dataset()