In [1]:
filename = r"D:\Spike sorting\matrix_bug\n6mab041109_n6mab031109_MKKdistfloat_25_regular1p7_1_subset_129989.clu.1"

Loading files


In [2]:
from loader import KlustersLoader
from selection import get_spikes_in_clusters, select, select_numpy

In [3]:
l = KlustersLoader(filename)

In [4]:
features = l.get_features()
waveforms = l.get_waveforms()
masks = l.get_masks(full=True)
clusters = l.get_clusters()

In [5]:
from collections import Counter
c = Counter(clusters)
spikes_in_clusters = [np.nonzero(clusters == clu)[0] for clu in sorted(c)]

In [6]:
from correlations import compute_correlations, compute_statistics
from statstools import matrix_of_pairs

In [7]:
stats = compute_statistics(features, features, spikes_in_clusters, masks)

In [8]:
%run correlations
correlations = compute_correlations(features, clusters, masks)
matrix = matrix_of_pairs(correlations)

Report functions


In [9]:
def report_pair(c0, c1):
    # SELECT CLUSTERS
    spikes0 = get_spikes_in_clusters([c0], clusters)
    waves0 = select(waveforms, spikes0)
    features0 = select(features, spikes0)
    masks0 = select(masks, spikes0)[:,:-1][:,::3]
    
    spikes1 = get_spikes_in_clusters([c1], clusters)
    waves1 = select(waveforms, spikes1)
    features1 = select(features, spikes1)
    masks1 = select(masks, spikes1)[:,:-1][:,::3]
    
    # MASKS
    fig = figure();
    fig.suptitle("Masks", fontsize=20)
    plot(masks0.mean(axis=0), 'b', label=str(c0));
    plot(masks1.mean(axis=0), 'r', label=str(c1));
    grid();
    legend(loc=2);
    
    # select most relevant channels according to masks
    nfetplots = 8
    allmasks = masks0.mean(axis=0) + masks1.mean(axis=0)
    best_channels = sort(np.argsort(allmasks)[::-1][:nfetplots])
    
    # FEATURES
    from itertools import product
    fig = figure(figsize=(10, 10))
    fig.suptitle("Channels", fontsize=20)
    for k, (i, j) in enumerate(product(best_channels, best_channels)):
        subplot(nfetplots, nfetplots, k + 1)
        plot(features0[:,3*i], features0[:,3*j], 'b,', alpha=.25);
        plot(features1[:,3*i], features1[:,3*j], 'r,', alpha=.25);
        xticks([])
        yticks([])
        if k < nfetplots:
            title("%d" % j)
    
    # WAVEFORMS
    fig = figure(figsize=(16, 8))
    fig.suptitle("Waveforms", fontsize=20)
    for i in range(32):
        ind0 = randint(size=100, low=0, high=waves0.shape[0])
        ind1 = randint(size=100, low=0, high=waves1.shape[0])
        subplot(4, 8, i + 1);
        plot(waves0[ind0,:,i].T, 'b', alpha=.25);
        plot(waves1[ind1,:,i].T, 'r', alpha=.25);
        title("Channel %d" % i)
        xticks([])
        yticks([])

In [10]:
def report_pair_stats(c0, c1):
    
    (mu0, C0, Cinv0, logdet0, n0) = stats[c0]
    (mu1, C1, Cinv1, logdet1, n1) = stats[c1]
    
    # Plot mean vectors.
    fig = figure()
    fig.suptitle("Means", fontsize=20)
    mu0bis = mu0.ravel()[:-1]
    plot(linspace(0., 31., 96), mu0bis, 'bo');
    mu1bis = mu1.ravel()[:-1]
    plot(linspace(0., 31., 96), mu1bis, 'ro');
    
    grid();
    
    # Plot cov matrices.
    fig = figure(figsize=(14,14,))
    fig.suptitle("Covs", fontsize=20)
    subplot(121);
    imshow(C0, interpolation='none');
    subplot(122);
    imshow(C1, interpolation='none');

Plotting cluster pairs


In [11]:
c0, c1 = 83,76

In [12]:
report_pair(c0, c1)



In [13]:
(mu0, C0, Cinv0, logdet0, n0) = stats[c0]
(mu1, C1, Cinv1, logdet1, n1) = stats[c1]

# Plot mean vectors.
fig = figure()
fig.suptitle("Means", fontsize=20)
mu0bis = mu0.ravel()[:-1]
plot(linspace(0., 31., 96), mu0bis, 'bo');
mu1bis = mu1.ravel()[:-1]
plot(linspace(0., 31., 96), mu1bis, 'ro');
grid();

# Plot cov matrices.
fig = figure(figsize=(14,8))
fig.suptitle("Covs", fontsize=20)
subplot(121);
imshow(C0, interpolation='none', vmin=0, vmax=.01);
colorbar();
subplot(122);
imshow(C1, interpolation='none', vmin=0, vmax=.01);
colorbar();



In [13]: