In [1]:
import sys
import numpy as np
import random
import itertools
import heapq as hq

In [2]:
sys.path.append('src/')

In [3]:
#import pyximport; pyximport.install()
from inference_lv import do_inference_list_viterbi

In [4]:
#from inference import do_inference_brute_force

In [5]:
random.seed(0)
np.random.seed(0)

In [6]:
class HeapItem:  # an item in heapq (min-heap)
    def __init__(self, priority, task):
        self.priority = priority
        self.task = task
        self.string = str(priority) + ': ' + str(task)

    def __lt__(self, other):
        return self.priority < other.priority

    def __repr__(self):
        return self.string

    def __str__(self):
        return self.string

Brute force search.


In [7]:
def brute_force(ps, L, M, unary_params, pw_params, unary_features, pw_features, debug=True, top=5):
    Cu = np.zeros(M, dtype=np.float)       # unary_param[p] x unary_features[p]
    Cp = np.zeros((M, M), dtype=np.float)  # pw_param[pi, pj] x pw_features[pi, pj]
    # a intermediate POI should NOT be the start POI, NO self-loops
    for pi in range(M):
        Cu[pi] = np.dot(unary_params[pi, :], unary_features[pi, :])
        for pj in range(M):
            Cp[pi, pj] = -np.inf if (pj == ps or pi == pj) else np.dot(pw_params[pi, pj, :], pw_features[pi, pj, :])

    Q = []
    poi_set = [p for p in range(M) if p != ps]
    for x in itertools.product(poi_set, repeat=L-1):
        y = [ps] + list(x)
        score = 0

        for j in range(1, L):
            score += Cp[y[j - 1], y[j]] + Cu[y[j]]

        if len(Q) < top:
            hq.heappush(Q, HeapItem(score, np.array(y)))
        else:
            hq.heappushpop(Q, HeapItem(score, np.array(y)))  # pop the smallest, then push

    results = []
    scores = []
    while len(Q) > 0:
        hterm = hq.heappop(Q)
        results.append(hterm.task)
        scores.append(hterm.priority)

    # reverse the order: smallest -> largest => largest -> smallest
    results.reverse()
    scores.reverse()

    if debug is True:
        for score, y in zip(scores, results):
            print(y, score)

    return results

The list Viterbi algorithm described in paper Sequentially finding the N-best list in hidden Markov models (2001).


In [36]:
def list_viterbi_2001(ps, L, M, unary_params, pw_params, unary_features, pw_features, top=10):
    assert(L > 1)
    assert(M >= L)
    assert(ps >= 0)
    assert(ps < M)
    assert(top > 0)
    
    maxIter = 1e6
    
    Cu = np.zeros(M, dtype=np.float)
    Cp = np.zeros((M, M), dtype=np.float)
    
    for pi in range(M):
        Cu[pi] = np.dot(unary_params[pi, :], unary_features[pi, :])
        for pj in range(M):
            Cp[pi, pj] = -np.inf if (pj == ps or pi == pj) else np.dot(pw_params[pi, pj, :], pw_features[pi, pj, :])
            
    Alpha = np.zeros((L, M), dtype=np.float)
    Beta  = np.zeros((L, M), dtype=np.float)
    
    for pj in range(M): Alpha[1, pj] = Cp[ps, pj] + Cu[pj]
    for t in range(2, L):
        for pj in range(M):
            Alpha[t, pj] = np.max([Alpha[t-1, pi] + Cp[pi, pj] + Cu[pj] for pi in range(M)])
    
    for pi in range(M): Beta[L-1, pi] = 0 
    for t in range(L-1, 1, -1):
        for pi in range(M):
            Beta[t-1, pi] = np.max([Cp[pi, pj] + Cu[pj] + Beta[t, pj] for pj in range(M)])
    Beta[0, ps] = np.max([Cp[ps, pj] + Cu[pj] + Beta[1, pj] for pj in range(M)])
    
    Fp = np.zeros((L-1, M, M), dtype=np.float)
    for t in range(L-1):
        for pi in range(M):
            for pj in range(M):
                Fp[t, pi, pj] = Alpha[t, pi] + Cp[pi, pj] + Cu[pj] + Beta[t+1, pj]
                
    y_best = np.ones(L, dtype=np.int) * (-1)
    y_best[0] = ps
    for t in range(1, L): y_best[t] = np.argmax(Fp[t-1, y_best[t-1], :])
    
    Q = []
    priority = -np.max(Alpha[L-1, :])
    print(-priority)
    print(Alpha[L-1, y_best[L-1]])
    print(Fp[L-2, y_best[L-2], y_best[L-1]])
    partition_index = -1
    exclude_set = set()  
    hq.heappush(Q, HeapItem(priority, (y_best, partition_index, exclude_set)))
    
    results = []
    k = 0; y_last = None
    while len(Q) > 0 and k < maxIter:
        hitem = hq.heappop(Q)
        k_priority = hitem.priority
        (k_best, k_partition_index, k_exclude_set) = hitem.task
        k += 1; y_last = k_best
        print('OUT: %s, %.6f' % (k_best, -k_priority))
        
        if len(set(k_best)) == L:
            results.append(k_best); top -= 1
            if top == 0: return results

        partition_index_start = 1
        if k_partition_index > 0:
            assert(k_partition_index < L)
            partition_index_start = k_partition_index

        for parix in range(partition_index_start, L):
            new_best = np.zeros(L, dtype=np.int) * (-1)
            new_best[:parix] = k_best[:parix]
            
            new_exclude_set = set({k_best[parix]})
            if parix == partition_index_start: new_exclude_set = new_exclude_set | k_exclude_set
            candidate_points = [p for p in range(M) if p not in new_exclude_set]
            if len(candidate_points) == 0: continue
            candidate_maxix = np.argmax([Fp[parix-1, k_best[parix-1], p] for p in candidate_points])
            new_best[parix] = candidate_points[candidate_maxix]
            
            for pk in range(parix+1, L): 
                new_best[pk] = np.argmax([Fp[pk-1, new_best[pk-1], p] for p in range(M)])
            
            new_priority = (-k_priority)
            new_priority += Fp[parix-1, k_best[parix-1], new_best[parix]] - Fp[parix-1, k_best[parix-1], k_best[parix]]
            new_priority *= -1.0

            hq.heappush(Q, HeapItem(new_priority, (new_best, parix, new_exclude_set)))
            #print('IN : %s, %.6f' % (new_best, -new_priority))
            
    if len(Q) == 0:
        sys.stderr.write('WARN: empty queue, return the last one\n')
    results.append(y_last); top -= 1
    while len(Q) > 0 and top > 0:
        hitem = hq.heappop(Q)
        results.append(hitem.task[0]); top -= 1
    return results

The list Viterbi algorithm described in paper List Viterbi decoding algorithms with applications (1994).


In [38]:
def list_viterbi_1994(ps, L, M, unary_params, pw_params, unary_features, pw_features, top=10):
    assert(L > 1)
    assert(M >= L)
    assert(ps >= 0)
    assert(ps < M)
    assert(top > 0)
    
    maxIter = 1e6
    
    Cu = np.zeros(M, dtype=np.float)
    Cp = np.zeros((M, M), dtype=np.float)
    
    for pi in range(M):
        Cu[pi] = np.dot(unary_params[pi, :], unary_features[pi, :])
        for pj in range(M):
            Cp[pi, pj] = -np.inf if (pj == ps or pi == pj) else np.dot(pw_params[pi, pj, :], pw_features[pi, pj, :])
            
    Alpha = np.zeros((L, M), dtype=np.float)
    Beta  = np.zeros((L, M), dtype=np.float)
    
    for pj in range(M): Alpha[1, pj] = Cp[ps, pj] + Cu[pj]
    for t in range(2, L):
        for pj in range(M):
            Alpha[t, pj] = np.max([Alpha[t-1, pi] + Cp[pi, pj] + Cu[pj] for pi in range(M)])
    
    for pi in range(M): Beta[L-1, pi] = 0 
    for t in range(L-1, 1, -1):
        for pi in range(M):
            Beta[t-1, pi] = np.max([Cp[pi, pj] + Cu[pj] + Beta[t, pj] for pj in range(M)])
    Beta[0, ps] = np.max([Cp[ps, pj] + Cu[pj] + Beta[1, pj] for pj in range(M)])
    
    Fp = np.zeros((L-1, M, M), dtype=np.float)
    for t in range(L-1):
        for pi in range(M):
            for pj in range(M):
                Fp[t, pi, pj] = Alpha[t, pi] + Cp[pi, pj] + Cu[pj] + Beta[t+1, pj]
                
    y_best = np.ones(L, dtype=np.int) * (-1)
    y_best[0] = ps
    for t in range(1, L): y_best[t] = np.argmax(Fp[t-1, y_best[t-1], :])
    
    Q = []
    priority = -np.max(Alpha[L-1, :])
    print(-priority)
    print(Alpha[L-1, y_best[L-1]])
    print(Fp[L-2, y_best[L-2], y_best[L-1]])
    partition_index = -1
    exclude_set = set()  
    hq.heappush(Q, HeapItem(priority, (y_best, partition_index, exclude_set)))
    
    print(y_best)
    print('-----------------------')
    y1 = np.ones(L, dtype=np.int) * -1
    y1[0] = y_best[0]
    y1[2:] = y_best[2:]
    cpoints = [p for p in range(M) if p != y_best[1]]
    y1scores = [Cp[y1[0], p] + Cu[p] + Cp[p, y1[2]] for p in cpoints]
    y1ix = np.argmax(y1scores)
    y1[1] = cpoints[y1ix]
    print(y1)
    print('-----------------------')
    score1 = -priority + Fp[1, y1[1], y_best[2]] - Fp[1, y_best[1], y_best[2]]
    score2 = 0
    for j in range(L-1):
        ss = y1[j]
        tt = y1[j+1]
        score2 += Cp[ss, tt] + Cu[tt]
    #print(-priority)
    print(score1)
    print(score2)
    print('-----------------------')
    
    results = []
    k = 0; y_last = None
    while len(Q) > 0 and k < maxIter:
        hitem = hq.heappop(Q)
        k_priority = hitem.priority
        (k_best, k_partition_index, k_exclude_set) = hitem.task
        k += 1; y_last = k_best
        print('OUT: %s, %.6f' % (k_best, -k_priority))
        
        if len(set(k_best)) == L:
            results.append(k_best); top -= 1
            if top == 0: return results

        partition_index_start = L-1
        if k_partition_index > 0:
            assert(k_partition_index < L)
            partition_index_start = k_partition_index

        for parix in range(partition_index_start, 0, -1):
            new_best = np.zeros(L, dtype=np.int) * (-1)
            
            # new_best[parix+1:]
            new_best[parix+1:] = k_best[parix+1:]
            new_best[0] = ps
            
            # new_best[parix]
            new_exclude_set = set({k_best[parix]})
            if parix == partition_index_start: new_exclude_set = new_exclude_set | k_exclude_set
            candidate_points = [p for p in range(M) if p not in new_exclude_set]
            if len(candidate_points) == 0: continue
            if parix == 1:
                candidate_maxix = np.argmax([Cp[ps, p] + Cu[p] + Cp[p, k_best[2]] for p in candidate_points])
                new_best[parix] = candidate_points[candidate_maxix]
            elif parix == L-1:
                candidate_maxix = np.argmax([Alpha[L-1, p] for p in candidate_points])
                new_best[parix] = candidate_points[candidate_maxix]
            else:
                candidate_maxix = np.argmax([Fp[parix, p, k_best[parix+1]] for p in candidate_points])
                new_best[parix] = candidate_points[candidate_maxix]
                
            # new_best[:parix]
            if parix > 1:
                for pk in range(parix-1, 0, -1): 
                    if pk == 1:
                        new_best[pk] = np.argmax([Cp[ps, p] + Cu[p] + Cp[p, new_best[2]] for p in range(M)])
                    else:
                        new_best[pk] = np.argmax([Fp[pk, p, new_best[pk+1]] for p in range(M)])
            
            # sequence score
            new_priority = (-k_priority)
            if parix == L-1:
                new_priority += Alpha[L-1, new_best[L-1]] - Alpha[L-1, k_best[L-1]]
            else:
                new_priority += Fp[parix, new_best[parix], k_best[parix+1]] - Fp[parix, k_best[parix], k_best[parix+1]]
            new_priority *= -1.0
            new_score = 0
            for j in range(L-1):
                ss = new_best[j]
                tt = new_best[j+1]
                new_score += Cp[ss, tt] + Cu[tt]
            #print(-new_priority); print(new_score) 
            assert(np.isclose(-new_priority, new_score))
            new_priority = -new_score

            hq.heappush(Q, HeapItem(new_priority, (new_best, parix, new_exclude_set)))
            #print('IN : %s, %.6f, %d' % (new_best, -new_priority, parix))
            
    if len(Q) == 0:
        sys.stderr.write('WARN: empty queue, return the last one\n')
    results.append(y_last); top -= 1
    while len(Q) > 0 and top > 0:
        hitem = hq.heappop(Q)
        results.append(hitem.task[0]); top -= 1
    return results

In [26]:
#M0 = 90
M0 = 10
n_u = 10
n_p = 5
w_u = np.random.rand(M0*n_u).reshape(M0, n_u)
f_u = np.random.rand(M0*n_u).reshape(M0, n_u)
w_p = np.random.rand(M0*M0*n_p).reshape(M0, M0, n_p)
f_p = np.random.rand(M0*M0*n_p).reshape(M0, M0, n_p)
ps0 = np.random.choice(np.arange(M0))
L0 = np.random.choice(np.arange(2, 8))
#L0 = 10
indices0 = [x for x in range(M0) if x != ps0]; np.random.shuffle(indices0)
y_true0 = [ps0] + indices0[:L0-1]
y_true_list0 = [y_true0]
for j in range(8):
    np.random.shuffle(indices0); y_true_list0.append([ps0] + indices0[:L0-1])
print(ps0, L0)


8 6

In [27]:
brute_force(ps0, L0, M0, w_u, w_p, f_u, f_p, debug=True)


[8 3 6 5 3 6] 24.8025703379
[8 3 5 3 2 0] 24.7780823383
[8 3 2 3 2 0] 24.7382093348
[8 3 5 3 5 3] 24.7239254745
[8 5 3 5 3 6] 24.705130814
Out[27]:
[array([8, 3, 6, 5, 3, 6]),
 array([8, 3, 5, 3, 2, 0]),
 array([8, 3, 2, 3, 2, 0]),
 array([8, 3, 5, 3, 5, 3]),
 array([8, 5, 3, 5, 3, 6])]

In [28]:
do_inference_list_viterbi(ps0, L0, M0, w_u, w_p, f_u, f_p)


best seq: [8 3 6 5 3 6]
score([8 3 6 5 3 6]): 24.802570
f_{8,3)}: 24.802570
f_{3,6)}: 24.802570
f_{6,5)}: 24.802570
f_{5,3)}: 24.802570
f_{3,6)}: 24.802570
OUT: [8 3 6 5 3 6], 24.80257
IN : [8 5 3 5 3 6], 24.70513
IN : [8 3 5 3 2 0], 24.77808
IN : [8 3 6 3 2 0], 24.44477
IN : [8 3 6 5 9 4], 23.74422
IN : [8 3 6 5 3 2], 24.69850
OUT: [8 3 5 3 2 0], 24.77808
IN : [8 3 2 3 2 0], 24.73821
IN : [8 3 5 4 3 6], 23.60944
IN : [8 3 5 3 5 3], 24.72392
IN : [8 3 5 3 2 3], 24.68405
OUT: [8 3 2 3 2 0], 24.73821
IN : [8 3 9 4 3 6], 24.26520
IN : [8 3 2 5 3 6], 24.59653
IN : [8 3 2 3 5 3], 24.68405
IN : [8 3 2 3 2 3], 24.64418
OUT: [8 3 5 3 5 3], 24.72392
IN : [8 3 5 3 9 4], 24.42065
IN : [8 3 5 3 5 4], 23.76490
OUT: [8 5 3 5 3 6], 24.70513
IN : [8 9 4 3 2 0], 24.47678
IN : [8 5 4 3 2 0], 23.52841
IN : [8 5 3 6 5 3], 24.66739
IN : [8 5 3 5 9 4], 23.64678
IN : [8 5 3 5 3 2], 24.60106
OUT: [8 3 6 5 3 2], 24.69850
IN : [8 3 6 5 3 5], 24.34892
OUT: [8 3 5 3 2 3], 24.68405
IN : [8 3 5 3 2 5], 24.14288
OUT: [8 3 2 3 5 3], 24.68405
IN : [8 3 2 3 9 4], 24.38078
IN : [8 3 2 3 5 4], 23.72502
OUT: [8 5 3 6 5 3], 24.66739
IN : [8 5 3 2 3 6], 24.66526
IN : [8 5 3 6 3 6], 24.37181
IN : [8 5 3 6 5 4], 23.70836
OUT: [8 5 3 2 3 6], 24.66526
IN : [8 5 3 9 4 3], 24.13001
IN : [8 5 3 2 0 4], 24.58507
IN : [8 5 3 2 3 2], 24.56119
OUT: [8 3 2 3 2 3], 24.64418
IN : [8 3 2 3 2 5], 24.10301
OUT: [8 5 3 5 3 2], 24.60106
IN : [8 5 3 5 3 5], 24.25148
OUT: [8 3 2 5 3 6], 24.59653
IN : [8 3 2 0 3 6], 24.49060
IN : [8 3 2 5 9 4], 23.53818
IN : [8 3 2 5 3 2], 24.49246
OUT: [8 5 3 2 0 4], 24.58507
24.585073471069336
IN : [8 5 3 2 5 3], 24.46135
IN : [8 5 3 2 0 3], 24.35541
OUT: [8 5 3 2 3 2], 24.56119
IN : [8 5 3 2 3 5], 24.21161
OUT: [8 3 2 5 3 2], 24.49246
IN : [8 3 2 5 3 5], 24.14288
OUT: [8 3 2 0 3 6], 24.49060
IN : [8 3 2 6 5 3], 24.07071
IN : [8 3 2 0 4 3], 24.44841
IN : [8 3 2 0 3 2], 24.38653
OUT: [8 9 4 3 2 0], 24.47678
24.476781845092773
IN : [8 6 5 3 2 0], 24.33257
IN : [8 9 5 3 2 0], 24.26275
IN : [8 9 4 5 3 6], 24.22352
IN : [8 9 4 3 5 3], 24.42262
IN : [8 9 4 3 2 3], 24.38275
OUT: [8 5 3 2 5 3], 24.46135
IN : [8 5 3 2 9 4], 23.71866
IN : [8 5 3 2 5 4], 23.50232
OUT: [8 3 2 0 4 3], 24.44841
IN : [8 3 2 0 9 4], 24.09902
IN : [8 3 2 0 4 6], 24.13947
OUT: [8 3 6 3 2 0], 24.44477
IN : [8 3 6 9 4 3], 24.07659
IN : [8 3 6 3 5 3], 24.39061
IN : [8 3 6 3 2 3], 24.35074
OUT: [8 9 4 3 5 3], 24.42262
IN : [8 9 4 3 9 4], 24.11935
IN : [8 9 4 3 5 4], 23.46360
OUT: [8 3 5 3 9 4], 24.42065
IN : [8 3 5 3 6 3], 24.39061
IN : [8 3 5 3 9 5], 23.59752
OUT: [8 3 5 3 6 3], 24.39061
IN : [8 3 5 3 7 5], 23.37896
IN : [8 3 5 3 6 5], 24.34892
OUT: [8 3 6 3 5 3], 24.39061
IN : [8 3 6 3 9 4], 24.08734
IN : [8 3 6 3 5 4], 23.43158
OUT: [8 3 2 0 3 2], 24.38653
IN : [8 3 2 0 3 5], 24.03695
OUT: [8 9 4 3 2 3], 24.38275
IN : [8 9 4 3 2 5], 23.84158
OUT: [8 3 2 3 9 4], 24.38078
IN : [8 3 2 3 6 3], 24.35074
IN : [8 3 2 3 9 5], 23.55765
OUT: [8 5 3 6 3 6], 24.37181
IN : [8 5 3 6 9 4], 24.21326
IN : [8 5 3 6 3 2], 24.26775
OUT: [8 5 3 2 0 3], 24.35541
IN : [8 5 3 2 0 9], 23.77742
OUT: [8 3 6 3 2 3], 24.35074
IN : [8 3 6 3 2 5], 23.80957
OUT: [8 3 2 3 6 3], 24.35074
IN : [8 3 2 3 7 5], 23.33908
IN : [8 3 2 3 6 5], 24.30905
OUT: [8 3 6 5 3 5], 24.34892
IN : [8 3 6 5 3 9], 24.19649
OUT: [8 3 5 3 6 5], 24.34892
IN : [8 3 5 3 6 9], 24.04563
OUT: [8 6 5 3 2 0], 24.33257
24.33257293701172
IN : [8 2 3 5 3 6], 24.25605
IN : [8 6 3 5 3 6], 23.98285
IN : [8 6 5 4 3 6], 23.16393
IN : [8 6 5 3 5 3], 24.27842
IN : [8 6 5 3 2 3], 24.23854
OUT: [8 3 2 3 6 5], 24.30905
IN : [8 3 2 3 6 9], 24.00576
OUT: [8 6 5 3 5 3], 24.27842
IN : [8 6 5 3 9 4], 23.97514
IN : [8 6 5 3 5 4], 23.31939
OUT: [8 5 3 6 3 2], 24.26775
IN : [8 5 3 6 3 5], 23.91817
OUT: [8 3 9 4 3 6], 24.26520
IN : [8 3 7 5 3 6], 23.83260
IN : [8 3 9 5 3 6], 24.05116
IN : [8 3 9 4 5 3], 23.94970
IN : [8 3 9 4 3 2], 24.16113
OUT: [8 9 5 3 2 0], 24.26275
24.262752532958984
IN : [8 9 3 5 3 6], 23.81970
IN : [8 9 5 4 3 6], 23.09411
IN : [8 9 5 3 5 3], 24.20860
IN : [8 9 5 3 2 3], 24.16872
OUT: [8 2 3 5 3 6], 24.25605
IN : [8 4 3 5 3 6], 24.16379
IN : [8 2 5 3 2 0], 24.10629
IN : [8 2 3 6 5 3], 24.21830
IN : [8 2 3 5 9 4], 23.19769
IN : [8 2 3 5 3 2], 24.15198
OUT: [8 5 3 5 3 5], 24.25148
IN : [8 5 3 5 3 9], 24.09905
OUT: [8 6 5 3 2 3], 24.23854
IN : [8 6 5 3 2 5], 23.69737
OUT: [8 9 4 5 3 6], 24.22352
24.223520278930664
IN : [8 9 4 6 5 3], 24.07594
IN : [8 9 4 5 9 4], 23.16517
IN : [8 9 4 5 3 2], 24.11945
OUT: [8 2 3 6 5 3], 24.21830
IN : [8 2 3 2 3 6], 24.21617
IN : [8 2 3 6 3 6], 23.92273
IN : [8 2 3 6 5 4], 23.25927
OUT: [8 2 3 2 3 6], 24.21617
IN : [8 2 3 9 4 3], 23.68093
IN : [8 2 3 2 0 4], 24.13599
IN : [8 2 3 2 3 2], 24.11210
OUT: [8 5 3 6 9 4], 24.21326
24.2132568359375
IN : [8 5 3 6 2 0], 23.91760
IN : [8 5 3 6 9 5], 23.39012
OUT: [8 5 3 2 3 5], 24.21161
IN : [8 5 3 2 3 9], 24.05918
OUT: [8 9 5 3 5 3], 24.20860
IN : [8 9 5 3 9 4], 23.90532
IN : [8 9 5 3 5 4], 23.24957
OUT: [8 3 6 5 3 9], 24.19649
IN : [8 3 6 5 3 4], 23.43136
OUT: [8 9 5 3 2 3], 24.16872
IN : [8 9 5 3 2 5], 23.62755
OUT: [8 4 3 5 3 6], 24.16379
IN : [8 7 5 3 2 0], 23.27900
IN : [8 4 5 3 2 0], 23.90246
IN : [8 4 3 6 5 3], 24.12605
IN : [8 4 3 5 9 4], 23.10544
IN : [8 4 3 5 3 2], 24.05972
OUT: [8 3 9 4 3 2], 24.16113
IN : [8 3 9 4 3 5], 23.81155
OUT: [8 2 3 5 3 2], 24.15198
IN : [8 2 3 5 3 5], 23.80240
OUT: [8 3 5 3 2 5], 24.14288
IN : [8 3 5 3 2 6], 24.10846
OUT: [8 3 2 5 3 5], 24.14288
IN : [8 3 2 5 3 9], 23.99045
OUT: [8 3 2 0 4 6], 24.13947
24.139469146728516
IN : [8 3 2 0 4 9], 23.91531
OUT: [8 2 3 2 0 4], 24.13599
IN : [8 2 3 2 5 3], 24.01226
IN : [8 2 3 2 0 3], 23.90633
OUT: [8 5 3 9 4 3], 24.13001
IN : [8 5 3 7 5 3], 23.69742
IN : [8 5 3 9 5 3], 23.91598
IN : [8 5 3 9 4 6], 23.82107
OUT: [8 4 3 6 5 3], 24.12605
IN : [8 4 3 2 3 6], 24.12392
IN : [8 4 3 6 3 6], 23.83048
IN : [8 4 3 6 5 4], 23.16702
OUT: [8 4 3 2 3 6], 24.12392
IN : [8 4 3 9 4 3], 23.58867
IN : [8 4 3 2 0 4], 24.04374
IN : [8 4 3 2 3 2], 24.01985
OUT: [8 9 4 5 3 2], 24.11945
24.11945152282715
IN : [8 9 4 5 3 5], 23.76987
OUT: [8 9 4 3 9 4], 24.11935
IN : [8 9 4 3 6 3], 24.08931
IN : [8 9 4 3 9 5], 23.29622
OUT: [8 2 3 2 3 2], 24.11210
IN : [8 2 3 2 3 5], 23.76253
OUT: [8 3 5 3 2 6], 24.10846
IN : [8 3 5 3 2 4], 23.77819
OUT: [8 2 5 3 2 0], 24.10629
IN : [8 2 0 4 3 6], 24.02040
IN : [8 2 5 4 3 6], 22.93765
IN : [8 2 5 3 5 3], 24.05213
IN : [8 2 5 3 2 3], 24.01226
OUT: [8 3 2 3 2 5], 24.10301
IN : [8 3 2 3 2 6], 24.06858
OUT: [8 5 3 5 3 9], 24.09905
IN : [8 5 3 5 3 4], 23.33392
OUT: [8 3 2 0 9 4], 24.09902
24.09902000427246
IN : [8 3 2 0 5 3], 23.83274
IN : [8 3 2 0 9 5], 23.27589
OUT: [8 9 4 3 6 3], 24.08931
IN : [8 9 4 3 7 5], 23.07766
IN : [8 9 4 3 6 5], 24.04762
OUT: [8 3 6 3 9 4], 24.08734
IN : [8 3 6 3 6 3], 24.05729
IN : [8 3 6 3 9 5], 23.26420
OUT: [8 3 6 9 4 3], 24.07659
IN : [8 3 6 2 3 6], 23.95876
IN : [8 3 6 9 5 3], 23.86256
IN : [8 3 6 9 4 6], 23.76765
OUT: [8 9 4 6 5 3], 24.07594
24.075937271118164
Out[28]:
[array([8, 5, 3, 2, 0, 4]),
 array([8, 9, 4, 3, 2, 0]),
 array([8, 6, 5, 3, 2, 0]),
 array([8, 9, 5, 3, 2, 0]),
 array([8, 9, 4, 5, 3, 6]),
 array([8, 5, 3, 6, 9, 4]),
 array([8, 3, 2, 0, 4, 6]),
 array([8, 9, 4, 5, 3, 2]),
 array([8, 3, 2, 0, 9, 4]),
 array([8, 9, 4, 6, 5, 3])]

In [37]:
list_viterbi_2001(ps0, L0, M0, w_u, w_p, f_u, f_p)


24.8025703379
24.8025703379
24.8025703379
OUT: [8 3 6 5 3 6], 24.802570
OUT: [8 3 5 3 2 0], 24.778082
OUT: [8 3 2 3 2 0], 24.738209
OUT: [8 3 5 3 5 3], 24.723925
OUT: [8 5 3 5 3 6], 24.705131
OUT: [8 3 6 5 3 2], 24.698501
OUT: [8 3 5 3 2 3], 24.684052
OUT: [8 3 2 3 5 3], 24.684052
OUT: [8 5 3 6 5 3], 24.667385
OUT: [8 5 3 2 3 6], 24.665258
OUT: [8 3 2 3 2 3], 24.644179
OUT: [8 5 3 5 3 2], 24.601061
OUT: [8 3 2 5 3 6], 24.596530
OUT: [8 5 3 2 0 4], 24.585074
OUT: [8 5 3 2 3 2], 24.561188
OUT: [8 3 2 5 3 2], 24.492461
OUT: [8 3 2 0 3 6], 24.490598
OUT: [8 9 4 3 2 0], 24.476782
OUT: [8 5 3 2 5 3], 24.461345
OUT: [8 3 2 0 4 3], 24.448411
OUT: [8 3 6 3 2 0], 24.444766
OUT: [8 9 4 3 5 3], 24.422625
OUT: [8 3 5 3 9 4], 24.420653
OUT: [8 3 6 3 5 3], 24.390609
OUT: [8 3 5 3 6 3], 24.390609
OUT: [8 3 2 0 3 2], 24.386529
OUT: [8 9 4 3 2 3], 24.382752
OUT: [8 3 2 3 9 4], 24.380780
OUT: [8 5 3 6 3 6], 24.371814
OUT: [8 5 3 2 0 3], 24.355413
OUT: [8 3 6 3 2 3], 24.350736
OUT: [8 3 2 3 6 3], 24.350736
OUT: [8 3 6 5 3 5], 24.348923
OUT: [8 3 5 3 6 5], 24.348923
OUT: [8 6 5 3 2 0], 24.332573
OUT: [8 3 2 3 6 5], 24.309050
OUT: [8 6 5 3 5 3], 24.278416
OUT: [8 5 3 6 3 2], 24.267745
OUT: [8 3 9 4 3 6], 24.265195
OUT: [8 9 5 3 2 0], 24.262752
OUT: [8 2 3 5 3 6], 24.256046
OUT: [8 5 3 5 3 5], 24.251484
OUT: [8 6 5 3 2 3], 24.238543
OUT: [8 9 4 5 3 6], 24.223520
OUT: [8 2 3 6 5 3], 24.218300
OUT: [8 2 3 2 3 6], 24.216173
OUT: [8 5 3 6 9 4], 24.213257
OUT: [8 5 3 2 3 5], 24.211611
OUT: [8 9 5 3 5 3], 24.208596
OUT: [8 3 6 5 3 9], 24.196489
OUT: [8 9 5 3 2 3], 24.168723
OUT: [8 4 3 5 3 6], 24.163794
OUT: [8 3 9 4 3 2], 24.161126
OUT: [8 2 3 5 3 2], 24.151977
OUT: [8 3 5 3 2 5], 24.142883
OUT: [8 3 2 5 3 5], 24.142883
OUT: [8 3 2 0 4 6], 24.139469
OUT: [8 2 3 2 0 4], 24.135990
OUT: [8 5 3 9 4 3], 24.130010
OUT: [8 4 3 6 5 3], 24.126048
OUT: [8 4 3 2 3 6], 24.123921
OUT: [8 9 4 5 3 2], 24.119451
OUT: [8 9 4 3 9 4], 24.119353
OUT: [8 2 3 2 3 2], 24.112104
OUT: [8 3 5 3 2 6], 24.108458
OUT: [8 2 5 3 2 0], 24.106290
OUT: [8 3 2 3 2 5], 24.103010
OUT: [8 5 3 5 3 9], 24.099049
OUT: [8 3 2 0 9 4], 24.099021
OUT: [8 9 4 3 6 3], 24.089309
OUT: [8 3 6 3 9 4], 24.087336
OUT: [8 3 6 9 4 3], 24.076594
OUT: [8 9 4 6 5 3], 24.075937
Out[37]:
[array([8, 5, 3, 2, 0, 4]),
 array([8, 9, 4, 3, 2, 0]),
 array([8, 6, 5, 3, 2, 0]),
 array([8, 9, 5, 3, 2, 0]),
 array([8, 9, 4, 5, 3, 6]),
 array([8, 5, 3, 6, 9, 4]),
 array([8, 3, 2, 0, 4, 6]),
 array([8, 9, 4, 5, 3, 2]),
 array([8, 3, 2, 0, 9, 4]),
 array([8, 9, 4, 6, 5, 3])]

In [39]:
list_viterbi_1994(ps0, L0, M0, w_u, w_p, f_u, f_p)


24.8025703379
24.8025703379
24.8025703379
[8 3 6 5 3 6]
-----------------------
[8 4 6 5 3 6]
-----------------------
23.8171062651
23.8171062651
-----------------------
OUT: [8 3 6 5 3 6], 24.802570
OUT: [8 3 5 3 2 0], 24.778082
OUT: [8 3 2 3 2 0], 24.738209
OUT: [8 3 5 3 5 3], 24.723925
OUT: [8 5 3 5 3 6], 24.705131
OUT: [8 3 6 5 3 2], 24.698501
OUT: [8 3 5 3 2 3], 24.684052
OUT: [8 3 2 3 5 3], 24.684052
OUT: [8 5 3 6 5 3], 24.667385
OUT: [8 5 3 2 3 6], 24.665258
OUT: [8 3 2 3 2 3], 24.644179
OUT: [8 5 3 5 3 2], 24.601061
OUT: [8 3 2 5 3 6], 24.596530
OUT: [8 5 3 2 0 4], 24.585074
OUT: [8 5 3 2 3 2], 24.561188
OUT: [8 3 2 5 3 2], 24.492461
OUT: [8 3 2 0 3 6], 24.490598
OUT: [8 9 4 3 2 0], 24.476782
OUT: [8 5 3 2 5 3], 24.461345
OUT: [8 3 2 0 4 3], 24.448411
OUT: [8 3 6 3 2 0], 24.444766
OUT: [8 9 4 3 5 3], 24.422625
OUT: [8 3 5 3 9 4], 24.420653
OUT: [8 3 6 3 5 3], 24.390609
OUT: [8 3 5 3 6 3], 24.390609
OUT: [8 3 2 0 3 2], 24.386529
OUT: [8 9 4 3 2 3], 24.382752
OUT: [8 3 2 3 9 4], 24.380780
OUT: [8 5 3 6 3 6], 24.371814
OUT: [8 5 3 2 0 3], 24.355413
OUT: [8 3 6 3 2 3], 24.350736
OUT: [8 3 2 3 6 3], 24.350736
OUT: [8 3 6 5 3 5], 24.348923
OUT: [8 3 5 3 6 5], 24.348923
OUT: [8 6 5 3 2 0], 24.332573
OUT: [8 3 2 3 6 5], 24.309050
OUT: [8 6 5 3 5 3], 24.278416
OUT: [8 5 3 6 3 2], 24.267745
OUT: [8 3 9 4 3 6], 24.265195
OUT: [8 9 5 3 2 0], 24.262752
OUT: [8 2 3 5 3 6], 24.256046
OUT: [8 5 3 5 3 5], 24.251484
OUT: [8 6 5 3 2 3], 24.238543
OUT: [8 9 4 5 3 6], 24.223520
OUT: [8 2 3 6 5 3], 24.218300
OUT: [8 2 3 2 3 6], 24.216173
OUT: [8 5 3 6 9 4], 24.213257
OUT: [8 5 3 2 3 5], 24.211611
OUT: [8 9 5 3 5 3], 24.208596
OUT: [8 3 6 5 3 9], 24.196489
OUT: [8 9 5 3 2 3], 24.168723
OUT: [8 4 3 5 3 6], 24.163794
OUT: [8 3 9 4 3 2], 24.161126
OUT: [8 2 3 5 3 2], 24.151977
OUT: [8 3 5 3 2 5], 24.142883
OUT: [8 3 2 5 3 5], 24.142883
OUT: [8 3 2 0 4 6], 24.139469
OUT: [8 2 3 2 0 4], 24.135990
OUT: [8 5 3 9 4 3], 24.130010
OUT: [8 4 3 6 5 3], 24.126048
OUT: [8 4 3 2 3 6], 24.123921
OUT: [8 9 4 5 3 2], 24.119451
OUT: [8 9 4 3 9 4], 24.119353
OUT: [8 2 3 2 3 2], 24.112104
OUT: [8 3 5 3 2 6], 24.108458
OUT: [8 2 5 3 2 0], 24.106290
OUT: [8 3 2 3 2 5], 24.103010
OUT: [8 5 3 5 3 9], 24.099049
OUT: [8 3 2 0 9 4], 24.099021
OUT: [8 9 4 3 6 3], 24.089309
OUT: [8 3 6 3 9 4], 24.087336
OUT: [8 3 6 9 4 3], 24.076594
OUT: [8 9 4 6 5 3], 24.075937
Out[39]:
[array([8, 5, 3, 2, 0, 4]),
 array([8, 9, 4, 3, 2, 0]),
 array([8, 6, 5, 3, 2, 0]),
 array([8, 9, 5, 3, 2, 0]),
 array([8, 9, 4, 5, 3, 6]),
 array([8, 5, 3, 6, 9, 4]),
 array([8, 3, 2, 0, 4, 6]),
 array([8, 9, 4, 5, 3, 2]),
 array([8, 3, 2, 0, 9, 4]),
 array([8, 9, 4, 6, 5, 3])]

In [ ]:
do_inference_list_viterbi(ps0, L0, M0, w_u, w_p, f_u, f_p, y_true=y_true0, y_true_list=y_true_list0) # allow sub-tours

In [ ]:
do_inference_brute_force(ps0, L0, M0, w_u, w_p, f_u, f_p, debug=True)

In [ ]:
do_inference_list_viterbi(ps0, L0, M0, w_u, w_p, f_u, f_p)

In [ ]:
a = [14,  0, 18, 27, 25]
#a = [14,  0, 27, 25, 18]
#a = [14,  0, 18, 27,  0]
priority = 0

In [ ]:
for t in range(1, L0): 
    ss = a[t-1]
    tt = a[t]
    priority += np.dot(w_p[ss, tt], f_p[ss, tt]) + np.dot(w_u[tt], f_u[tt])

In [ ]:
priority

In [ ]:
do_inference_brute_force(ps0, L0, M0, w_u, w_p, f_u, f_p, y_true=y_true0, y_true_list=y_true_list0)

In [ ]:
do_inference_list_viterbi(ps0, L0, M0, w_u, w_p, f_u, f_p, y_true=y_true0, y_true_list=y_true_list0) # allow sub-tours