In [ ]:
import numpy as np
np.set_printoptions(precision=3)

import matplotlib.pyplot as plt

import re

In [ ]:
fft_step   = 12.5/1000. # 12.5ms
fft_window = 50.0/1000.  # 50ms

audio_filenames = [ './librivox/guidetomen_%02d_rowland_64kb.mp3' % (i,) for i in [1,2,3]]
audio_filenames

mel_filenames = [ f.replace('.mp3', '.melspectra.hkl') for f in audio_filenames ]

In [ ]:
audio_filename_test_idx = 1 

#Set #FILE: = #FILE:  guidetomen_02_rowland_64kb.mp3
#Set #OFFSET_START: = #OFFSET_START: 7.0
#Set #OFFSET_END: = #OFFSET_END: 613.0
offset_start, offset_end = 7.0, 613.0

In [ ]:
mel_filename_test = mel_filenames[audio_filename_test_idx]

#with open(mel_filename_test.replace('.hkl', '.16_k2.sym'), 'rt') as f:
with open(mel_filename_test.replace('.hkl', '.64_k2.sym'), 'rt') as f:
    mel_sym_str = f.read()
    
mel_sym_chars = set(mel_sym_str)
mel_sym_dict  = { c:i for i,c in enumerate(sorted(list(mel_sym_chars))) }
mel_sym = np.array( [mel_sym_dict[c] for c in mel_sym_str] )

mel_sym_silence = mel_sym_dict[' ']

mel_sym_str[100:115], mel_sym[100:115], mel_sym.shape[0]

In [ ]:
def print_one_sec_per_line(s, t_min=0., t_max=None):
    #each_line = int(1/fft_step)
    if t_max is None: t_max = len(s)*fft_step
    for t in np.arange(t_min, t_max, 1.):
        i_min = int(t/fft_step)
        i_max = int((t+1.)/fft_step)
        if i_max>t_max/fft_step: 
            i_max = int(t_max/fft_step)
        print("%6.2f %s" % (t, s[i_min:i_max]) )
        
print_one_sec_per_line(mel_sym_str, 0.0, 10.0)

In [ ]:
# Compress the sym data, by counting the duplicates, 
# and storing 'initial char', 'char count' and 'initial char idx'
mel_sym_ct, mel_sym_cc, mel_sym_cn = [], [], []
prev_c, prev_n = '', 0
for t, c in enumerate(mel_sym_str):
    if c==prev_c: 
        prev_n+=1 # Add one to count
    else:
        mel_sym_cn.append(prev_n)  # Store count of previous char
        mel_sym_ct.append(t)       # Start on new char's index
        mel_sym_cc.append(c)       # Start on new char's value
        prev_c, prev_n = c, 1
mel_sym_cn.append(prev_n)  # Store last count value

mel_sym_ct = np.array( mel_sym_ct )
mel_sym_ci = np.array( [mel_sym_dict[c] for c in mel_sym_cc] )
mel_sym_cn = np.array( mel_sym_cn[1:])   # Kill the first value, and convert to numpy

(mel_sym_str[66:95],   # Original String 
 mel_sym_cc[0:10],     # Distinct characters
 mel_sym_ci[0:10],     # Distinct symbol indicies
 mel_sym_cn[0:10],     #   x number of repetitions
 mel_sym_ct[0:10],     # mel time
 mel_sym_ci.shape[0],) # total number of mels

In [ ]:


In [ ]:
# Let's find the ranges of the actual sounds in the 
# speech audio - ignoring short silences

#silence_is_short_len = 8  # 100ms
silence_is_short_len = 40  # 500ms

audio_spans = []
def add_span(span_start_index, s_i, s_n):
    if len(s_i)>0:
        span=[]
        for sym, count in zip(s_i, s_n):
            span.extend( [sym]*count )
        audio_spans.append(dict(
            t_start=span_start_index,
            t_end  =span_start_index+len(span),
            syms=s_i,
            count=s_n,
            span=span,
        ))
        #print(span_start_index, 
        #      mel_sym_str[span_start_index:span_start_index+len(span)])

span_start, span_i, span_n = -1, [], []  # Indices and counts
for idx, c in enumerate(mel_sym_cc):
    if span_start<0: 
        span_start = mel_sym_ct[idx]
    ci, cn = mel_sym_ci[idx], mel_sym_cn[idx]
    
    if ci==mel_sym_silence and cn>silence_is_short_len:
        #print(cn)
        add_span(span_start, span_i, span_n)
        span_start, span_i, span_n = -1, [], []
        continue 
        
    span_i.append(ci)
    span_n.append(cn)
add_span(span_start, span_i, span_n)

len(audio_spans), #audio_spans[92]

In [ ]:
# Let's take off the first few and last few (outside offset_start, offset_end)
audio_spans = [a for a in audio_spans
                  if  a['t_start']*fft_step>offset_start 
                  and a['t_end']  *fft_step<offset_end
              ]
len(audio_spans), #audio_spans[92]

In [ ]:
import os
import random, string

# pip install soundfile
import soundfile
import librosa

from IPython.display import Audio as audio_playback_widget

os.makedirs('./data/tmp', exist_ok=True)

In [ ]:
audio_filename_test = audio_filenames[audio_filename_test_idx]
audio_samples, _sample_rate = librosa.core.load(audio_filename_test, sr=None)
audio_samples = audio_samples/np.max(audio_samples)

def play_audio_span(s, autoplay=False):
    hsh = ''.join(random.choices(string.ascii_uppercase + string.digits, k=8))
    f = './data/tmp/%s.wav' % (hsh,)
    def ts(t_mel, end=False):  # t_mel to samples
        return int( (t_mel*fft_step + (fft_window if end else 0.))*_sample_rate )
    audio_span = audio_samples[ts(s['t_start']):ts(s['t_end'], end=True)]
    soundfile.write(f, audio_span, samplerate=_sample_rate)
    
    plt.figure(figsize=(12,2))
    plt.plot(audio_span)
    #plt.plot(np.arange(s['t_start'],s['t_end'],), audio_span)
    #plt.xticks( np.arange(s['t_start'], s['t_end'], 20.), rotation=90 )
    plt.grid(True)

    plt.show()
    
    return audio_playback_widget(f, autoplay=autoplay)

In [ ]:
#play_audio_span(audio_spans[0])  # audio_span[0] is first word in text
#play_audio_span(audio_spans[113])  # audio_span[-1] is last phrase in text

In [ ]:
# Bachelors  (near beginning)
print_one_sec_per_line(mel_sym_str, 0./80, 800./80)
#play_audio_span(dict(t_start=0, t_end=800), autoplay=True)
print()
print_one_sec_per_line(mel_sym_str, 640./80, 700./80)
play_audio_span(dict(t_start=640, t_end=690), autoplay=True)

In [ ]:


In [ ]:


In [ ]:


In [ ]:
# Read in text as words

# Create initial array of word starts
#   with initial guess of maximum error bars
# Create map of word -> word_index

In [ ]:


In [ ]:
with open(mel_filename_test.replace('.melspectra.hkl', '.txt'), 'rt') as f:
    mel_txt = f.read()

txt_arr = mel_txt.replace('\n', ' ').split(' ')
#txt_arr.insert(0, '#EOS') # Extra one at start
#txt_arr.insert(0, '') # Helps start process 
len(txt_arr), ','.join( txt_arr[0:10] )

In [ ]:
# Quick-and-dirty sanity check : Flip all words over and expect failure
#txt_arr = txt_arr[::-1]

In [ ]:
sentence_spans = []
def add_sentence_span(span_start_index, span):
    if len(span)>0:
        sentence_spans.append(dict(
            t_start= span_start_index, # within txt_arr
            t_end  = span_start_index+len(span),
            span   = span,
        ))

span_start, span_words = -1, []  # Indices and text
for idx, w in enumerate(txt_arr):
    if span_start<0: 
        span_start = idx
    if w=='#EOS':
        add_sentence_span(span_start, span_words)
        span_start, span_words = -1, []
        continue 
    span_words.append(w)
add_sentence_span(span_start, span_words)

len(sentence_spans), sentence_spans[0]

In [ ]:


In [ ]:
# Set up some matrices to fill in - these are timings in seconds 
txt_length_est = np.array( [ len(s.replace('#EOS', '#EOS-is-longish')) for s in txt_arr ] )
#txt_err = np.zeros_like( txt_starts )

In [ ]:
# Total up the lengths of the words (in characters) 
#  - the initial timing guess is going to be proportional 
txt_length_est[0:10]

In [ ]:
txt_err_min = 2. # Seconds
txt_err_pct = 0.10 # i.e. plus or minus this amount of 'unknown length'

In [ ]:
# Need function to fix starts and errs for every word in txt_arr
#   given a dict of word_index to known starts+errs
known_starts = {
    #0 : (offset_start, txt_err_min),
    #(txt_length_est.shape[0]-1) : (offset_end, txt_err_min),
    0 : (audio_spans[0]['t_start']*fft_step, txt_err_min),
    (txt_length_est.shape[0]-1) : (audio_spans[-1]['t_end']*fft_step, txt_err_min),
}
known_starts

In [ ]:
def create_starts_and_errs(known_starts):
    txt_starts  = np.zeros( txt_length_est.shape )
    txt_err_fwd = np.zeros_like( txt_starts )
    txt_err_bwd = np.zeros_like( txt_starts )
    
    known_start_i = sorted( known_starts.keys() )
    for i_start, i_end in  zip( known_start_i[:-1], known_start_i[1:]):
        v_start, v_end = known_starts[i_start], known_starts[i_end]
        
        # Create a span including updated duration estimates
        actual_duration = (v_end[0]-v_start[0])
        length_est = txt_length_est[i_start:i_end] # Copy range
        length_adj = length_est/length_est.sum()*actual_duration
        
        # Put the span into the txt_starts array
        cs_fwd = np.cumsum(length_adj)
        cs_bwd = np.flip(np.flip(length_adj, 0).cumsum(), 0)
        
        txt_starts[i_start] = v_start[0]
        txt_starts[i_start+1:i_end+1] = v_start[0] + cs_fwd

        txt_err_fwd[i_start] = v_start[1]
        txt_err_fwd[i_start+1:i_end+1] = v_start[1] + cs_fwd*txt_err_pct

        txt_err_bwd[i_end] = v_end[1]
        txt_err_bwd[i_start:i_end] = v_end[1] + cs_bwd*txt_err_pct
    
    return txt_starts, np.minimum(txt_err_fwd, txt_err_bwd)

#np.cumsum( np.array( [66,3,72,12,42] ))

In [ ]:
txt_starts, txt_err = create_starts_and_errs(known_starts)

txt_length_est[0:10], txt_starts[0:10], txt_starts[-10:]
txt_err[0:5], txt_err[ txt_starts.shape[0]//2-5:txt_starts.shape[0]//2 ], txt_err[-5:],

In [ ]:
plt.plot(txt_starts, 'b')
plt.plot(txt_err*8., 'r')
plt.show()

In [ ]:
## Old version
#txt_starts = txt_length_est.cumsum()
#txt_starts = offset_start + txt_starts*(offset_end-offset_start)/txt_starts[-1]
#txt_starts[0:10]

In [ ]:
## Old version
#txt_err = np.array( [ (i-txt_starts.shape[0]) for i, txt in enumerate(txt_arr) ] )
#txt_err = np.square(txt_err)
#txt_err = txt_err.max() - txt_err # Now a parabola peaking in the middle
#
#txt_err_scale = (txt_err_max - txt_err_min) / txt_err[ txt_starts.shape[0]//2 ]
#
#txt_err = txt_err*txt_err_scale + txt_err_min
#
#txt_err[0:5], txt_err[ txt_starts.shape[0]//2:txt_starts.shape[0]//2+5 ]

In [ ]:
## Global alignment

# Idea : Train up a word embedding based on range (within error bars)
# Function to map word to set of ranges
# Function to convert ranges into a % of whole
# Table of words with % coverage (to check whether that's a doable idea)
# Find word-range average vector (vs. not-in-word-range)

In [ ]:
word_to_idx={}
for i, w in enumerate(txt_arr):
    if w not in word_to_idx: word_to_idx[w]=[]
    word_to_idx[w].append( i )
word_to_idx['bachelors'], len(txt_arr), len(word_to_idx)  #bachelor mere

In [ ]:
def word_range_mask(w):
    mask = np.zeros_like(mel_sym)
    for i in word_to_idx.get(w, []):
        t_min = txt_starts[i]-txt_err[i]
        if t_min<0: t_min=0
        
        if i+1 < txt_starts.shape[0]:
            i_next = i+1
        else: # Rare end-case:
            i_next = i
        t_max = txt_starts[i_next]+txt_err[i]
            
        #print("(i_min, i_max) = ", i_min, i_max)
        mask[ int(t_min/fft_step):int(t_max/fft_step) ] = 1.
    return mask

word_mask=word_range_mask('bachelors')  # bachelor mere woman man the
np.sum(word_mask) / word_mask.shape[0]

In [ ]:
plt.plot(word_mask, 'r')
plt.show()

In [ ]:
# Let's look at word frequencies, and mask coverage
words_freq_ordered = sorted(word_to_idx.keys(), key=lambda k: -len(word_to_idx[k]))
len(words_freq_ordered), words_freq_ordered[0], words_freq_ordered[-1]

In [ ]:
for w in words_freq_ordered:
    n = len(word_to_idx[w])
    if n<2: continue # Not enough for any meaningful stats...
    word_mask=word_range_mask(w)
    coverage=np.sum(word_mask) / word_mask.shape[0]
    print("%6.2f%%, %4d, %s" % (coverage*100., n, w))

In [ ]:
# Sweet spot is defined by maximising # of examples, while minimising
#   probability of mis-identification (i.e. low coverage is good)
# Perhaps, though, need to avoid overlapping masks (if possible?)

# Candidates (?) : love, one, never, how, marriage, give, me, right, greatest, getting
# Except : bachelor's and bachelors, 
# Except : man(~woman), woman's, man's

In [ ]:
#mel_sym_dict

In [ ]:
# Now create histogram of symbol frequencies corresponding to mask
def histogram_freqs(mask, remove_silence=True):
    print(np.sum(mask) / mask.shape[0])
    #inside = bincount(mel_sym, weights=mask)
    n_sym = len(mel_sym_chars)
    
    inside_bins  = np.bincount(mel_sym, weights=mask)
    outside_bins = np.bincount(mel_sym, weights=1-mask)
    
    if remove_silence:
        inside_bins[mel_sym_silence]=0
        outside_bins[mel_sym_silence]=0
    
    plt.figure(figsize=(12,4))
    rects1 = plt.bar(np.arange(0, n_sym)-.2, width=0.4, color='r',
                     height=inside_bins/np.sum(inside_bins))
    rects2 = plt.bar(np.arange(0, n_sym)+.2, width=0.4, color='b',
                     height=outside_bins/np.sum(outside_bins))

    plt.xlabel('Symbol#')
    plt.ylabel('Freq')
    plt.xticks(np.arange(0, n_sym, 1.0))
    plt.grid(True)
    plt.show()
    
histogram_freqs(word_range_mask('bachelors'))

In [ ]:


In [ ]:


In [ ]:
## Local alignment (looking within word-error-bar ranges only)

# Idea : Mostly textual alignment
# Have a look symbols in appropriate ranges for several word examples
# See whether a simple optimisation can align multiple segments
# Would reduce error bars massively
# Possibly : 
#   http://mlpy.sourceforge.net/docs/3.4/lcs.html#standard-lcs  (GPL3, though)
#   https://github.com/Samnsparky/py_common_subseq (MIT)
#   https://en.wikibooks.org/wiki/Algorithm_Implementation/Strings/Longest_common_substring#Python_3 
#   https://docs.python.org/3/library/difflib.html (not quite...)

In [ ]:
#  !pip install py_common_subseq
#import py_common_subseq as subseq  # But doesn't output alignment

# Homebrew version : does what we actually want (return array of index correspondences)
#  See : https://en.wikipedia.org/wiki/Longest_common_subsequence_problem
def lcs(a_arr, b_arr, just_length=True):  
    m = np.zeros( (len(a_arr)+1, len(b_arr)+1) )
    # offset all the i and j by 1, since we need blank first row+col
    for i, a in enumerate(a_arr):
        for j, b in enumerate(b_arr):
            if a == b:
                m[i+1, j+1] = m[i, j] + 1
            else:
                m[i+1, j+1] = max(m[i+1, j], m[i, j+1])
    #print(m)
    if just_length:
        return m[-1, -1]
                
    #?  a_i=np.zeros( (lengths[-1,-1], ))
    a_i, b_i = [], []
    i, j = len(a_arr), len(b_arr)
    while i>0 and j>0:
        if m[i, j] == m[i-1, j]:
            i -= 1
        elif m[i, j] == m[i, j-1]:
            j -= 1
        else: # a_arr[i-1] == b_arr[j-1]
            a_i.append(i-1)
            b_i.append(j-1)
            i -= 1
            j -= 1
    
    return a_i[::-1], b_i[::-1]

lcs([17, 9,99,7,4,8,3,7,5,2,4,1,2,1,2,4,5,6],
    [1, 17,11,7,4,8,3,7,0,2,4,0,2,  2,  5,6], just_length=True)

In [ ]:
# So now let's pick a word, and find its shortest range
# Then go through 1 second (=80 position) chunks within that range,
# and see how that matches each of the ~1000 spans 
# (possibly by searching over overlapping 2 second chunks) 

#word_probe = 'greatest'
#word_probe = 'getting'
word_probe = 'bachelors'

In [ ]:
# Find all the audio_spans that overlap the mask for this word
mask_probe = word_range_mask(word_probe)
spans_probe = [ s for s in audio_spans 
                if mask_probe[s['t_start']]>0 
                or mask_probe[s['t_end']-1]>0 ]
len(spans_probe), len(audio_spans), np.sum(mask_probe)/mask_probe.shape[0]

In [ ]:
# Now go through these spans in segments of 1 second, in 0.5 second increments...
probe_win = int(1./fft_step)
probe_step = probe_win//2

probe_results=None
def get_probe_results(spans_probe):
    probe_res=[]  # Will be a list of probe score arrays
    for sp in spans_probe:
        #for i in range(sp['t_start'], sp['t_end'], probe_step):
        for i in range(0, len(sp['span']), probe_step):
            segment_probe = sp['span'][i:i+probe_win]
            #print( i, sp['span'] )
            #print( segment_probe )
            #break
            # And scan the segment_probe across all the spans in the audio,
            #   With similar windowing idea

            audio_res = []
            for sa in audio_spans:
                #for j in range(sa['t_start'], sa['t_end'], probe_step):
                for j in range(0, len(sa['span']), probe_step):
                    segment_audio = sa['span'][j:j+probe_win]

                    probe_minlen = min(len(segment_probe), len(segment_audio))
                    audio_res.append( lcs(segment_audio, segment_probe)/(probe_minlen+1) )

            probe_res.append( audio_res )
        #break

    return np.array( probe_res )

# Don't do this unless you intend to...
#probe_results = get_probe_results(spans_probe)
#probe_results.shape

In [ ]:
#[5,6,7,6,8][2:13]
if probe_results is not None:
    probe_results.sum() 
    
    plt.figure(figsize = (20,2))
    plt.imshow(probe_results, cmap='Purples')  #, interpolation='nearest'
    plt.show()

In [ ]:


In [ ]:


In [ ]:


In [ ]:


In [ ]:


In [ ]:
embedding_dim = 8

In [ ]:
# Create an embedding for all the symbols

np.random.seed(100)

sym_embed = np.random.normal( size=(len(mel_sym_chars), embedding_dim) )
sym_embed = sym_embed / np.linalg.norm(sym_embed, axis=1)[:, np.newaxis]
sym_embed[mel_sym_silence, :] = 0.

sym_embed[3,:], np.linalg.norm(sym_embed[3,:])

In [ ]:


In [ ]:
# Create combined embedding for each audio_span

#        audio_spans.append(dict(
#            t_start=span_start_index,
#            t_end  =span_start_index+len(span),
#            syms=s_i,
#            count=s_n,
#            span=span,
#        ))

def get_audio_spans_embedding(period_in_sec=None, 
                              beginning=False, ending=False):
    overall_bins= np.bincount(mel_sym)
    overall_emb = np.dot(overall_bins, sym_embed)
    overall_emb /= np.linalg.norm(overall_emb)

    emb = np.zeros( (len(audio_spans), embedding_dim))
    
    for i, s in enumerate(audio_spans):
        t_start, t_end = s['t_start'], s['t_end']
        
        if beginning:
            t_end_new = t_start + period_in_sec/fft_step
            if True or t_end>t_end_new:
                t_end=t_end_new
        if ending:
            t_start_new = t_end - period_in_sec/fft_step
            if True or t_start<t_start_new:
                t_start=t_start_new
            
        inside_bins = np.bincount(mel_sym[ int(t_start):int(t_end) ], 
                                minlength=len(mel_sym_chars))

        inside_emb  = np.dot( inside_bins, sym_embed)
        inside_emb /= np.linalg.norm(inside_emb)

        #outside_emb = overall_emb - inside_emb
        span_emb = inside_emb - overall_emb

        norm = np.linalg.norm(span_emb)
        if norm>0.:
            span_emb /= norm

        #print(i, span_emb)
        emb[i,:] = span_emb
        
    return emb

#overall_emb
audio_spans_embedding = get_audio_spans_embedding()
print(audio_spans_embedding[100])

In [ ]:
plt.figure(figsize=(6,6))
plt.imshow(np.dot(audio_spans_embedding, audio_spans_embedding.T), 
           aspect='auto', origin='lower', interpolation='nearest', 
           vmin=-1., vmax=1., cmap='bwr')
plt.grid(True)
plt.show()

In [ ]:
# For every word (n_occurrences>0) create embedding via masks
#   NB: masks depend on current txt_starts, txt_err

def create_word_embeddings(ignore_rare=True, ignore_frequent=True):
    word_embed = dict()
    
    overall_bins= np.bincount(mel_sym)
    overall_emb = np.dot(overall_bins, sym_embed)
    overall_emb /= np.linalg.norm(overall_emb)
    
    #for w in words_freq_ordered:
    for w in word_to_idx.keys():
        n = len(word_to_idx[w])
        if ignore_rare and n<2: 
            # Not enough for any logic operations to make a difference...
            continue 
            
        word_mask=word_range_mask(w)
        if ignore_frequent and np.sum(word_mask)>0.80*word_mask.shape[0]:
            # Too broad to be worthwhile... (includes #EOS)
            continue 
        if w == '#EOS':
            continue
        
        inside_bins = np.bincount(mel_sym, weights=word_mask)
        inside_emb  = np.dot( inside_bins, sym_embed)
        inside_emb /= np.linalg.norm(inside_emb)

        #outside_emb = overall_emb - inside_emb
        #word_emb = inside_emb - outside_emb
        
        word_emb = inside_emb - overall_emb

        norm = np.linalg.norm(word_emb)
        if norm>0.:
            word_emb /= norm
        #else == nonsense

        word_embed[w]=word_emb
    return word_embed

word_embedding = create_word_embeddings()
#len(word_embedding)
word_embedding['love'], word_embedding['bachelors']

In [ ]:
# Create combined embedding for each sentence
#   Uses the word_embeddings above (will change if txt_starts, txt_err changes)
#txt_arr[:6], txt_arr[-6:]

#        sentence_spans.append(dict(
#            t_start= span_start_index,
#            t_end  = span_start_index+len(span),
#            span   = span,
#        ))

def create_sentence_embedding(word_embedding):
    ss_embedding = np.zeros( (len(sentence_spans), embedding_dim))
    for i, s in enumerate(sentence_spans):
        span_emb = np.zeros( (embedding_dim,) )
        for w in s['span']:
            if w in word_embedding:
                span_emb += word_embedding[w]
        norm = np.linalg.norm(span_emb)
        if norm>0.:
            span_emb /= norm
        else:
            span_emb = word_embedding['marriage']  # Aribitrary to avoid ==0
        ss_embedding[i, :] = span_emb
    return ss_embedding

sentence_spans_embedding = create_sentence_embedding(word_embedding)        

print(sentence_spans[0]['span'])
print(sentence_spans_embedding[0])
print(sentence_spans_embedding[-1])

In [ ]:
plt.figure(figsize=(6,6))
plt.imshow(np.dot(sentence_spans_embedding, sentence_spans_embedding.T), 
           aspect='auto', origin='lower', interpolation='nearest', 
           vmin=-1., vmax=1., cmap='bwr')
plt.show()  # Sentences are much more dissimilar if rare words are rejected

In [ ]:
#sentence_spans_embedding.shape, audio_spans_embedding.shape
#((69, 8), (114, 8))

In [ ]:


In [ ]:
# Do a vector DTW across spans and sentences
 
cost_matrix, warp_path = librosa.dtw(
    audio_spans_embedding.T, sentence_spans_embedding.T, 
    metric='cosine', subseq=False,
    #step_sizes_sigma = np.array([[1, 1], [0, 1], [1, 0]]),
    #weights_add = np.array([0, 0, 0]),
    #weights_mul = np.array([1, 1, 1]),
    
    step_sizes_sigma = np.array([[1, 1], [1, 0]]),  # Disallow 2 sentences for 1 audio
    weights_add = np.array([0, 0]),
    weights_mul = np.array([1, 1]),
    
    #band_rad=0.25
)
cost_matrix.shape, warp_path.shape

In [ ]:
#   Graph out the DTW path
#     See : https://musicinformationretrieval.com/dtw_example.html
plt.imshow(cost_matrix.T, aspect='auto', origin='lower', 
           interpolation='nearest', cmap='gray')
plt.plot(warp_path[:,0], warp_path[:,1], 'r')
plt.show()

In [ ]:
i=29; sentence_spans_embedding[i,:], ' '.join(sentence_spans[i]['span'])
#warp_path  #

In [ ]:
# Function to output sentence and audio in that span
warp_path[ warp_path[:,1]==i, 0]

In [ ]:
# 48 == 'a fool and her money are soon courted'
play_audio_span(audio_spans[48])

In [ ]:
#   Output new 'known_starts' dictionary
#   Update txt_starts, txt_err
def add_copy_path_to_known_starts(warp_path, txt_err_min_new=txt_err_min*2.0):
    #known_starts = dict()
    #new_starts = {
    #    0 : (offset_start, txt_err_min_new),
    #    (txt_length_est.shape[0]-1) : (offset_end, txt_err_min_new),
    #}
    new_starts = dict( known_starts )

    #audio_spans    warp_path[:,0] t_start= span_start_index, # within mel_arr
    #sentence_spans warp_path[:,1] t_start= span_start_index, # within txt_arr
    
    # There are more audio_spans than sentence_spans
    # So, for each sentence, look for the series of audios that correspond
    # and pick the ?middle? one as the anchor
    for i, s in enumerate(sentence_spans):
        audio_span_i_arr = warp_path[ warp_path[:,1]==i, 0]
        #print(audio_span_i_arr)
        #audio_span_i = audio_span_i_arr[ audio_span_i_arr.shape[0]//2 ]
        audio_span_i = audio_span_i_arr[ -1 ]
        #print(audio_span_i)

        start_sec = audio_spans[audio_span_i]['t_start']*fft_step
        if start_sec<offset_start:
            start_sec=offset_start
        
        new_starts[ s['t_start'] ] = ( start_sec, txt_err_min_new )
        #print("%6d -> (%6.2f, %6.2f)" % (s['t_start'], start_sec, txt_err_min_new,))
    
    txt_starts, txt_err = create_starts_and_errs(new_starts)
    
    return txt_starts, txt_err
    
txt_starts, txt_err = add_copy_path_to_known_starts(warp_path, 3.0)
#known_starts

In [ ]:
word_embedding['love'], word_embedding['bachelors']

In [ ]:
word_embedding = create_word_embeddings()
sentence_spans_embedding = create_sentence_embedding(word_embedding)

In [ ]:
word_embedding['love'], word_embedding['bachelors']

In [ ]:


In [ ]:
# http://colinraffel.com/publications/thesis.pdf

# Idea : Align symbols using linear DTW (v fast)
# Assign random increments to symbols, and see whether linear DTW can match the alignments
# https://blog.acolyer.org/2016/05/11/searching-and-mining-trillions-of-time-series-subsequences-under-dynamic-time-warping/

# Idea : Align (using DTW) the mels or embeddings within the word-error-bar segments
# This would be multiple small alignments too
# Needs vector DTW (like librosa has...)

#  https://github.com/pierre-rouanet/dtw (GPL3 : Unusable)
#  https://github.com/slaypni/fastdtw/tree/master/fastdtw  (MIT)
#  https://github.com/ricardodeazambuja/DTW (cardiod example : CC0 licensed)


## Combo

# Use global word embeddings to weight samples in local ranges


## Global alignment

# Idea : Train up a word embedding based on range (within error bars)
# Find word-range average vector (vs. not-in-word-range)
# Use this to do a DTW across all words vs all timesteps

# Alternative : Use same word embedding to do some kind of annealing :
#   gradually reducing error bars (and improving embedding, etc)

In [ ]:
len(audio_spans), len(sentence_spans)
#sentence_spans[32]['t_start'], txt_starts[707]

In [ ]:


In [ ]:
# Reset the sentence timings
txt_starts, txt_err = create_starts_and_errs(known_starts)

In [ ]:
# Non-DTW approach : First guess is to chose nearest-neighbours 
#  between sentence_spans position guesses and audio_spans
#  BUT : What if the nearest neighbours create dupes?  
#     Could use some kind of repulsion (like a chain of springs)...  
#     Or: First come, first served, others 'free' (or both 'free')


# Go through the sentence_spans, and walk a pointer in the audio_spans
#   beyond the estimated start.  Then choose the closest of the 
#   audio_spans starts (previous one or this one)
#   If previous one chosen, check whether it is already 'occupied' - 
#     and in that case 'free' both this and the previous sentence...
def sentence_ends_find_nearest_audio_gaps():
    s_to_a = [None]*len(sentence_spans)

    a_i = 1 # Start ahead of the beginning (eliminate one end-case)
    for i, s in enumerate(sentence_spans):
        s_t = txt_starts[ s['t_start'] ]/fft_step
        while (a_i<len(audio_spans) and audio_spans[a_i]['t_start']<s_t):
            #print(a_i, s_t)
            a_i += 1
        if a_i>=len(audio_spans):
            if s_to_a[i-1] is None:
                s_to_a[i] = a_i-1  # Assigns the last one if unclaimed
            break
        # Find which is closer
        d_this = np.abs(s_t - audio_spans[a_i  ]['t_start'])
        d_prev = np.abs(s_t - audio_spans[a_i-1]['t_start'])
        if d_this<d_prev:
            s_to_a[i] = a_i
        else:
            if i==0 or s_to_a[i-1] != a_i-1:
                s_to_a[i] = a_i-1
            else:
                s_to_a[i-1] = None # Free both
    return s_to_a

s_to_a = sentence_ends_find_nearest_audio_gaps()
for a in s_to_a: print(a,', ', end="")

In [ ]:
# Now use the values in s_to_a to create a new mapping
def s_to_a_to_starts(s_to_a, known_starts=known_starts, txt_err_min_new = 5.):
    new_starts = dict( known_starts )
    
    for s_i, a_i in enumerate(s_to_a):
        if a_i is None:
            continue
        s = sentence_spans[s_i]
        
        start_sec = audio_spans[ a_i ]['t_start']*fft_step
        if start_sec<offset_start:
            start_sec=offset_start

        new_starts[ s['t_start'] ] = ( start_sec, txt_err_min_new )
        #print("%6d -> (%6.2f, %6.2f)" % (s['t_start'], start_sec, txt_err_min_new,))
    return new_starts

# Actually, leave this to one side for a bit...
#txt_starts, txt_err = s_to_a_to_starts(s_to_a)
#s_to_a_to_starts(s_to_a)

txt_starts, txt_err_ignore = create_starts_and_errs( s_to_a_to_starts(s_to_a) )

In [ ]:
plt.plot(txt_starts, 'b')
plt.plot(txt_err*10., 'r')
plt.show()

In [ ]:
word_embedding = create_word_embeddings()
sentence_spans_embedding = create_sentence_embedding(word_embedding)    

plt.figure(figsize=(6,6))
plt.imshow(np.dot(sentence_spans_embedding, sentence_spans_embedding.T), 
           aspect='auto', origin='lower', interpolation='nearest', 
           vmin=-1., vmax=1., cmap='bwr')
plt.grid(True)
plt.show()

In [ ]:
# List neighbouring sentences that have the lowest dot products
def get_sorted_sentence_span_contrasts():
    sentence_span_contrast=[]
    for i in range(0, len(sentence_spans)-2):  
        sentence_span_contrast.append( (
            np.dot(sentence_spans_embedding[i,:], sentence_spans_embedding[i+1,:]),
            i, i+1)
        )
    return sorted(sentence_span_contrast)

get_sorted_sentence_span_contrasts()[:10]

In [ ]:
i=65
print( sentence_spans_embedding[i,:])
print( sentence_spans_embedding[i+1,:])
print( np.dot(sentence_spans_embedding[i,:], sentence_spans_embedding[i+1,:]) )
print(str(i)  +') '+' '.join(sentence_spans[i]['span']) )
print(str(i+1)+') '+' '.join(sentence_spans[i+1]['span']) )

In [ ]:
# Loop through audio_spans within 'striking range' of a given sentence start
#  And find the dot product with that sentence

def return_sentence_vs_audio_dots(i):
    #print(sentence_spans[i])
    t_start = sentence_spans[i]['t_start']
    t_end   = sentence_spans[i]['t_end']
    if t_end+1 < len(txt_starts):
        t_end += 1
    
    #print( txt_starts[ t_start ], txt_err[ t_start ] )
    #print( txt_starts[ t_end ], txt_err[ t_end ] )
    
    t_min = (txt_starts[ t_start ] - txt_err[ t_start ])/fft_step
    t_max = (txt_starts[ t_end ]   + txt_err[ t_end ]  )/fft_step
    #print(t_min, t_max)
    
    a_arr, dots = [],[]
    for a_i, a in enumerate(audio_spans):
        if a['t_start']>t_max or a['t_end']<t_min:
            continue
        a_arr.append(a_i)
        dots.append(np.dot(sentence_spans_embedding[i,:], 
                           audio_spans_embedding[a_i, :] ))
    #print(dots)
    return a_arr, dots

def show_audio_span_matches(i):
    # i set in previous cell : 'contrasting adjacent sentences'
    x,y = return_sentence_vs_audio_dots(i)
    plt.plot(x,y, 'b-*')

    x,y = return_sentence_vs_audio_dots(i+1)
    plt.plot(x,y, 'r-*')

    plt.grid(True)
    plt.title("Looking for blue decline with simultaneous red increase")
    plt.show()
    
show_audio_span_matches(i)

In [ ]:
j=100
#print( audio_spans[j]['t_start'], audio_spans[j]['t_end'], )
play_audio_span(audio_spans[j], autoplay=True)  
#play_audio_span(audio_spans[j+1])

In [ ]:


In [ ]:


In [ ]:


In [ ]:
# Reset the sentence timings roughly
current_starts = dict( known_starts )
#txt_starts, txt_err = create_starts_and_errs(known_starts)

In [ ]:
#current_starts = dict( known_starts )

In [ ]:
# Loop below here

In [ ]:
txt_starts, txt_err = create_starts_and_errs(current_starts)

# Now shift the timings, so that sentence starts hit audio starts exactly
s_to_a = sentence_ends_find_nearest_audio_gaps()

# This starts-dict aligns the sentences to the nearest audio
tmp_starts=s_to_a_to_starts(s_to_a, known_starts=current_starts)

# This doesn't update the txt_errs, which are probably sort-of-right
txt_starts, txt_err_ignore = create_starts_and_errs( tmp_starts )

plt.plot(txt_starts, 'b')
plt.plot(txt_err*10., 'r')
plt.show()

In [ ]:
word_embedding = create_word_embeddings()
sentence_spans_embedding = create_sentence_embedding(word_embedding)    

get_sorted_sentence_span_contrasts()[:10]

In [ ]:
i=64  # This is the last number in the tuple (higher one)
print(str(i-1)+') '+' '.join(sentence_spans[i-1]['span']) )
show_audio_span_matches(i-1)
print(str(i  )+') '+' '.join(sentence_spans[i  ]['span']) )

In [ ]:
j=109  # Look for start of second, red, audio_span (easier to spot)
play_audio_span(audio_spans[j], autoplay=True)

In [ ]:
# Add the j found to the current_starts
current_starts[sentence_spans[i]['t_start']] = (
    audio_spans[ j ]['t_start']*fft_step, txt_err_min 
)

In [ ]:
# Now loop around...

In [ ]:
current_starts

In [ ]:
# Possible criteria : 
#   Must score > 0.4 on blue and red either side
#   Must score less on red and blue either side (i.e. there's a crossing-point)

# Also filter out :
#   One side is much shorter than the other

# Perhaps do some more rounds to work out actual mapping 
#   (i.e. allow cheating/peeking this time)

In [ ]:


In [ ]:
# Now what?  We have an alignment that could let us train the words 
#   (either better, or for the first time)
# Assume we train the words now (and get sentence embeddings)
#   Could we do some kind of nudging into the right alignment?

#   Looking at behaviour of sentence-sentence embedding cosines above
#   it seems like things are initially very vague, but later form 
#   stronger (mainly negative) opinions where the alignment is wrong

In [ ]:


In [ ]:
# All New Approach : audio_spans and sentence_spans : starts and ends

# Find the embedding for the first and last 2 seconds of each audio span
# Find the embedding for the first and last ~2 seconds of each sentence span

# Do the matching as above (but be aware of direction of comparison)

In [ ]:


In [ ]:
matching_period = 5.0 # in seconds

# Audio span embeddings
audio_spans_embedding_starts = get_audio_spans_embedding(
    period_in_sec=matching_period, beginning=True)
audio_spans_embedding_ends   = get_audio_spans_embedding(
    period_in_sec=matching_period, ending=True)

#overall_emb
j=103
print(audio_spans_embedding_starts[j])
print(audio_spans_embedding_ends[j])

In [ ]:
word_embedding = create_word_embeddings(ignore_rare=False, ignore_frequent=False)

In [ ]:
sentence_spans[0]

In [ ]:
def create_sentence_embedding_matching(word_embedding, period_in_sec=None, 
                                          beginning=False, ending=False, debug_i=-1):
    ss_embedding = np.zeros( (len(sentence_spans), embedding_dim))
    for i, s in enumerate(sentence_spans):
        span_emb = np.zeros( (embedding_dim,) )

        if beginning:
            txt_starts_i = s['t_start']  # Use these to get the lengths in seconds
            t_max = txt_starts[ txt_starts_i ] + period_in_sec
            #for j, w in enumerate(s['span']):
            for j, w in enumerate( txt_arr[txt_starts_i:] ):
                if w in word_embedding:
                    span_emb += word_embedding[w]
                    if i==debug_i:
                        print('start : '+w)
                if txt_starts[ txt_starts_i+j ]>t_max:
                    break

        if ending:
            txt_starts_i = s['t_end']  # Use these to get the lengths in seconds
            #print(i, txt_starts_i)
            if txt_starts_i>len(txt_starts)-1: 
                txt_starts_i=len(txt_starts)-1
            t_min = txt_starts[ txt_starts_i ] - period_in_sec
            #for j, w in enumerate(s['span'][::-1]): # Go backwards
            for j, w in enumerate(txt_arr[txt_starts_i::-1]): # Go backwards
                if w in word_embedding:
                    span_emb += word_embedding[w]
                    if i==debug_i:
                        print('end : '+w)
                if txt_starts[ txt_starts_i-j ]<t_min:
                    break

        norm = np.linalg.norm(span_emb)
        if norm>0.:
            span_emb /= norm
        else:
            print("No embeddings found for sentence "+str(i))
            span_emb = word_embedding['marriage']  # Aribitrary to avoid ==0
        ss_embedding[i, :] = span_emb
    return ss_embedding

# Test this once
i=3

#matching_period=4.0
sentence_spans_embedding_starts = create_sentence_embedding_matching(word_embedding, 
    period_in_sec=matching_period, beginning=True, debug_i=i)
sentence_spans_embedding_ends   = create_sentence_embedding_matching(word_embedding, 
    period_in_sec=matching_period, ending=True, debug_i=i)

print(  txt_starts[ sentence_spans[i]['t_end']+1] 
      - txt_starts[ sentence_spans[i]['t_start'] ])
print(sentence_spans_embedding_starts[i])
print(sentence_spans_embedding_ends[i])

In [ ]:


In [ ]:
# List neighbouring sentences that have the lowest dot products
def get_sorted_sentence_span_match_contrasts():
    match_contrast=[]
    for i in range(0, len(sentence_spans)-2):  
        match_contrast.append( (
            np.dot(sentence_spans_embedding_ends[i,:], 
                   sentence_spans_embedding_starts[i+1,:]),
            i, i+1)
        )
    return sorted(match_contrast)

get_sorted_sentence_span_match_contrasts()[:10]

In [ ]:
# Loop through audio_spans within 'striking range' of a given sentence start
#  And find the dot product with that sentence

def return_sentence_vs_audio_dots_match(s_i, beginning=False, ending=False):
    #print(sentence_spans[i])
    t_start = sentence_spans[s_i]['t_start']
    t_end   = sentence_spans[s_i]['t_end']
    if t_end+1 < len(txt_starts):
        t_end += 1
    
    #print( txt_starts[ t_start ], txt_err[ t_start ] )
    #print( txt_starts[ t_end ], txt_err[ t_end ] )
    
    t_min = (txt_starts[ t_start ] - txt_err[ t_start ])/fft_step
    t_max = (txt_starts[ t_end ]   + txt_err[ t_end ]  )/fft_step
    #print(t_min, t_max)
    
    a_arr, dots = [],[]
    for a_i, a in enumerate(audio_spans):
        if a['t_start']>t_max or a['t_end']<t_min:
            continue
        a_arr.append(a_i)
        if ending:    # Match the audio_span ending with sentence ending
            dots.append(np.dot(sentence_spans_embedding_ends[s_i, :], 
                               audio_spans_embedding_ends[a_i, :] ))
        if beginning: # Match the audio_span beginning with sentence beginning
            dots.append(np.dot(sentence_spans_embedding_starts[s_i, :], 
                               audio_spans_embedding_starts[a_i, :] ))
    #print(dots)
    return a_arr, dots

def show_audio_span_period_matches(s_i):
    # i set in previous cell : 'contrasting adjacent sentences'
    x,y = return_sentence_vs_audio_dots_match(s_i, ending=True)
    plt.plot(x,y, 'b-*')

    x,y = return_sentence_vs_audio_dots_match(s_i+1, beginning=True)
    plt.plot(x,y, 'r-*')

    plt.grid(True)
    plt.title("Looking for blue peak followed by a red peak")
    plt.show()
    
show_audio_span_period_matches(i)

In [ ]:


In [ ]:


In [ ]:
# Reset the sentence timings roughly
current_starts = dict( known_starts )

In [ ]:
# Loop starts here

In [ ]:
txt_starts, txt_err = create_starts_and_errs(current_starts)

# Now shift the timings, so that sentence starts hit audio starts exactly
s_to_a = sentence_ends_find_nearest_audio_gaps()

# This starts-dict aligns the sentences to the nearest audio
tmp_starts=s_to_a_to_starts(s_to_a, known_starts=current_starts)

# This doesn't update the txt_errs, which are probably sort-of-right
txt_starts, txt_err_ignore = create_starts_and_errs( tmp_starts )

plt.plot(txt_starts, 'b')
plt.plot(txt_err*10., 'r')
plt.show()

In [ ]:
word_embedding = create_word_embeddings(ignore_rare=False, ignore_frequent=False)

sentence_spans_embedding_starts = create_sentence_embedding_matching(word_embedding, 
    period_in_sec=matching_period, beginning=True)
sentence_spans_embedding_ends   = create_sentence_embedding_matching(word_embedding, 
    period_in_sec=matching_period, ending=True)

get_sorted_sentence_span_match_contrasts()[:15]

In [ ]:
i=40  # This is the last number in the tuple (higher one)
print(str(i-1)+') '+' '.join(sentence_spans[i-1]['span']) )
show_audio_span_period_matches(i-1)
print(str(i  )+') '+' '.join(sentence_spans[i  ]['span']) )

In [ ]:
j=64  # Look for start of second, red, audio_span (easier to spot)
play_audio_span(audio_spans[j], autoplay=True)

In [ ]:
# Add the j found to the current_starts
current_starts[sentence_spans[i]['t_start']] = (
    audio_spans[ j ]['t_start']*fft_step, txt_err_min 
)

In [ ]:
# Now loop around

In [ ]: