In [1]:
import torch
import torch.utils.data as data
import torchaudio
import librosa
import numpy as np
import random
import os
import glob

In [11]:
class VariableLengthDataset(data.Dataset):
    def __init__(self, manifest, snippet_length=24000, get_sequentially=False, ret_np=False, use_librosa=False):
        self.manifest = manifest
        self.snippet_length = snippet_length
        self.get_sequentially = get_sequentially
        self.use_librosa = use_librosa
        self.ret_np = ret_np
        self.acc = 0
        self.snippet_counter = 0
        self.audio_idx = 0
        self.st = 0
        self.data = {}
    def __getitem__(self, index):
        # load audio data from file or cache
        if self.snippet_counter == 0:
            self.audio_idx = index - self.acc
            apath = self.manifest[self.audio_idx]
            if apath not in self.data:
                if self.use_librosa:
                    sig, sr = librosa.core.load(apath, sr=None)
                    sig = torch.from_numpy(sig).unsqueeze(1).float()
                else:
                    sig, sr = torchaudio.load(apath, normalization=True)
                self.data[apath] = (sig, sr)
            else:
                sig, sr = self.data[apath]

            # increase iterations based on length of audio
            num_snippets = int(sig.size(0) // self.snippet_length)
            self.acc += max(num_snippets-1,0)
        else:
            apath = self.manifest[self.audio_idx]
            sig, sr = self.data[apath]
            num_snippets = int(sig.size(0) // self.snippet_length)

        # create snippet
        if self.get_sequentially:
            self.st += self.snippet_length
        else:
            self.st = random.randrange(int(sig.size(0)-self.snippet_length))
        ret_sig = sig[self.st:(self.st+self.snippet_length)]
        if self.ret_np:
            ret_sig = ret_sig.numpy()

        # update counter for current audio file
        self.snippet_counter += 1

        # label creation
        spkr = os.path.dirname(apath).rsplit("/", 1)[-1]
        spkr = 0

        # check for reset
        if self.snippet_counter >= num_snippets:
            self.snippet_counter = 0
            self.st = 0

        return ret_sig, spkr

    def __len__(self):
        return len(self.manifest) + self.acc

    def reset_acc(self):
        self.acc = 0

class FixedLengthDataset(data.Dataset):
    def __init__(self, manifest, transforms = snippet_length=24000, ret_np=False, use_librosa=False):
        self.manifest = manifest
        self.snippet_length = snippet_length
        self.use_librosa = use_librosa
        self.ret_np = ret_np
        self.num_snippets = 1
        self.acc = 0
        self.audio_idx = 0
        self.st = 0
        self.data = {}
    def __getitem__(self, index):
        # load audio data from file or cache
        self.audio_idx = index if self.num_snippets == 1 else index // self.num_snippets
        apath = self.manifest[self.audio_idx]
        if self.use_librosa:
            sig, sr = librosa.core.load(apath, sr=None)
            sig = torch.from_numpy(sig).unsqueeze(1).float()
        else:
            sig, sr = torchaudio.load(apath, normalization=True)

        # create snippet
        if sig.size(0) < self.snippet_length:
            ret_sig = sig
        else:
            self.st = random.randrange(int(sig.size(0)-self.snippet_length))
            ret_sig = sig[self.st:(self.st+self.snippet_length)]
        if self.ret_np:
            ret_sig = ret_sig.numpy()

        # label creation
        #spkr = os.path.dirname(apath).rsplit("/", 1)[-1]
        spkr = 0 # just using a dummy label now.

        return ret_sig, spkr

    def __len__(self):
        return len(self.manifest) * self.num_snippets

def run_dataset():
    for epoch in range(1):
        all_data = [(x, label) for x, label in ds]
        print(epoch, len(all_data))
        try:
            ds.reset_acc()
        except:
            pass

In [3]:
datadir = "/home/david/Programming/tests/pcsnpny-20150204-mkj"
audio_manifest = [a for a in glob.glob(datadir+"/**/*.wav", recursive=True)]

ds = VariableLengthDataset(audio_manifest, 12000, get_sequentially=True, ret_np=False, use_librosa=False)
%time run_dataset()
ds = VariableLengthDataset(audio_manifest, 12000, get_sequentially=True, ret_np=False, use_librosa=True)
%time run_dataset()


0
CPU times: user 91.1 ms, sys: 8.62 ms, total: 99.7 ms
Wall time: 291 ms
0
CPU times: user 79 ms, sys: 6.72 ms, total: 85.7 ms
Wall time: 69.8 ms

In [4]:
ds = VariableLengthDataset(audio_manifest, 12000, get_sequentially=False, ret_np=False, use_librosa=False)
dl = data.DataLoader(ds, batch_size=10)
print(len(ds), len(dl))
for mb, tgts in dl:
    print(mb.size(), tgts.size())
    break


10 1
torch.Size([10, 12000, 1]) torch.Size([10])

In [13]:
datadir = "/home/david/Programming/tests/pcsnpny-20150204-mkj"
audio_manifest = [a for a in glob.glob(datadir+"/**/*.wav", recursive=True)]
ds = FixedLengthDataset(audio_manifest, 12000, ret_np=True, use_librosa=True)
%time run_dataset()


0 10
CPU times: user 18.1 ms, sys: 0 ns, total: 18.1 ms
Wall time: 18.4 ms