Spank 2

A lighter weight version (small arrays and less plotting) of Spank

Also trying some parallel computation stuff


In [1]:
from IPython import parallel
clients = parallel.Client()
clients.block = True  # use synchronous computations
print clients.ids


[0, 1, 2, 3]

In [ ]:


In [13]:
# Test with a trivial example
def mul(a, b):
    return a * b

In [14]:
# clients[:].apply(mul, 5, 6)
view = clients.load_balanced_view()
view.map(mul, [5, 6, 7, 8], [8, 9, 10, 11])


Out[14]:
[40, 54, 70, 88]

In [19]:
%px
import numpy as np
import tables as tb
import pandas as pd
import scipy.io as sio
from scipy import linalg
import itertools


from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn import mixture
from sklearn.externals.six.moves import xrange

%matplotlib inline
#%config InlineBackend.figure_format = 'svg'
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib as mpl

Load a small subset of the ephys data


In [120]:
gut = np.loadtxt('Gutnisky_py.txt',usecols=range(1,20001))
#gut = pd.read_table('Gutnisky_py.txt',index_col=0,usecols=range(0,10001))

In [121]:
gut.shape
#gut?
# gut.describe


Out[121]:
(8, 20000)

Extract spike times


In [122]:
# Find std deviation of the raw waveforms, and use a multiple of this to extract spiketimes
m = np.std(gut)
# Take absolute peak this time, to catch positive spikes
mask = np.absolute(gut)>m*4
indices = np.where(mask)
indices = indices[1]

In [123]:
# Sort and remove spikes from edges, and those too close together
index = np.sort(indices)
# plt.plot(np.diff(index))

#index = np.sort(index)
mask = np.diff(index)>8
idx = np.where(mask)
spks = index[idx]

spks = spks[(spks)>15]
spks.shape


Out[123]:
(142,)

In [139]:
# Non shifted waveforms
waves = np.zeros((len(spks),8*30))
for i in range(len(spks)):
    x = np.reshape(gut[:,spks[i]-15:spks[i]+15],240)
    waves[i,:] = x

Load waveforms, determine location of peak, work out offset of that peak from 15 samp centres and re-align TO DO: Interpolate and do this once to see the range of shifts, then define a slightly larger window for waveform extraction, interpolation and shift (to allow for sub-sample rate shifts). Non-trivial as how do we then decide when the spike occurred WRT out sample rate?


In [141]:
# Shifted indices, for shifted waveforms
shift = np.zeros(len(spks))
index = np.zeros_like(spks)
for i in range(len(spks)):
    x = np.reshape(gut[:,spks[i]-15:spks[i]+15],240)
    pk = np.argmax(np.absolute(x))
    shift[i] = 15-np.mod(pk,30)
    index[i] = spks[i]+shift[i]

In [142]:
plt.hist(shift)


Out[142]:
(array([  3.,   1.,   0.,   0.,  57.,  34.,  37.,   5.,   3.,   2.]),
 array([-14. , -11.1,  -8.2,  -5.3,  -2.4,   0.5,   3.4,   6.3,   9.2,
         12.1,  15. ]),
 <a list of 10 Patch objects>)

In [143]:
wavesS = np.zeros((len(index),8*30))
for i in range(len(spks)):
  #  print i
    x = np.reshape(gut[:,index[i]-15:index[i]+15],240)
    wavesS[i,:] = x

In [144]:
# from sklearn.decomposition import PCA
# Define PCA w/ whitening 
pca = PCA(n_components=2,whiten='false')
# Fit the PCA
X_S = pca.fit(wavesS).transform(wavesS)
# Percentage of variance explained for each components
print('explained variance ratio (first few components): %s'
      % str(pca.explained_variance_ratio_))
#from sklearn.cluster import KMeans
kmeans = KMeans(4, random_state=8)
Y_hat = kmeans.fit(X_S).labels_
plt.scatter(X_S[:,0], X_S[:,1], c=Y_hat,alpha = 0.4);
mu = kmeans.cluster_centers_
plt.scatter(mu[:,0], mu[:,1], s=100, c=np.unique(Y_hat))


explained variance ratio (first few components): [ 0.14706308  0.13592002]
Out[144]:
<matplotlib.collections.PathCollection at 0x10fe09390>

In [152]:
# Tried generating inverse waveforms, but it's not working (not enough signal using PCA)
fig, ax = plt.subplots(nrows=1,ncols=1,figsize=(12, 4))
invMu = pca.inverse_transform(mu)
invMu.shape
# plt.plot(invMu[3,:].reshape(30,8))
plt.plot(invMu.T)


Out[152]:
[<matplotlib.lines.Line2D at 0x10a9c2310>,
 <matplotlib.lines.Line2D at 0x10a9c2590>,
 <matplotlib.lines.Line2D at 0x10a9c27d0>,
 <matplotlib.lines.Line2D at 0x10a9c2990>]

In [427]:
invMu.shape


Out[427]:
(6, 160)

In [434]:
# # Plot mean waveforms of Kmeans clusters
# fig, ax = plt.subplots(nrows=1,ncols=1,figsize=(12, 4))
# for i in xrange(len(invMu)):
#     plt.subplot(1,len(invMu),i+1)
#     plt.plot(invMu[i,:].reshape(8,30).T)

# plt.tight_layout()
colours


Out[434]:
array([[  5.00000000e-01,   0.00000000e+00,   1.00000000e+00,
          1.00000000e+00],
       [  1.70588235e-01,   4.94655843e-01,   9.66718404e-01,
          1.00000000e+00],
       [  1.66666667e-01,   8.66025404e-01,   8.66025404e-01,
          1.00000000e+00],
       [  5.03921569e-01,   9.99981027e-01,   7.04925547e-01,
          1.00000000e+00],
       [  8.33333333e-01,   8.66025404e-01,   5.00000000e-01,
          1.00000000e+00],
       [  1.00000000e+00,   4.94655843e-01,   2.55842778e-01,
          1.00000000e+00],
       [  1.00000000e+00,   1.22464680e-16,   6.12323400e-17,
          1.00000000e+00]])

In [435]:
# Fit kmeans directly to data
MnumS = 4
kmeans = KMeans(MnumS)
mu_wavesS = kmeans.fit(wavesS).cluster_centers_

colours = cm.rainbow(np.linspace(0, 1, 8))

fig, ax = plt.subplots(nrows=1,ncols=1,figsize=(12, 4))
for i in xrange(MnumS):
    plt.subplot(1,MnumS,i+1)
    plt.plot(mu_wavesS[i].reshape(8,30).T,color=colours[:])

plt.tight_layout()


<matplotlib.figure.Figure at 0x1166b1950>

In [156]:
# Fit kmeans directly to data, non-shifted
Mnum = 4
kmeans = KMeans(Mnum)
mu_waves = kmeans.fit(waves).cluster_centers_

fig, ax = plt.subplots(nrows=1,ncols=1,figsize=(12, 4))
for i in xrange(Mnum):
    plt.subplot(1,Mnum,i+1)
    plt.plot(mu_waves[i].reshape(30,8))

plt.tight_layout()


2.0. Need to do interpolation before shifting. Also could use a better threshold?


In [419]:
# Load more data until it all breaks again
# gut = np.loadtxt('Gutnisky_py.txt',usecols=range(1,100001))
gut = np.loadtxt('Gutnisky_py.txt',usecols=range(1,3129231))

In [4]:
# Threshold ideas
m = np.std(gut)
print m*4        # 4*std error
print (m-np.mean(gut))*4 # 4*std error - mean
print (np.median(np.absolute(gut))/0.6745)*4 # 4*median(abs(x))/0.6745 from Quiroga, Harris and others


84.7841148329
86.6873474172
59.3031875463

In [420]:
# 4*std still seems a decent bet for detection
m = np.std(gut)
# Take absolute peak this time, to catch positive spikes
mask = np.absolute(gut)>m*4
indices = np.where(mask)
indices = np.sort(indices[1])  # Not interested on which channel the peak was, really. Or am I?

In [357]:
# Plotting when the spikes happen
plt.clf
plt.plot(1000*indices[0])
plt.plot(indices[1])
plt.plot(np.sort(indices[1]))

indices?



In [421]:
# Remove spikes from start, and those too close together
mask = np.diff(indices)>8 # When spikes are too close, this removes the earliest ones in a sequence
idx = np.where(mask)
spks = indices[idx]

In [312]:
fig, axes = plt.subplots(nrows=1,ncols=1,figsize=(12, 4))
plt.plot(gut.T,'k',alpha = 0.3)
plt.plot(indices,200*np.ones(len(indices)),'.',color='r')
plt.plot(spks,200*np.ones(len(spks)),'.',color='b')
plt.xlim([250,280])


Out[312]:
(250, 280)

In [422]:
# Shifted indices, for shifted waveforms
shift = np.zeros(len(spks))
index = np.zeros_like(spks)
wavesA = np.zeros((len(index),8*20))
# for i in range(20):
for i in range(len(spks)):
# x = np.reshape(gut[:,spks[i]-15:spks[i]+15],240)
    x = gut[:,spks[i]-15:spks[i]+15].T
    pk = np.argmax(np.absolute(x.T))
    shift[i] = (np.mod(pk,30)-15) #np.mod(pk,30)-
    index[i] = spks[i]+shift[i]

    y = gut[:,index[i]-10:index[i]+10].T
    y = y.reshape(160)
    wavesA[i,:] = y

In [423]:
# Lots of plotting/printing for debugging
#x = waves[5,:]
#y = wavesS[5,:]
# plt.plot(x.reshape(30,8),'b')
# plt.plot(y.reshape(30,8),'r')
i =1
#plt.plot(x.reshape(30,8),'b')
#plt.plot(y.reshape(20,8),'r')
# plt.plot(np.absolute(x.T.reshape(240)),'g')
plt.plot(wavesA[i,:].reshape(20,8),'g')
print spks[i]
print pk
print np.mod(pk,30)
print shift[i]
print index[i]


226
38
8
-1.0
225

In [429]:
wavesA.shape
wavesO = np.empty_like(wavesA)
wavesO[:] = wavesA
# wavesO = np.copy.deepcopy(wavesA) #X_S.shape

PCA/Kmeans again on actually aligned data


In [401]:
# from sklearn.decomposition import PCA
# Define PCA w/ whitening 
pca = PCA(n_components=50,whiten='false')
# Fit the PCA
X_S = pca.fit(wavesA).transform(wavesA)
# Percentage of variance explained for each components
print('explained variance ratio (first few components): %s'
      % str(pca.explained_variance_ratio_))
#from sklearn.cluster import KMeans
kmeans = KMeans(6, random_state=8)
Y_hat = kmeans.fit(X_S).labels_
plt.scatter(X_S[:,0], X_S[:,1], c=Y_hat,alpha = 0.4);
mu = kmeans.cluster_centers_
plt.scatter(mu[:,0], mu[:,1], s=100, c=np.unique(Y_hat))


explained variance ratio (first few components): [ 0.29715898  0.16354807  0.10884256  0.04283791  0.02925272  0.02651601
  0.02416472  0.02237233  0.0192831   0.01635502  0.01610419  0.01448015
  0.01307426  0.01169332  0.01051051  0.01026856  0.00954974  0.00901659
  0.00849529  0.00752237  0.00713567  0.00678858  0.00636219  0.00553538
  0.00540274  0.00519412  0.00492914  0.00484003  0.00475057  0.00442241
  0.00418283  0.00402771  0.00388604  0.00367826  0.00358422  0.00350625
  0.00334542  0.00315723  0.00303928  0.00274379  0.00257843  0.00249905
  0.00236909  0.00228243  0.00223738  0.0021973   0.00205218  0.0019727
  0.00188007  0.00179587]
Out[401]:
<matplotlib.collections.PathCollection at 0x118df80d0>

SciKit Learn GMM example to find the number of clusters


In [403]:
color_iter = itertools.cycle(['r', 'g', 'b', 'c', 'm','y','k'])
n = 10
#color_iter=iter(cm.rainbow(np.linspace(0,1,n)))

fig, splot = plt.subplots(nrows=1,ncols=1,figsize=(18, 5))

for i, (clf, title) in enumerate([
        (mixture.GMM(n_components=10, covariance_type='full', n_iter=10000),
         "Expectation-maximization"),
        (mixture.DPGMM(n_components=10, covariance_type='full', alpha=0.01,
                       n_iter=10000),
         "Dirichlet Process,alpha=0.01"),
        (mixture.DPGMM(n_components=10, covariance_type='diag', alpha=10.,
                       n_iter=10000),
         "Dirichlet Process,alpha=10.")]):

    clf.fit(X_S)
    splot = plt.subplot(1, 3, 1 + i)
    Y_ = clf.predict(X_S)
    
    for i, (mean, covar, color) in enumerate(zip(
            clf.means_, clf._get_covars(), color_iter)):
        v, w = linalg.eigh(covar)
        u = w[0] / linalg.norm(w[0])
        # as the DP will not use every component it has access to
        # unless it needs it, we shouldn't plot the redundant
        # components.
        if not np.any(Y_ == i):
            continue
        plt.scatter(X_S[Y_ == i, 0], X_S[Y_ == i, 1], .8, color=color)

        # Plot an ellipse to show the Gaussian component
        angle = np.arctan(u[1] / u[0])
        angle = 180 * angle / np.pi  # convert to degrees
        ell = mpl.patches.Ellipse(mean, v[0], v[1], 180 + angle, color=color)
        ell.set_clip_box(splot.bbox)
        ell.set_alpha(0.5)
        splot.add_artist(ell)

    plt.xlim(-6, 4 * np.pi - 6)
    plt.ylim(-5, 5)
    plt.title(title)
    plt.xticks(())
    plt.yticks(())

plt.show()



In [34]:
plt.hist(Y_)
max(Y_)
clf.means_


Out[34]:
array([[-0.18375941, -0.07463888,  0.16502837, -0.06900734, -0.01871269,
        -0.01779214,  0.02373176,  0.02007298, -0.02509611, -0.01147653],
       [-1.49454765, -2.52061733, -1.3542853 ,  8.60339642,  4.16794056,
         0.35117762, -1.497136  , -3.01639273,  0.19987596, -1.63493334],
       [ 2.70826119,  2.90137221,  0.19995342,  1.6926278 , -0.01939818,
        -2.00468625, -0.03251305,  1.39532597,  0.93693433, -0.54462534],
       [ 2.74596322,  2.92156772,  0.35154769,  0.39270337,  0.90808228,
         1.70476526, -0.75677798, -1.60350358, -0.89408765,  0.33771516],
       [-0.23576158, -0.12123042, -0.42321723,  0.70116598,  0.12112472,
         1.49912451,  1.10946499,  1.15554725,  0.76314448,  1.08669368],
       [-0.18364736, -0.07474266,  0.16506657, -0.0690184 , -0.01871647,
        -0.01779371,  0.02372043,  0.0200683 , -0.02508537, -0.01148891],
       [-0.18369894, -0.07469134,  0.16505258, -0.06901324, -0.01871957,
        -0.01779198,  0.02372124,  0.02006642, -0.02508503, -0.01148611],
       [ 0.27034669, -0.36398284, -0.91476871, -0.24454751, -0.17868209,
        -0.03992707, -0.05882379, -0.05607322,  0.07759205,  0.03975084],
       [-0.18360516, -0.07478165,  0.16508104, -0.06902256, -0.01871801,
        -0.01779427,  0.02371606,  0.02006644, -0.0250812 , -0.01149364],
       [-0.18360553, -0.07478126,  0.16508096, -0.06902253, -0.01871806,
        -0.01779425,  0.02371604,  0.02006641, -0.02508117, -0.01149364]])

In [462]:
# Tried generating inverse waveforms, but it's not working (not enough signal using PCA)
fig, ax = plt.subplots(nrows=1,ncols=1,figsize=(12, 4))
invMu = pca.inverse_transform(mu)
invMu.shape
lines = np.linspace(0, 7, 8)
colours = cm.rainbow(np.linspace(0, 1, 8))
plt.plot(invMu[0,:].reshape(8,20).T,c=colours)
# plt.legend(handles=lines)
# plt.plot(invMu[3,:])
# plt.xlim([100,120])


Out[462]:
[<matplotlib.lines.Line2D at 0x116888610>,
 <matplotlib.lines.Line2D at 0x116888110>,
 <matplotlib.lines.Line2D at 0x116888550>,
 <matplotlib.lines.Line2D at 0x116888c50>,
 <matplotlib.lines.Line2D at 0x1168883d0>,
 <matplotlib.lines.Line2D at 0x1155683d0>,
 <matplotlib.lines.Line2D at 0x1155682d0>,
 <matplotlib.lines.Line2D at 0x113be4950>]
<matplotlib.figure.Figure at 0x113be10d0>

In [ ]:
# Fit kmeans directly to data

Mnum = 7
kmeans = KMeans(Mnum)
mu_wavesA = kmeans.fit(wavesA).cluster_centers_
mu_wavesO = kmeans.fit(wavesO).cluster_centers_

In [231]:
fig, ax = plt.subplots(nrows=1,ncols=Mnum,figsize=(12, 5))
for i in xrange(Mnum):
    plt.subplot(2,Mnum,i+1)
    plt.plot(mu_wavesA[i].reshape(20,8))
    plt.title("Whitened")

plt.tight_layout()



In [230]:
fig, ax = plt.subplots(nrows=1,ncols=Mnum,figsize=(12, 5))
for i in xrange(Mnum):
    plt.subplot(2,Mnum,i+1)
    plt.plot(mu_wavesO[i].reshape(20,8))
    plt.title("Un-whitened")

plt.tight_layout()



In [401]:
z = mu_waves[0]
z.shape
plt.plot(z.reshape(8,20).T)


Out[401]:
[<matplotlib.lines.Line2D at 0x1126e1dd0>,
 <matplotlib.lines.Line2D at 0x1126e6090>,
 <matplotlib.lines.Line2D at 0x1126e62d0>,
 <matplotlib.lines.Line2D at 0x1126e6490>,
 <matplotlib.lines.Line2D at 0x1126e6650>,
 <matplotlib.lines.Line2D at 0x1126e6810>,
 <matplotlib.lines.Line2D at 0x1126e69d0>,
 <matplotlib.lines.Line2D at 0x115215ed0>]

In [408]:
sum([ 0.29715898 , 0.16354807,  0.10884256  ,0.04283791 , 0.02925272,0.02651601,0.02416472 , 0.02237233 , 0.0192831  , 0.01635502])


Out[408]:
0.7503314200000002

Matching pursuit code.

With example data, then real voltage (using default dictionaries, then a dictionary from a GMM/K-means applied to extracted spikes


In [277]:
from sklearn.linear_model import OrthogonalMatchingPursuit
from sklearn.linear_model import OrthogonalMatchingPursuitCV
from sklearn.datasets import make_sparse_coded_signal

n_components, n_features = 512, 100
n_nonzero_coefs = 17

# generate the data
###################

# y = Xw
# |x|_0 = n_nonzero_coefs

y, X, w = make_sparse_coded_signal(n_samples=5,
                                   n_components=n_components,
                                   n_features=n_features,
                                   n_nonzero_coefs=n_nonzero_coefs,
                                   random_state=0)
i = 2
idx, = w[:,i].nonzero()

# distort the clean signal
##########################
y_noisy = y[:,i] + 0.05 * np.random.randn(len(y))

# plot the sparse signal
########################
plt.figure(figsize=(7, 7))
plt.subplot(5, 1, 1)
plt.xlim(0, 512)
plt.title("Sparse signal")
plt.stem(idx, w[idx,i])

# plot the noise-free reconstruction
####################################

omp = OrthogonalMatchingPursuit(n_nonzero_coefs=n_nonzero_coefs)
omp.fit(X, y[:,i])
coef = omp.coef_
idx_r, = coef.nonzero()
plt.subplot(4, 1, 2)
plt.xlim(0, 512)
plt.title("Recovered signal from noise-free measurements")
plt.stem(idx_r, coef[idx_r])

# plot the noisy reconstruction
###############################
omp.fit(X, y_noisy)
coef = omp.coef_
idx_r, = coef.nonzero()
plt.subplot(4, 1, 3)
plt.xlim(0, 512)
plt.title("Recovered signal from noisy measurements")
plt.stem(idx_r, coef[idx_r])

# plot the noisy reconstruction with number of non-zeros set by CV
##################################################################
omp_cv = OrthogonalMatchingPursuitCV()
omp_cv.fit(X, y_noisy)
coef = omp_cv.coef_
idx_r, = coef.nonzero()
plt.subplot(4, 1, 4)
plt.xlim(0, 512)
plt.title("Recovered signal from noisy measurements with CV")
plt.stem(idx_r, coef[idx_r])

plt.subplots_adjust(0.06, 0.04, 0.94, 0.90, 0.20, 0.38)
plt.suptitle('Sparse signal recovery with Orthogonal Matching Pursuit',
             fontsize=16)
plt.show()



In [289]:
omp_cv = OrthogonalMatchingPursuit()
omp_cv.fit(X, y)
coef = omp_cv.coef_
plt.figure(figsize=(15, 5))
for i in range(5):
    idx_r, = coef[i,:].nonzero()
    plt.subplot(5, 2, (i+1)*2-1)
    plt.xlim(0, 512)
   # plt.title("Recovered signal from noisy measurements with CV")
    plt.stem(idx_r, coef[i,idx_r])
    plt.subplot(5,2,(i+1)*2)
    plt.plot(y[:,i])



In [275]:
print y_noisy.shape
print y.shape
print w.shape
print idx.shape
print X.shape
print coef.shape
x = coef.nonzero()
# w[45:80]


(100,)
(100, 5)
(512, 5)
(17,)
(100, 512)
(5, 512)

In [267]:
plt.plot(w)
# plt.plot(y_noisy)
plt.plot(coef)
Tdata.shape


Out[267]:
(1000, 160)

Start with 1D case.


In [291]:
sig = gut[:,0:512]
sig.shape
plt.plot(sig[6,:])


Out[291]:
[<matplotlib.lines.Line2D at 0x119272fd0>]

In [38]:
omp = OrthogonalMatchingPursuit(n_nonzero_coefs=n_nonzero_coefs)
omp.fit(sig, y)
coef = omp.coef_
idx_r, = coef.nonzero()
plt.xlim(0, 512)
plt.title("Recovered signal from noise-free measurements")
plt.stem(idx_r, coef[idx_r])

In [303]:
Tdata = wavesO[0:1000,:]
Tdata -= np.mean(Tdata, axis = 0)
Tdata /= np.std(Tdata, axis = 0)
x = wavesO[3,:].reshape(20,8)
plt.plot(x)
wavesO.shape


Out[303]:
(30072, 160)

DO dictionary learning on extracted spikes, then use this dictionary for the matching pursuit

This is the long version. See the short version below


In [384]:
##########
# EXAMPLE
##########

import time
from sklearn import datasets
from sklearn.cluster import MiniBatchKMeans
from sklearn.feature_extraction.image import extract_patches_2d

faces = datasets.fetch_olivetti_faces()

###############################################################################
# Learn the dictionary of images

print('Learning the dictionary... ')
rng = np.random.RandomState(0)
kmeans = MiniBatchKMeans(n_clusters=81, random_state=rng, verbose=True)
patch_size = (20, 8)

buffer = []
index = 1
t0 = time.time()


Learning the dictionary... 

In [386]:
##########
# EXAMPLE
##########

# The online learning part: cycle over the whole dataset 4 times
index = 0
for _ in range(6):
    for img in faces.images:
        data = extract_patches_2d(img, patch_size, max_patches=50,
                                  random_state=rng)
        data = np.reshape(data, (len(data), -1))
        buffer.append(data)
        index += 1
        if index % 10 == 0:
            data = np.concatenate(buffer, axis=0)
#             print data.shape
            data -= np.mean(data, axis=0)
            data /= np.std(data, axis=0)

            kmeans.partial_fit(data)
            buffer = []
        if index % 100 == 0:
            print('Partial fit of %4i out of %i'
                  % (index, 6 * len(faces.images)))

dt = time.time() - t0
print('done in %.2fs.' % dt)


Partial fit of  100 out of 2400
Partial fit of  200 out of 2400
Partial fit of  300 out of 2400
Partial fit of  400 out of 2400
Partial fit of  500 out of 2400
Partial fit of  600 out of 2400
Partial fit of  700 out of 2400
Partial fit of  800 out of 2400
Partial fit of  900 out of 2400
Partial fit of 1000 out of 2400
Partial fit of 1100 out of 2400
Partial fit of 1200 out of 2400
Partial fit of 1300 out of 2400
Partial fit of 1400 out of 2400
Partial fit of 1500 out of 2400
Partial fit of 1600 out of 2400
Partial fit of 1700 out of 2400
Partial fit of 1800 out of 2400
Partial fit of 1900 out of 2400
Partial fit of 2000 out of 2400
Partial fit of 2100 out of 2400
Partial fit of 2200 out of 2400
Partial fit of 2300 out of 2400
Partial fit of 2400 out of 2400
done in 13.19s.

In [440]:
##########
# REAL
##########

# Online learning of a dictionary for waveform reconstruction, cycling over waveforms
del dataW
print('Learning the dictionary... ')
rng = np.random.RandomState(0)
kmeans = MiniBatchKMeans(n_clusters=10, random_state=rng, verbose=True)
patch_size = (20, 8)

buffer = []
index = 1
t0 = time.time()

index = 0
for _ in range(30): #Trying batches of waveforms, 1000 at a time
    index += 1
    dataW = wavesO[((index+1)*1000)-1000:((index+1)*1000),:]
#     print index
#         if index % 10 == 0:
#             data = np.concatenate(buffer, axis=0)
    dataW -= np.mean(dataW, axis=0)
#     print np.std(dataW, axis=0)
#     x = np.std(dataW, axis=0)
#     dataW /= np.std(dataW, axis=0)
    buffer.append(dataW)
#     print dataW.shape
    kmeans.partial_fit(dataW)
    buffer = []
#             buffer = []
#         if index % 100 == 0:
    print('Partial fit of %4i out of %i'
            % (index, 30))

dt = time.time() - t0
print('done in %.2fs.' % dt)


Learning the dictionary... 
Partial fit of    1 out of 30
Partial fit of    2 out of 30
Partial fit of    3 out of 30
Partial fit of    4 out of 30
Partial fit of    5 out of 30
Partial fit of    6 out of 30
Partial fit of    7 out of 30
Partial fit of    8 out of 30
Partial fit of    9 out of 30
Partial fit of   10 out of 30
Partial fit of   11 out of 30
Partial fit of   12 out of 30
Partial fit of   13 out of 30
Partial fit of   14 out of 30
Partial fit of   15 out of 30
Partial fit of   16 out of 30
Partial fit of   17 out of 30
Partial fit of   18 out of 30
Partial fit of   19 out of 30
Partial fit of   20 out of 30
Partial fit of   21 out of 30
Partial fit of   22 out of 30
Partial fit of   23 out of 30
Partial fit of   24 out of 30
Partial fit of   25 out of 30
Partial fit of   26 out of 30
Partial fit of   27 out of 30
Partial fit of   28 out of 30
Partial fit of   29 out of 30
Partial fit of   30 out of 30
done in 0.10s.

In [446]:
##########
# REAL
##########

###############################################################################
# Plot the results
plt.figure(figsize=(15, 15))
for i, wave in enumerate(kmeans.cluster_centers_):
    plt.subplot(9, 9, i + 1)
#     plt.imshow(wave.reshape(patch_size), cmap=plt.cm.gray,
#                interpolation='nearest')
    plt.plot(wave.reshape(patch_size))
    plt.xticks(())
    plt.yticks(())


plt.suptitle('Dictionary of spikes\nTrain time %.1fs on %d spikes' %
             (dt, 30000), fontsize=16)
plt.subplots_adjust(0.08, 0.02, 0.92, 0.85, 0.08, 0.23)

plt.show()


Quick version, with minibatchdictionary learning and sparse coding with Orthogonal Matching Pursuit


In [448]:
##########################
# Saving the relevant data
##########################
import pickle

pickle.dump(wavesO, open("wavesO.p","wb"))
pickle.dump(gut,open("raw_voltage.p","wb"))

In [ ]:
###########################
# Loading the relevant data
###########################
import pickle
wavesO = pickle.load(open("wavesO.p","rb"))
gut = pickle.load(gut,open("raw_voltage.p","rb"))

In [549]:
#######################################
# Prepare the data for training/testing
#######################################

SL = 20 # sample length
# Training with raw voltage (one channel)
dict_data = gut[6,:]
training_data = dict_data[0:600000]
training_data = training_data.reshape(training_data.shape[0]/SL,SL)
print training_data.shape

# Testing with raw voltage (one channel)
test_data = dict_data[1000001:1600001].reshape(600000/SL, SL)
print test_data.shape

# Training with isolated spikes (one channel)
training_data_s = np.empty((30000,SL))
for i in range(len(training_data_s)):
    training_data_s[i,:] = wavesO[i,:].reshape(SL,8)[:,6]

print training_data_s.shape
# TO DO: scale up to multiple channels (should do better)

# Plot examples of each kind
plt.figure(figsize=(16, 4))
for i in range(5):
    plt.subplot(1, 5, i + 1)
    plt.plot(training_data[i])
    plt.plot(test_data[i])
    plt.plot(training_data_s[i])


(30000, 20)
(30000, 20)
(30000, 20)

In [593]:
######################
# Build the dictionary
######################
from sklearn.decomposition import MiniBatchDictionaryLearning
from time import time
n_components = 128
dico = MiniBatchDictionaryLearning(n_components=n_components, alpha=1, n_iter=200)
t0 = time()
D = dico.fit(training_data).components_
t1 = time() - t0
print('Raw data trained in %d seconds' % t1)

t2 = time()
DS = dico.fit(training_data_s).components_
t3 = time() - t2
print('Spikes trained in %d seconds' % t3)


Raw data trained in 4 seconds
Spikes trained in 4 seconds

In [594]:
######################
# Plot the dictionaries
######################
import itertools
fig, axes = plt.subplots(6,6)
fig.set_size_inches(16,12)
for i, j in itertools.product(range(1,7), range(6)):
    axes[i-1][j].plot(D[6*(i-1)+j])
plt.suptitle('Dictionary from raw voltages',fontsize=16)

fig, axes = plt.subplots(6,6)
fig.set_size_inches(16,12)
for i, j in itertools.product(range(1,7), range(6)):
    axes[i-1][j].plot(DS[6*(i-1)+j])
plt.suptitle('Dictionary from spikes',fontsize=16)


Out[594]:
<matplotlib.text.Text at 0x158e079d0>

In [595]:
#################################################################
# Reconst using sparsity-constrained Orthogonal Matching Pursuit
#################################################################

result = np.ndarray((test_data.shape[0],SL))
print result.shape
from sklearn.decomposition import SparseCoder

coder = SparseCoder(dictionary = D, transform_n_nonzero_coefs=6, transform_alpha=0.8, transform_algorithm="omp")

t0 = time()
result = coder.transform(test_data)
t1 = time() - t0
print('Coded signal using OMP and learned raw-data dictionary in %d seconds.' % t1)

coderS = SparseCoder(dictionary = DS, transform_n_nonzero_coefs=6, transform_alpha=0.8, transform_algorithm="omp")

t2 = time()
resultS = coderS.transform(test_data)
t3 = time() - t2
print('Coded signal using OMP and learned spikes-only dictionary in %d seconds.' % t3)


DC = np.zeros((2*len(D),SL))
DC[0:len(D),:] = D
DC[len(D):2*len(D),:] = DS
print DC.shape

coderC = SparseCoder(dictionary = DC, transform_n_nonzero_coefs=6, transform_alpha=0.8, transform_algorithm="omp")

t4 = time()
resultC = coderC.transform(test_data)
t5 = time() - t4
print('Coded signal using OMP and combined dictionaries in %d seconds.' % t5)


(30000, 20)
Coded signal using OMP and learned raw-data dictionary in 13 seconds.
Coded signal using OMP and learned spikes-only dictionary in 13 seconds.
(256, 20)
Coded signal using OMP and combined dictionaries in 15 seconds.

In [596]:
################################################################
# Generate output from reconstructed result and coded dictionary
################################################################
orig = test_data.reshape(len(test_data)*SL)
# test_data.shape
# orig.shape
# plt.plot(orig[0:100])
# plt.plot(test_data[0,:])
out = np.zeros(orig.shape)
outS = np.zeros(orig.shape)
outC = np.zeros(orig.shape)

print result.shape[0]
for n in range(result.shape[0]):
    out[n*SL:(n+1)*SL] = np.sum(D.T*result[n],axis=1)
    outS[n*SL:(n+1)*SL] = np.sum(DS.T*resultS[n],axis=1)
    outC[n*SL:(n+1)*SL] = np.sum(DC.T*resultC[n],axis=1)


30000

In [597]:
##############################
# Plot waveform reconstruction
##############################
r = ([27200,27400])
fig, axes = plt.subplots(3,3)
fig.set_size_inches(20,8)

# Reconstruction from raw waveform dictionary
axes[0,0].plot(orig[r[0]:r[1]])
axes[1,0].plot(out[r[0]:r[1]], 'g')
axes[2,0].plot((out[r[0]:r[1]]-orig[r[0]:r[1]])**2, 'r')

# Reconstruction from spike-derived dictionary
axes[0,1].plot(orig[r[0]:r[1]])
axes[1,1].plot(outS[r[0]:r[1]], 'g')
axes[2,1].plot((outS[r[0]:r[1]]-orig[r[0]:r[1]])**2, 'r')

# Reconstruction from combined dictionary
axes[0,2].plot(orig[r[0]:r[1]])
axes[1,2].plot(outC[r[0]:r[1]], 'g')
axes[2,2].plot((outC[r[0]:r[1]]-orig[r[0]:r[1]])**2, 'r')


Out[597]:
[<matplotlib.lines.Line2D at 0x156ad8a50>]

TO DO: Construct a custom dictionary from time and amplitude-shifted spikes

Explore different sparseness penalties to ensure good spike idenfication

Work out how to recover dictionary loadings wrt time, to recover neuron identities

Try a combined dictionary of spikes + 'multi unit noise' to see if that helps

Work out if OMP uses linear superposition of atoms (otherwise lose overlapping spike detection)

Finally, try some Pillow/Simoncelli hacks for dictionary/test data choice i.e. more heuristic


In [590]:
test_data.shape[0]
result?

In [582]:
n_components = 72
print n_components
len(D)


72
Out[582]:
72

In [568]:
DC = np.zeros((72,20))
DC[0:36,:] = D
DC[36:72,:] = DS
print DC.shape


(72, 20)

In [ ]:


In [438]:
# index = 2
# x = wavesO[((index+1)*1000)-1000:((index+1)*1000),:]
# print dataW.shape
# print index
# print dataW.shape
# wavesO.shape

x = kmeans.cluster_centers_
x.shape
plt.plot(x[0,:].reshape(20,8))


Out[438]:
[<matplotlib.lines.Line2D at 0x118adefd0>,
 <matplotlib.lines.Line2D at 0x118ade590>,
 <matplotlib.lines.Line2D at 0x118ade650>,
 <matplotlib.lines.Line2D at 0x118ade890>,
 <matplotlib.lines.Line2D at 0x118ace190>,
 <matplotlib.lines.Line2D at 0x118acebd0>,
 <matplotlib.lines.Line2D at 0x118ace390>,
 <matplotlib.lines.Line2D at 0x117a0de10>]

In [427]:
mask = np.isnan(dataW)
x = np.where(mask)
# indices = np.sort(indices[1])
# x?
plt.plot(wavesO[13085,:])


Out[427]:
[<matplotlib.lines.Line2D at 0x1179f85d0>]

In [319]:
print faces.images.shape
print len(faces.images)
print wavesO.shape


(400, 64, 64)
400
(30072, 160)

In [309]:
from sklearn import datasets
faces = datasets.fetch_olivetti_faces()
faces.images.shape
64*64/120
# plt.imshow(faces.images[10])


Out[309]:
34

In [216]:
# MOVING TO BOTTOM
# Let's start with a better threshold using sklearn's kde
# NOTE KDE min is likely very similar to 4*std
# KDE based threshold should be used AFTER spike identification, for each cell, to decide what to keep
from sklearn.neighbors import KernelDensity

x = waves.reshape(waves.size)
x_grid = np.linspace(min(x),max(x), 100)
kde_skl = KernelDensity(bandwidth=0.01)
kde_skl.fit(x[:, np.newaxis])
# score_samples() returns the log-likelihood of the samples
log_pdf = kde_skl.score_samples(x_grid[:, np.newaxis])
fig, ax = plt.subplots(1, 2,figsize=(12, 4))
ax[0].plot(x_grid,log_pdf)
ax[1].hist(x,bins=max(x)-min(x))
print x_grid[np.argmin(log_pdf)]


-362.0

Different parallel session starting code


In [2]:
cluster = Client(profile="default")
lb_view = cluster.load_balanced_view()
cluster.block = True

print "Profile: %s" % cluster.profile
print "Engines: %s" % len(lb_view)
print cluster.ids
print lb_view
def f(x):
    result = 1.0
    for counter in range(100000):
        result = (result * x * 0.5)
        if result % 5 == 0:
            result -=4
    return result
%%timeit -r 1 -n 1
result = []
for i in range(1000):
    result.append(f(i))
%%timeit -r 1 -n 1
result = lb_view.map(f, range(1000), block=True)
lb_view.map(sum,[1,2,3])
%%timeit -r 1 -n 1
result = f.map(range(1000))
print "Results Count: %s" % len(result)
del lb_view


Profile: default
Engines: 4
[0, 1, 2, 3]
<LoadBalancedView None>