In [49]:
import os.path
import pandas as pd

# make sure that read_csv is available
assert hasattr(pd, 'read_csv')

def load_text_pandas(filename, dtype, skiprows=0, delimiter=' '):
    with open(filename, 'r') as f:
        for _ in xrange(skiprows):
            f.readline()
        x = pd.read_csv(f, header=None, sep=delimiter).values.astype(dtype).squeeze()
    # return pd.read_csv(filename, skiprows=skiprows, sep=delimiter).values.astype(dtype)
    # print x.shape
    return x

def get_actual_filename(filename, extension, fileindex=1):
    """Search the most plausible existing filename corresponding to the
    requested approximate filename, which has the required file index and
    extension."""
    dir = os.path.dirname(filename)
    filename = os.path.basename(filename)
    files = os.listdir(dir)
    prefix = filename
    if fileindex is None:
        suffix = '.{0:s}'.format(extension)
    else:
        suffix = '.{0:s}.{1:d}'.format(extension, fileindex)
    filtered = []
    # find the real filename with the longest path that fits the requested
    # filename
    while prefix and not filtered:
        filtered = filter(lambda file: (file.startswith(prefix) and 
            file.endswith(suffix)), files)
        prefix = prefix[:-1]
    # order by increasing length and return the shortest
    filtered = sorted(filtered, cmp=lambda k, v: len(k) - len(v))
    return os.path.join(dir, filtered[0])

In [50]:
filename = "../_test/data/test"
filename = r"D:\Spike sorting\spike_sorting_janelia\32Channels\trace_32Chan_linear32\MKK_CDfixedderived\kenneth_1\trace_32Chan"
filename = r"D:\Spike sorting\matrix\n6mab061109quarter1_n6mab031109_MKKdistfloat_197_regular2p3_1_subset_129929"
filename = r"D:\Spike sorting\matrix\2\n6mab061109quarter1_n6mab031109_MKKdistfloat_197_regular2p3_1_subset_129929"
fileindex = 1
fetdim = 3
nchannels = 32

In [51]:
path = get_actual_filename(filename, 'clu', fileindex)
# clusters = load_text(path, np.int32)
clusters = load_text_pandas(path, np.int32)
nspikes = len(clusters) - 1

# nclusters = clusters[0]
clusters = clusters[1:]

In [52]:
# FEATURES
# -------------------------------------------------
# features = load_text_fast(filename + ".fet.%d" % fileindex, np.int32, skiprows=1)
path = get_actual_filename(filename, 'fet', fileindex)
features = load_text_pandas(path, np.int32, skiprows=1)
features = np.array(features, dtype=np.float32)

features = features.reshape((-1, fetdim * nchannels + 1))
ndims = fetdim * nchannels

# get the spiketimes
spiketimes = features[:,-1].copy()

# remove the last column in features, containing the spiketimes
# features = features[:,:nchannels * fetdim]
nextrafet = features.shape[1] - nchannels * fetdim

# remove extra features
features = features[:,:-nextrafet]

# normalize normal features
m = features.min()
M = features.max()
# force symmetry
vx = max(np.abs(m), np.abs(M))
m, M = -vx, vx
features = -1+2*(features-m)/(M-m)

In [53]:
# masks = load_text(filename + ".fmask.%d" % fileindex, np.float32, skiprows=1)
path = get_actual_filename(filename, 'fmask', fileindex)
masks = load_text_pandas(path, np.float32, skiprows=1)
masks = masks[:,:-nextrafet]

In [54]:
from collections import Counter
c = Counter(clusters)
spikes_in_clusters = [np.nonzero(clusters == clu)[0] for clu in sorted(c)]
nClusters = len(spikes_in_clusters)

In [55]:
#def get_stats_KL(Fet1, Fet2, spikes_in_clusters, masks):
Fet1, Fet2 = features, features
    
nPoints = Fet1.shape[0] #size(Fet1, 1)
nDims = Fet1.shape[1] #size(Fet1, 2)
# nClusters = Clu2.max() #max(Clu2)
nClusters = len(spikes_in_clusters)

In [56]:
# precompute the mean and variances of the masked points for each feature
# contains 1 when the corresponding point is masked
masked = np.zeros_like(masks)
masked[masks == 0] = 1
nmasked = np.sum(masked, axis=0)
nu = np.sum(Fet2 * masked, axis=0) / nmasked
nu = nu.reshape((1, -1))
sigma2 = np.sum(((Fet2 - nu) * masked) ** 2, axis=0) / nmasked
sigma2 = sigma2.reshape((1, -1))
# expected features
y = Fet1 * masks + (1 - masks) * nu
z = masks * Fet1**2 + (1 - masks) * (nu ** 2 + sigma2)
eta = z - y ** 2

In [57]:
LogP = np.zeros((nPoints, nClusters))
stats = {}

for c in xrange(nClusters):
    # MyPoints = np.nonzero(Clu2==c)[0]
    MyPoints = spikes_in_clusters[c]
    # MyFet2 = Fet2[MyPoints, :]
    # now, take the modified features here
    # MyFet2 = y[MyPoints, :]
    MyFet2 = np.take(y, MyPoints, axis=0)
    if len(MyPoints) > nDims:
        LogProp = np.log(len(MyPoints) / float(nPoints)) # log of the proportion in cluster c
        Mean = np.mean(MyFet2, axis=0).reshape((1, -1))
        CovMat = np.cov(MyFet2, rowvar=0) # stats for cluster c
        
        # HACK: avoid instability issues, kind of works
        CovMat += np.diag(1e-3 * np.ones(nDims))
        
        # now, add the diagonal modification to the covariance matrix
        # the eta just for the current cluster
        etac = np.take(eta, MyPoints, axis=0)
        d = np.sum(etac, axis=0) / nmasked
        # add diagonal
        CovMat += np.diag(d)
        CovMatinv = np.linalg.inv(CovMat)
        LogDet = np.log(np.linalg.det(CovMat))
        
        stats[c] = (Mean, CovMat, CovMatinv, LogDet, len(MyPoints))

clusters = sorted(stats.keys())

In [58]:
eigs=[]
for ci in clusters:
    Mean, CovMat, CovMatinv, LogDet, npoints = stats[ci]
    eigs.append(sorted(eig(CovMat)[0])[::-1])
eigs = array(eigs)
figure(figsize=(12,6))
#print eigs
imshow(log(eigs), interpolation='none')
colorbar()


Out[58]:
<matplotlib.colorbar.Colorbar instance at 0x000000000E3A6B88>

In [59]:
#clusters = sorted(stats.keys())
#matrix_original = zeros((nClusters, nClusters))

LogP = np.zeros((nPoints, nClusters))
for c in xrange(nClusters):
    # MyPoints = np.nonzero(Clu2==c)[0]
    MyPoints = spikes_in_clusters[c]
    # MyFet2 = Fet2[MyPoints, :]
    # now, take the modified features here
    # MyFet2 = y[MyPoints, :]
    MyFet2 = np.take(y, MyPoints, axis=0)
    if len(MyPoints) > nDims:
        LogProp = np.log(len(MyPoints) / float(nPoints)) # log of the proportion in cluster c
    
        Mean, CovMat, CovMatinv, LogDet, npoints = stats[c]
    
        # dx = Fet1 - Mean #repmat(Mean, nPoints, 1) # distance of each point from cluster
        # we take the expected features
        dx = y - Mean #repmat(Mean, nPoints, 1) # distance of each point from cluster
        # y = dx / CovMat
        # print Fet1.shape, Mean.shape, dx.shape, CovMat.shape
        # TODO: we don't need that anymore if we compute the inverse of the cov matrix
        y2 = np.linalg.solve(CovMat.T, dx.T).T
        correction = np.sum(eta * np.diag(CovMatinv).reshape((1, -1)), axis=1)
        LogP[:,c] = (np.sum(y2*dx, axis=1)/2. + correction / 2.) + LogDet/2. - LogProp + np.log(2*np.pi)*nDims/2. # -Log Likelihood
            # -log of joint probability that the point lies in cluster c and has given coords.

    else:
        LogP[:,c] = np.inf

JointProb = np.exp(-LogP)

# # if any points have all probs zero, set them to cluster 1
JointProb[np.sum(JointProb, axis=1) == 0, 0] = 1e-9 #eps

# #probability that point belongs to cluster, given coords
# p = JointProb / repmat(sum(JointProb,2), 1, nClusters) 
P = JointProb / np.sum(JointProb, axis=1).reshape((-1, 1))
    
matrix_original = np.zeros((nClusters, nClusters))
for c in xrange(nClusters):
    # MyPoints = np.nonzero(Clu2==c)[0]
    MyPoints = spikes_in_clusters[c]
    matrix_original[c,:] = np.mean(P[MyPoints, :], axis=0)

In [60]:
matrix_original[range(nClusters),range(nClusters)] = 0

In [61]:
matrix_KL = zeros((nClusters, nClusters))

for ci in clusters:
    for cj in clusters:
        mui, Ci, Ciinv, logdeti, npointsi = stats[ci]
        muj, Cj, Cjinv, logdetj, npointsj = stats[cj]
        dmu = (muj - mui).reshape((-1, 1))
        
        # KL divergence
        dkl = .5 * (trace(dot(Cjinv, Ci)) + dot(dot(dmu.T, Cjinv), dmu) - logdeti + logdetj - ndims)
        
        matrix_KL[ci, cj] = dkl

In [62]:
matrix_KL = -matrix_KL
matrix_KL[matrix_KL==0] = matrix_KL[matrix_KL!=0].min()

In [63]:
matrix_mean = zeros((nClusters, nClusters))

for ci in clusters:
    for cj in clusters:
        mui, Ci, Ciinv, logdeti, npointsi = stats[ci]
        muj, Cj, Cjinv, logdetj, npointsi = stats[cj]
        dmu = (muj - mui).reshape((-1, 1))
        
        matrix_mean[ci, cj] = sum(dmu ** 2)

In [64]:
matrix_mean2 = -matrix_mean
matrix_mean2[matrix_mean2==0] = matrix_mean2[matrix_mean2!=0].min()
matrix_mean2.min(), matrix_mean2.max()


Out[64]:
(-0.78876578807830811, -0.0015229303389787674)

In [99]:
matrix_product = zeros((nClusters, nClusters))

for ci in clusters:
    mui, Ci, Ciinv, logdeti, npointsi = stats[ci]
    for cj in clusters:
        muj, Cj, Cjinv, logdetj, npointsj = stats[cj]
        dmu = (muj - mui).reshape((-1, 1))
        
        p = log(2*pi)*(-nDims/2.)+(-.5*log(det(Ci+Cj)))+(-.5)*dot(dot(dmu.T, inv(Ci+Cj)), dmu)
        alpha = float(npointsi) / len(spiketimes)
        matrix_product[ci, cj] = p# + log(alpha)

In [104]:
matrix_product[range(len(clusters)), range(len(clusters))]= 0
matrix_product[matrix_product==0] = matrix_product[matrix_product!=0].min()

In [105]:
figure(figsize=(14,8))
subplot(131)
imshow(matrix_original.T, interpolation='none')
title("Klusters")
subplot(132)
imshow(exp(matrix_KL.T), interpolation='none')
title("KL")
subplot(133)
imshow(exp(matrix_product.T), interpolation='none')
title("Scalar product")


Out[105]:
<matplotlib.text.Text at 0xd70ea20>

In [ ]:
def draw_covariance(C, **kwargs):
    U, s , Vh = svd(C)
    orient = math.atan2(U[1,0],U[0,0])
    w=math.sqrt(s[0])
    h=math.sqrt(s[1])
    a=orient
    
    e = Ellipse(xy=(0,0), width=w, height=h, angle=a, **kwargs)
    ax.add_artist(e)
    
    xlim(-w, w)
    ylim(-h, h)

In [ ]:
Ci = array([[2., 1.], [1., 2.]])
Cj = array([[2., 1.], [1., 2.]])

dmu = array([[1., 2.]]).T
Cjinv = inv(Cj)
logdeti = log(det(Ci))
logdetj = log(det(Cj))
ndims = 2

dkl = .5 * (trace(dot(Cjinv, Ci)) + dot(dot(dmu.T, Cjinv), dmu) - logdeti + logdetj - ndims)
print dkl

from matplotlib.patches import Ellipse

fig = figure()
ax = fig.add_subplot(111, aspect='equal')
draw_covariance(Ci)
draw_covariance(Cj)