In [ ]:
import os

import numpy as np
import random

import datetime

In [ ]:
audio_filenames = [ './librivox/guidetomen_%02d_rowland_64kb.mp3' % (i,) for i in [1,2,3]]
audio_filenames

In [ ]:
import librosa
librosa.__version__  # '0.5.1'

In [ ]:
sample_rate= 24000 # input will be standardised to this rate

fft_step   = 12.5/1000. # 12.5ms
fft_window = 50.0/1000.  # 50ms

n_fft = 512*4

hop_length = int(fft_step*sample_rate)
win_length = int(fft_window*sample_rate)

n_mels = 80
fmin = 125 # Hz
#fmax = ~8000

#np.exp(-7.0), np.log(spectra_abs_min)  # "Audio tests" suggest a min log of -4.605 (-6 confirmed fine)
spectra_abs_min = 0.01 # From Google paper, seems justified

win_length, hop_length

In [ ]:
# And for the training windowing :
mel_samples  = 1024
batch_size   = 8

epochs = 10

seed = 10

random.seed(seed)
np.random.seed(seed)

In [ ]:
# pip install https://github.com/telegraphic/hickle/archive/dev.zip
import hickle as hkl

def audio_to_melspectrafile(audio_filepath, regenerate=False):
    print("convert_wavs_to_spectra_learnable_records(%s)" % (audio_filepath,))
    melspectra_filepath = audio_filepath.replace('.mp3', '.melspectra.hkl')
    if os.path.isfile(melspectra_filepath) and not regenerate:
        print("  Already present")
        return melspectra_filepath

    samples, _sample_rate = librosa.core.load(audio_filepath, sr=sample_rate)
    samples = samples/np.max(samples)  # Force amplitude of waveform into range ~-1 ... +1.0

    spectra_complex = librosa.stft(samples, n_fft=n_fft, 
                       hop_length=hop_length, 
                       win_length=win_length, window='hann', )

    power_spectra = np.abs(spectra_complex)**2
    melspectra = librosa.feature.melspectrogram(S=power_spectra, n_mels=n_mels, fmin=fmin)
    
    mel_log = np.log( np.maximum(spectra_abs_min, np.abs(melspectra) ))

    # Shape of batches will be (Batch, MelsChannel, TimeStep) for PyTorch - no need for Transpose
    data = dict( 
        mels = melspectra,
        mel_log = mel_log,
        spectra_complex = spectra_complex,
        #spectra_real = spectra_complex.real, 
        #spectra_imag = spectra_complex.imag, 
    )
    
    hkl.dump(data, melspectra_filepath, mode='w', compression='gzip')
    return melspectra_filepath

In [ ]:
mel_filenames = [ audio_to_melspectrafile(f) for f in audio_filenames ]

In [ ]:
# Don't see a clean way of shuffling without having loaded all the input first...

#class DatasetFromMelspectraFile(torch.utils.data.Dataset):
#    def __init__(self, melspectra_filepath):
#        super(DatasetFromMelspectraFile, self).__init__()
#        
#        data = hkl.load(melspectra_filepath)
#        self.mels = data['mels']
#
#    def __getitem__(self, index):
#        offset = index*mel_samples 
#        a = self.mels[:, offset:offset+mel_samples]
#        return a,a  # This is a VAE situation
#
#    def __len__(self):  
#        return self.mels.shape[1]//mel_samples
#    
#class DatasetFromFiles(torch.utils.data.Dataset):
#    def __init__(self, filepath_arr, length_arr):
#        super(DatasetFromFiles, self).__init__()
#        self.filepaths = filepath_arr
#        self.file_index, self.item_index = -1,-1
#        self.d = None
#        
#    def __getitem__(self, index):
#        self.item_index+=1
#        if self.d is None or self.item_index >= len(self.d):
#            self.file_index+=1
#            self.d = DatasetFromMelspectraFile(self.filepaths[self.file_index])
#            self.item_index=0
#        return d[self.item_index]
#
#    def __len__(self):  
#        #return len(self.filepaths)
#        return -1 # DUNNO

In [ ]:
# This approach allows us to load the files into memory only as needed - 
#   But may not be necessary for our purposes, since the data is actually pretty small

def yield_batches_from(melspectra_filepath, bs=batch_size, shuffle=False):
    data = hkl.load(melspectra_filepath)
    mels = data['mels']
    offsets = np.arange(0, mels.shape[1]-mel_samples, mel_samples)
    print("Batches from file : ", melspectra_filepath, mels.shape, offsets.shape)
    if shuffle:
        np.random.shuffle(offsets)  # in-place
    batch_x = np.zeros( shape=(bs, n_mels, mel_samples) )  # Allocate once
    for batch_idx in range(0, offsets.shape[0], bs):
        for i in range(0, bs):
             batch_x[i, :, :] = mels[:, offsets[i]:offsets[i]+mel_samples]
        yield batch_x, batch_x # input -> target
    # Stop

def yield_batches_from_files(filepaths, bs=batch_size, shuffle=False, shuffle_within=False):
    if shuffle:
        #random.shuffle(filepaths)  # in-place = meh
        filepaths = random.sample( filepaths, len(filepaths) )  # original unchanged(~)
    for filepath in filepaths:
        file_batcher = yield_batches_from(filepath, bs=bs, shuffle=shuffle_within)
        for batch in file_batcher:
            yield batch
    # Stop

# This is how this code looks when used :
#for epoch in range(epochs):
#    t0 = datetime.datetime.now()
#    train_batcher = yield_batches_from_files(mel_filenames, bs=batch_size, shuffle=True, shuffle_within=True)
#    for batch_idx, batch in enumerate(train_batcher):
#        input, target = batch
#        ...

In [ ]:
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import torch.utils.data # required

In [ ]:
if False:  # Test ops to get correct Tensor format
    t = torch.from_numpy(np.array([[10,11,12,13,14,15,16,17,18,19], 
                                   [20,21,22,23,24,25,26,27,28,29], 
                                   [30,31,32,33,34,35,36,37,38,39]
                                  ]))
    t
    #t.view(2,3,5)
    t.transpose(0,1).contiguous().view(2,5,3).transpose(1,2)

    # Want to convert long set of mels into batches of length mel_samples: 
    # 0 :
    #   10    11    12    13    14    
    #   20    21    22    23    24   
    #   30    31    32    33    34   
    # 1 :
    #   15    16    17    18    19
    #   25    26    27    28    29
    #   35    36    37    38    39

def TensorFromMelspectraFile(melspectra_filepath, block_len=mel_samples):
    data = hkl.load(melspectra_filepath)
    mel_log = data['mel_log']
    
    if block_len is None: # Allow for 'whole of file' tensor(1,mels,everything)
        block_len=mel_log.shape[1]
    n_blocks = mel_log.shape[1]//block_len
    print("Read %5d log(mel[%2d]) = %4d blocks from %s" % 
          (mel_log.shape[1], mel_log.shape[0], n_blocks, melspectra_filepath,))
    
    mel_log_trunc_t = mel_log[:, :n_blocks*block_len ].T
    #print(torch.from_numpy(mel_log_trunc_t).contiguous().size())
    return ( torch.from_numpy(mel_log_trunc_t).contiguous()
             .view(n_blocks, block_len, n_mels).transpose(1,2))

In [ ]:
mel_datasets = []
for f in mel_filenames:
    t = TensorFromMelspectraFile(f)
    mel_datasets.append( torch.utils.data.TensorDataset(t, t) )

In [ ]:
mel_dataset = torch.utils.data.ConcatDataset(mel_datasets)

In [ ]:
use_cuda = torch.cuda.is_available()
ftype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
ltype = torch.cuda.LongTensor  if use_cuda else torch.LongTensor
use_cuda

In [ ]:
class WaveNettyCell(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, cond_channels=0, 
                 kernel_size=2, stride=1, dilation=1):
        super(WaveNettyCell, self).__init__()
        
        self.gate   = torch.nn.Conv1d(in_channels, hidden_channels, 
                                    kernel_size=kernel_size, 
                                    stride=stride, dilation=dilation, 
                                    padding=0, groups=1, bias=True)
        self.signal = torch.nn.Conv1d(in_channels, hidden_channels, 
                                    kernel_size=kernel_size, 
                                    stride=stride, dilation=dilation, 
                                    padding=0, groups=1, bias=True)
        
        self.cond = cond_channels>0
        if self.cond:
            self.gate_cond   = torch.nn.Conv1d(cond_channels, hidden_channels, 
                                               kernel_size=1, bias=False)
            self.signal_cond = torch.nn.Conv1d(cond_channels, hidden_channels, 
                                               kernel_size=1, bias=False)

        self.pad_end = (kernel_size-1)*(dilation+stride*0)

        self.recombine = torch.nn.Conv1d(hidden_channels, in_channels, 
                                    kernel_size=1, stride=1, dilation=1, 
                                    padding=0, # Only accepts symmetrical values - pad separately
                                    groups=1, bias=True)
            
    def forward(self, input, condition=None):
        gate = self.gate(input)
        signal = self.signal(input)
        if self.cond:
            gate   = gate   + self.gate_cond(condition)
            signal = signal + self.signal_cond(condition)

        gate = F.sigmoid(gate)
            
        mult = gate * F.tanh(signal)
        
        # The padding here is at the 'end' rather than start, 
        #   since we want the signals to be 'forward looking'
        #   this is not the same as WaveNets for generating new data
        #   which should only be 'backward looking'
        
        # Yes : There's no side/skip here : It's just a fancy feed-forward
        #return input + F.pad( self.recombine(mult), (0, self.pad_end) )
        return input*0.8 + F.pad( self.recombine(mult), (0, self.pad_end) )
        #return F.pad( self.recombine(mult), (0, self.pad_end) )

In [ ]:
class VQ_encoder(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels=128):
        super(VQ_encoder, self).__init__()
        
        # See https://fomoro.com/tools/receptive-fields/
        
        #   #3,2,1,VALID;3,2,1,VALID;3,2,1,VALID;3,2,1,VALID;3,2,1,VALID
        #self.conv = [ WaveNettyCell(in_channels, hidden_channels, 
        #                            stride=2) for c in range(4) ]
            
        #   #3,1,1,VALID;3,1,2,VALID;3,1,4,VALID;3,1,8,VALID;3,1,16,VALID
        #   receptive field = 63 timesteps
        self.conv = torch.nn.ModuleList([ WaveNettyCell(in_channels, hidden_channels, 
                                    dilation=d) for d in [1,2,4,8,16] ])
        
        self.pad_end = sum([c.pad_end for c in self.conv])
            
    def forward(self, input):
        x = input
        for c in self.conv:
            x = c(x)
        return x

In [ ]:
class VQ_quantiser(torch.nn.Module):
    def __init__(self, n_symbols, latent_dim):
        super(VQ_quantiser, self).__init__()
        
        # See : https://github.com/nakosung/VQ-VAE/blob/master/model.py#L16
        self.n_symbols, self.latent_dim = n_symbols, latent_dim
        self.embedding = torch.nn.Embedding(n_symbols, latent_dim)  # k_dim=n_symbols, z_dim=latent_dim
        # , max_norm=1.0
        
        #self.init_weights_random()        
        self.init_weights_done = False
        self.symbol_hist_init()
        
    def init_weights_random(self):
        initrange = 1. / self.n_symbols
        self.embedding.weight.data.uniform_(-initrange, initrange)             
        
    def init_weights_informed(self, Z):
        #self.embedding.weight.data.uniform_(-initrange, initrange)             
        print("init_weights_informed : Z.size() ", Z.size())
        order = torch.randperm(Z.size(0)).type(ltype)
        Z_ordered = Z[order]
        self.embedding.weight.data = Z_ordered.data[0:self.n_symbols, :]
        self.init_weights_done = True
        
    def symbol_hist_init(self):
        self.symbol_hist = torch.zeros( (self.n_symbols,) ).type(ltype)
        
    def symbol_hist_update(self, nearest_idx):
        nearest_idx_values = nearest_idx.data.cpu().type(torch.FloatTensor)
        hist_now = torch.histc(nearest_idx_values, bins=self.n_symbols, 
                                        min=0, max=self.n_symbols-1).type(ltype)
        self.symbol_hist += hist_now
            
    def forward(self, input):
        #return input, [], 0.,  0.  # Doesn't do quantisation yet...
        sz = input.size()
        
        # BCT -> BTC -> B(TC in one long strip)
        Z = input.permute(0,2,1).contiguous().view(-1, self.latent_dim)
        if not self.init_weights_done:
            self.init_weights_informed(Z)

        W = self.embedding.weight

        def L2_dist(a,b):
            return ((a - b) ** 2)
        
        # Find nearest embedding for every vector in the long strip
        #   Form matrix of all L2 (sum) distances, finds most minimum's index
        nearest_idx = L2_dist( Z[:,None], W[None,:] ).sum(2).min(1)[1]
        W_nearest_latent = W[nearest_idx]  # Convert indices into latent vectors
        
        self.symbol_hist_update(nearest_idx)

        # B(TC) -> BCT i.e. re-roll back into 
        out = W_nearest_latent.view(sz[0],sz[2],sz[1]).permute(0,2,1)

        def hook(grad):
            # This is being called for Embedding updates.  
            # Store the grad to pass along as an input update too
            # This isn't 'perfect' according to the paper, but should work well enough
            self.saved_input_to_vq = input
            self.saved_grad_for_input_to_vq = grad
            return grad

        out.register_hook(hook)
        
        # Stop gradients (_sg) for additional loss terms
        Z_sg = Z.detach()
        W_nearest_latent_sg = W_nearest_latent.detach()

        return (out,
                nearest_idx.view(sz[0],sz[2]),
                # return additional loss values too to optimise embedding and input respectively
                L2_dist(Z_sg, W_nearest_latent).sum(1).mean(),
                L2_dist(Z, W_nearest_latent_sg).sum(1).mean(), 
               )

    # back propagation for inputs to VQ (rather than just to the embeddings)
    def backward_input_itself(self):
        self.saved_input_to_vq.backward(self.saved_grad_for_input_to_vq)
        return

In [ ]:
t = torch.from_numpy(np.array([[10,11,12,13,14,15,16,17,18,19], 
                               [20,21,22,23,24,25,26,27,28,29], 
                               [30,31,32,33,34,35,36,37,38,39]
                              ], dtype=np.float32))
#t.view(2,3,5)
t_batch = t.transpose(0,1).contiguous().view(2,5,3).transpose(1,2)
t_batch

e = torch.from_numpy(np.array([
    [10,20,30], 
    [12,22,32], 
    [15,25,35], 
    [15,23,35], 
    [17,24,32], 
    [14,27,35], 
    [27,24,12], 
  ], dtype=np.float32))
e
def L2_dist_local(a,b):
    return ((a - b) ** 2)
Z = t_batch.permute(0,2,1).contiguous().view(-1, 3) # Laid out as one big batch
Z
W = e
# Sample nearest embedding 
#   Form matrix of all L2 (sum) distances, finds most minimum's index
#nearest_idx = L2_dist_local( Z[:,None], W[None,:] )
#nearest_idx = L2_dist_local( Z[:,None], W[None,:] ).sum(2)
nearest_idx = L2_dist_local( Z[:,None], W[None,:] ).sum(2).min(1)
#nearest_idx
#nearest_idx = L2_dist_local( Z[:,None], W[None,:] ).sum(2).min(1)[1]
W_nearest_latent = W[nearest_idx[1]]  # Convert indices into latent vectors
W_nearest_latent

#t = torch.from_numpy(np.array([[10,11,12], 
#                           [20,21,22], 
#                          ]))
#t.size(0)
#t[None, :] # Adds an extra column at beginning
#t[:, None] # Inserts an extra column in middle

In [ ]:
class VQ_decoder(torch.nn.Module):
    def __init__(self, in_channels, latent_channels=0, hidden_channels=128):
        super(VQ_decoder, self).__init__()
        
        self.conv = torch.nn.ModuleList([ WaveNettyCell(in_channels, hidden_channels, 
                                    #cond_channels=latent_channels,
                                    dilation=d) for d in [1,2,4,8,16] ])
        
        self.pad_end = sum([c.pad_end for c in self.conv])
            
    def forward(self, input, latent=None):
        x = input
        for c in self.conv:
            #x = c(x, latent)
            x = c(x)
        return x

In [ ]:


In [ ]:
class VQ_VAE_Model(torch.nn.Module):
    def __init__(self):
        super(VQ_VAE_Model, self).__init__()
        
        self.channels, self.n_symbols = n_mels, 16
        #self.channels, self.n_symbols = n_mels, 64
        
        self.encoder = VQ_encoder(self.channels)
        self.quant   = VQ_quantiser(self.n_symbols, self.channels)
        self.decoder = VQ_decoder(self.channels)
        
        print(f"Number of parameter variables : {len(list(self.parameters()))}")
        
        self.pad_end = self.encoder.pad_end + self.decoder.pad_end 
        
        self.optimizer = torch.optim.Adam(self.parameters(), lr=0.01)
        #self.optimizer = torch.optim.RMSprop(self.parameters())  # Converges hardly at all
        #self.optimizer = torch.optim.SGD(self.parameters(), lr=0.01)
        
        # Hmm : Adaptive learning rate ideas:
        #   https://github.com/fastai/fastai/blob/master/fastai/learner.py#L216
        # And : http://pytorch.org/docs/master/optim.html#torch.optim.lr_scheduler.CosineAnnealingLR
        
    def update_lr(self, lr):
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
            
    def forward(self, input):
        #vq_input = input
        x = self.encoder(input)
        x, symbols, loss_e1, loss_e2 = self.quant(x)
        x = self.decoder(x)
        #x = self.decoder(x)
        return x, symbols, loss_e1, loss_e2

    def train_(self, input, target, take_step=True):
        valid_len = input.size(2) - self.pad_end
        
        self.train()  # Set mode
        
        output, symbols, loss_e1, loss_e2 = self(input)
        
        #if float(loss_e2)>10.:
        #print("Symbols : %s" % (' '.join([ ('%2d' % int(v)) for v in symbols[0,0:25]]),))
        #print("Symbol.hist : %s" % (' '.join([ ('%2d' % int(v)) for v in self.quant.symbol_hist[0:20]]),))
        
        loss_rec = F.mse_loss(output[:,:,:valid_len], target[:,:,:valid_len])
        #loss_rec = F.smooth_l1_loss(output[:,:,:valid_len], target[:,:,:valid_len])
        
        loss = 10.*loss_rec + loss_e1 + 0.25*loss_e2  # MAGIC NUMBERS
        
        if take_step:
            self.gradient_step(loss)
        
        return ( loss, loss_rec, loss_e1, loss_e2 )
    
    def gradient_step(self, loss):
        self.optimizer.zero_grad()
        loss.backward(retain_graph=True)
        self.quant.backward_input_itself()
        torch.nn.utils.clip_grad_norm(self.parameters(), 0.5)  # MAGIC NUMBER
        self.optimizer.step()

    def predict_(self, input):
        self.eval()
        output, symbols, loss_e1, loss_e2 = self(input)
        return output, symbols

    def get_state(self):          # Returns a tuple of the states
        return self.state_dict(), self.optimizer.state_dict()
    def set_state(self, states):  # ... resumable here...
        self.load_state_dict(states[0])
        self.optimizer.load_state_dict(states[1])
    
    def save(self, filename='model/tmp.pkl', with_optimiser=False):
        #torch.save(self.state_dict(), 'model/epoch_{}_{:02d}.pth'.format(self.name, epoch))
        torch.save(self.state_dict(), filename)
        if with_optimiser:
            torch.save(self.optimizer.state_dict(), filename.replace('.pkl', '.optim.pkl'))

    def load(self, filename='model/tmp.pkl', with_optimiser=False):
        self.load_state_dict(torch.load(filename))
        self.quant.init_weights_done = True
        if with_optimiser:
            self.optimizer.load_state_dict(torch.load(filename.replace('.pkl', '.optim.pkl')))

In [ ]:
model = VQ_VAE_Model()
if use_cuda:
    model = model.cuda()

In [ ]:
# http://pytorch.org/tutorials/beginner/data_loading_tutorial.html

def train_epoch(epoch, learning_rate, take_step=True, shuffle=True):
    t0 = datetime.datetime.now()
    train_batches = torch.utils.data.DataLoader(mel_dataset, batch_size=batch_size, 
                                                shuffle=shuffle, num_workers=1)
    model.quant.symbol_hist_init()
    model.update_lr(learning_rate)
    losses = np.zeros( shape=(4,) )
    for batch_idx, batch in enumerate(train_batches):
        input, target = batch
        
        x = Variable( input.type(ftype) )
        y = Variable( target.type(ftype) )
        losses_arr = model.train_(x, y, take_step=take_step)
        losses += np.array( [float(v) for v in losses_arr ])
        
        #print(f"Epoch {epoch:2}, Batch {batch_idx:2}, %.6f" % (float(mse*1000*1000),))
    
    print("  Symbol.hist : %s" % (' '.join([ ('%5d' % int(v)) for v in model.quant.symbol_hist[0:16]]),))
    #print(f"Epoch {epoch:4}, %s" % (', '.join([ ("%8.2f" % v) for v in (losses/batch_idx).tolist() ]),))
    print(f"Epoch {epoch:4}, %2d batches, %s" % (
        batch_idx, ', '.join([ ("%8.2f" % v) for v in (losses).tolist() ]),))
    
    return losses

In [ ]:
def find_loss_rate(epoch, lr_current, lr_factor=1.0):  #, lr_step=1.2
    saved_state = model.get_state()    # Save before any non-standard learning rate applied
    
    loss_best = None
    loss_performance_best = None
    
    for lr_s in [0.0, -.1, +.1, -.15, +.2, -.2, ]:
        lr = (1.+(lr_s*lr_factor))*lr_current
        
        print()
        print("  Trying learning rate : %.8f" % (lr,))

        # Repeatability will increase when fix from 
        #   https://github.com/SeanNaren/deepspeech.pytorch/issues/210
        #   is installed
        model.set_state(saved_state)
        
        loss_epoch = train_epoch(epoch, lr, take_step=True, shuffle=False)
        loss_base = loss_epoch[0]  

        loss_epoch = train_epoch(epoch, lr, take_step=False, shuffle=False)
        loss_this = loss_epoch[0]  

        
        if (loss_best is None) or loss_best>loss_this:
            loss_best=loss_this
            lr_loss_best = lr

        # loss_performance is the actual performance experienced under lr
        #   we want the best performing & highest lr available
        loss_performance = (loss_base - loss_this)
        print("  loss_performance : %.8f" % (loss_performance,))
        
        if loss_performance_best is None:
            loss_performance_best=loss_performance
            lr_performance_best = lr
            continue
            
        if loss_performance_best<loss_performance:
            loss_performance_best=loss_performance
            lr_best = lr
            
        #loss_epoch = train_epoch(epoch, lr) # This is after one lr step
        #loss_this = loss_epoch[0]
        # 
        #if loss_this<loss_base:
        #    model.save() 
        #else: 
        #    model.load()  # Load the model with the previous parameters
        #    lr /= lr_factor  # Back off one step
        #    break
            
        #lr *= lr_factor
        #loss_prev = loss_this
    
    lr = lr_loss_best
    print("learning rate set to : %.8f" % (lr,))
    model.set_state(saved_state)    # Revert to after model 1 lr_initial step
    return lr

In [ ]:
lr=0.01
for epoch in range(epochs*100):
    print()
    loss_epoch = train_epoch(epoch, lr)
    if epoch % 50 == 20:  # Allow for warm-up
        lr = find_loss_rate(0, lr)

In [ ]:
# f'{234.3453453453434534:6.2}'  Wierd choice for format specifiers : overall_width.digits_of_precision

In [ ]:
t = torch.from_numpy(np.array([[10,11,12,13,14,15,16,17,18,19], 
                               [20,21,22,23,24,25,26,27,28,29], 
                               [30,31,32,33,34,35,36,37,38,39]
                              ]))
t.size(0)
np.array([3,4,5]).tolist()

In [ ]:
# Save the model
#model.save('model/16symbols-no-enc-dec_epoch_{:04d}.pth', epoch)
#model.save('model/16symbols_epoch_{:04d}_553.pth'.format(epoch))
#model.save('model/64symbols_epoch_{:04d}_344.pth'.format(epoch))
#model.save('model/16symbols_k2_epoch_{:04d}_353.pth'.format(epoch))
model.save('model/16symbols_k2_epoch_{:04d}_553.pth'.format(epoch))

In [ ]:
#model.load('model/16symbols_epoch_0999_553.pth')
#model.load('model/64symbols_epoch_0999_344.pth')
model.load('model/16symbols_k2_epoch_0999_553.pth')
#model.load('model/64symbols_k2_epoch_0999_353.pth')

In [ ]:
# Need to get in a full batch to view
mel_filename_test = mel_filenames[1]
model.quant.symbol_hist_init()
test_input = TensorFromMelspectraFile(mel_filename_test, block_len=None) # Whole file
test_input
test_output, test_symbols = model.predict_( Variable( test_input.type(ftype) ) )
#test_symbols
model.quant.symbol_hist

In [ ]:


In [ ]:
#symb_to_char = "abcdef hikjlmnop" # k=3
symb_to_char = " bcdefghikjlmnop" # k=2
#symb_to_char = "abcdefghikjlmnopqrstuvwxyzABCDEFGHIJKLMNOP RSTUVWXYZ0123456789-+*@" # k=3
#symb_to_char = "abcdefghikjlmnopqrstuvwxyzABCDEFGHI KLMNOPQRSTUVWXYZ0123456789-+*@" # k=2

chars = ''.join( [symb_to_char[v] for v in test_symbols.data.cpu().numpy()[0][:-model.pad_end]] )
len(chars),chars

In [ ]:
#mel_filename_test
with open(mel_filename_test.replace('.hkl', '.16_k2.sym'), 'wt') as f:
    f.write(chars)

In [ ]:
# Convert symbols to mels - and have a listen (tricky... since that's a whole project in itself)
#test_output...

In [ ]: