In [ ]:
from skbio.stats.distance import DissimilarityMatrix
import numpy as np
from numpy.linalg import eigh
from numpy import diag, allclose

In [ ]:
class LabelledMatrix(object):
    def __init__(self, data, ids):
        self.data = data
        self.ids = ids

def readlsm(filename):
    from skbio.io.format.lsmat import _lsmat_to_matrix
    with open(filename) as fh:
        matrix = _lsmat_to_matrix(LabelledMatrix, fh, '\t')
    return matrix

In [ ]:
km = readlsm("kwip/field-expt_wip.kern")
dm = readlsm("kwip/field-expt_wip.dist")

In [ ]:
v, w = eigh(km.data)
if all(v > 0):
    print("PD matrix")
print(v)

In [ ]:
def normalise_01(kmat):
    """Normalise values of matrix kmat
    to have smallest value 0 and largest value 1.
    """
    smallest = np.min(kmat)
    largest = np.max(kmat)
    return (kmat - smallest)/(largest-smallest)

In [ ]:
def center(K):
    """Center the kernel matrix, such that the mean (in feature space) is zero."""
    one_mat = np.matrix(np.ones(K.shape))
    one_vec = np.matrix(np.ones((K.shape[0],1)))

    row_sum = np.matrix(np.mean(K,axis=0)).T
    R = K - row_sum * one_vec.T - one_vec * row_sum.T +\
        np.mean(row_sum.A)*one_mat
    return R

In [ ]:
def normalise_unit_diag(kmat):
    """Normalise values of matrix kmat
    such that the diagonal are all ones
    """
    Kii = np.diag(kmat)
    Kii.shape = (len(Kii),1)
    return np.divide(kmat, np.sqrt(np.matrix(Kii)*np.matrix(Kii).T))

In [ ]:
def kernel2dist(kmat):
    """Convert the kernel matrix into the corresponding distance"""
    # TODO: Vectorize
    D = np.zeros(kmat.shape)
    for ix in range(kmat.shape[0]):
        for iy in range(kmat.shape[1]):
            sqr_dist = kmat[ix,ix] + kmat[iy,iy] - 2*kmat[ix,iy]
            if sqr_dist > 0.0:
                D[ix,iy] = np.sqrt(sqr_dist)
    return D

In [ ]:
nk = normalise_unit_diag(km.data)
v, w = eigh(nk)
print(v)

In [ ]:
d = kernel2dist(nk)

In [ ]:
if np.allclose(d, dm.data, rtol=0.001, atol=0.0):
    print("Normalisation is the same between kWIP & old code")
else:
    print("BAD NORM")

In [ ]:
d.min(), d.max(), d.mean()

In [ ]:
v, w = eigh(d)

In [ ]:
v.sum()

In [ ]:
v, oldv

In [ ]: