In [1]:
filename = r"D:\Spike sorting\matrix_bug\n6mab041109_n6mab031109_MKKdistfloat_25_regular1p7_1_subset_129989.clu.1"
In [2]:
from loader import KlustersLoader
from selection import get_spikes_in_clusters, select, select_numpy
In [3]:
l = KlustersLoader(filename)
In [4]:
features = l.get_features()
waveforms = l.get_waveforms()
masks = l.get_masks(full=True)
clusters = l.get_clusters()
In [5]:
from collections import Counter
c = Counter(clusters)
spikes_in_clusters = [np.nonzero(clusters == clu)[0] for clu in sorted(c)]
In [6]:
from correlations import compute_correlations, compute_statistics
from statstools import matrix_of_pairs
In [7]:
stats = compute_statistics(features, features, spikes_in_clusters, masks)
In [8]:
%run correlations
correlations = compute_correlations(features, clusters, masks)
matrix = matrix_of_pairs(correlations)
In [9]:
def report_pair(c0, c1):
# SELECT CLUSTERS
spikes0 = get_spikes_in_clusters([c0], clusters)
waves0 = select(waveforms, spikes0)
features0 = select(features, spikes0)
masks0 = select(masks, spikes0)[:,:-1][:,::3]
spikes1 = get_spikes_in_clusters([c1], clusters)
waves1 = select(waveforms, spikes1)
features1 = select(features, spikes1)
masks1 = select(masks, spikes1)[:,:-1][:,::3]
# MASKS
fig = figure();
fig.suptitle("Masks", fontsize=20)
plot(masks0.mean(axis=0), 'b', label=str(c0));
plot(masks1.mean(axis=0), 'r', label=str(c1));
grid();
legend(loc=2);
# select most relevant channels according to masks
nfetplots = 8
allmasks = masks0.mean(axis=0) + masks1.mean(axis=0)
best_channels = sort(np.argsort(allmasks)[::-1][:nfetplots])
# FEATURES
from itertools import product
fig = figure(figsize=(10, 10))
fig.suptitle("Channels", fontsize=20)
for k, (i, j) in enumerate(product(best_channels, best_channels)):
subplot(nfetplots, nfetplots, k + 1)
plot(features0[:,3*i], features0[:,3*j], 'b,', alpha=.25);
plot(features1[:,3*i], features1[:,3*j], 'r,', alpha=.25);
xticks([])
yticks([])
if k < nfetplots:
title("%d" % j)
# WAVEFORMS
fig = figure(figsize=(16, 8))
fig.suptitle("Waveforms", fontsize=20)
for i in range(32):
ind0 = randint(size=100, low=0, high=waves0.shape[0])
ind1 = randint(size=100, low=0, high=waves1.shape[0])
subplot(4, 8, i + 1);
plot(waves0[ind0,:,i].T, 'b', alpha=.25);
plot(waves1[ind1,:,i].T, 'r', alpha=.25);
title("Channel %d" % i)
xticks([])
yticks([])
In [10]:
def report_pair_stats(c0, c1):
(mu0, C0, Cinv0, logdet0, n0) = stats[c0]
(mu1, C1, Cinv1, logdet1, n1) = stats[c1]
# Plot mean vectors.
fig = figure()
fig.suptitle("Means", fontsize=20)
mu0bis = mu0.ravel()[:-1]
plot(linspace(0., 31., 96), mu0bis, 'bo');
mu1bis = mu1.ravel()[:-1]
plot(linspace(0., 31., 96), mu1bis, 'ro');
grid();
# Plot cov matrices.
fig = figure(figsize=(14,14,))
fig.suptitle("Covs", fontsize=20)
subplot(121);
imshow(C0, interpolation='none');
subplot(122);
imshow(C1, interpolation='none');
In [11]:
c0, c1 = 83,76
In [12]:
report_pair(c0, c1)
In [13]:
(mu0, C0, Cinv0, logdet0, n0) = stats[c0]
(mu1, C1, Cinv1, logdet1, n1) = stats[c1]
# Plot mean vectors.
fig = figure()
fig.suptitle("Means", fontsize=20)
mu0bis = mu0.ravel()[:-1]
plot(linspace(0., 31., 96), mu0bis, 'bo');
mu1bis = mu1.ravel()[:-1]
plot(linspace(0., 31., 96), mu1bis, 'ro');
grid();
# Plot cov matrices.
fig = figure(figsize=(14,8))
fig.suptitle("Covs", fontsize=20)
subplot(121);
imshow(C0, interpolation='none', vmin=0, vmax=.01);
colorbar();
subplot(122);
imshow(C1, interpolation='none', vmin=0, vmax=.01);
colorbar();
In [13]: