Computing all pairwise correlograms within a pool of neurons

Pure Python and Cython code of an algorithm computing all pairwise correlograms between all neurons, given a set of spike trains.

Imports


In [1]:
from itertools import product
import numpy as np

In [2]:
%load_ext cythonmagic

Function definitions

Pure Python


In [3]:
def compute_correlograms(spiketimes, neurons, neurons_to_update=None,
    ncorrbins=100, corrbin=.001):
    """Compute all pairwise cross-correlograms (and auto-correlograms)
    between all neurons within a pool of neurons.
    
    Arguments:

      * spiketimes: a sorted 1D array with all spiketimes,
      * neurons: the neuron index for each spike (same shape as spiketimes),
      * neurons_to_update=None: a 1D array with the neurons for which the
        correlograms need to be computed. The correlograms between these neurons
        and all neurons are computed.
      * ncorrbins=100: the total number of bins in the correlograms. Need to be
        an even number.
      * corrbin=.001: the bin size, in seconds.

    Returns:

      * correlograms: a dictionary `(neuron0, neuron1): correlogram` where
        correlogram is a 1D ncorrbins-long array with spike count values in each
        bin.
    
    """
    # Ensure ncorrbins is an even number.
    assert ncorrbins % 2 == 0
    
    # Compute the histogram corrbins.
    n = ncorrbins // 2
    halfwidth = corrbin * n
    
    # size of the histograms
    nspikes = len(spiketimes)

    # unique neurons
    neurons_unique = np.unique(neurons)
    nneurons = len(neurons_unique)
    neuron_max = neurons_unique[-1]
    
    # neurons to update
    if neurons_to_update is None:
        neurons_to_update = neurons_unique
    neurons_mask = np.zeros(neuron_max + 1, dtype=np.bool)
    neurons_mask[neurons_to_update] = True
    
    # initialize the correlograms
    correlograms = np.zeros(
        ((neuron_max + 1) ** 2, ncorrbins), dtype=np.int32)

    # loop through all spikes, across all neurons, all sorted
    for i in range(nspikes):
        t0, neuron0 = spiketimes[i], neurons[i]
        # pass neurons that do not need to be processed
        if neurons_mask[neuron0]:
            # i, t0, c0: current spike index, spike time, and neuron
            # boundaries of the second loop
            t0min, t0max = t0 - halfwidth, t0 + halfwidth
            j = i + 1
            # go forward in time up to the correlogram half-width
            while j < nspikes:
                t1, neuron1 = spiketimes[j], neurons[j]
                if t1 <= t0max:
                    d = t1 - t0
                    k = int(d / corrbin) + n
                    ind = (neuron_max + 1) * neuron0 + neuron1
                    correlograms[ind, k] += 1
                else:
                    break
                j += 1
            j = i - 1
            # go backward in time up to the correlogram half-width
            while j >= 0:
                t1, neuron1 = spiketimes[j], neurons[j]
                if t0min <= t1:
                    d = t1 - t0
                    k = int(d / corrbin) + n - 1
                    #ind = pairs[(neuron0, neuron1)]
                    ind = (neuron_max + 1) * neuron0 + neuron1
                    correlograms[ind, k] += 1
                else:
                    break
                j -= 1
    return {(neuron0, neuron1): correlograms[(neuron_max + 1) * neuron0 + neuron1,:] for neuron0 in neurons_to_update for neuron1 in neurons_unique}

Cython version


In [4]:
%%cython

import numpy as np
cimport numpy as np
DTYPE = np.double
ctypedef np.double_t DTYPE_t
DTYPEI = np.int
ctypedef np.int_t DTYPEI_t

def compute_correlograms_cython(
     np.ndarray[DTYPE_t, ndim=1] spiketimes,
     np.ndarray[DTYPEI_t, ndim=1] neurons,
     np.ndarray[DTYPEI_t, ndim=1] neurons_to_update=None,
     int ncorrbins=100,
     double corrbin=.001):
    
    # Ensure ncorrbins is an even number.
    assert ncorrbins % 2 == 0
    
    # Compute the histogram corrbins.
    cdef int n = ncorrbins // 2
    cdef double halfwidth = corrbin * n
    
    # size of the histograms
    cdef int nspikes = len(spiketimes)
    
    cdef int i, j, neuron0, neuron1, k, ind
    cdef float t0, t1, t0min, t0max, d

    # Unique neurons
    cdef np.ndarray[DTYPEI_t, ndim=1] neurons_unique = np.unique(neurons)
    cdef int nneurons = len(neurons_unique)
    cdef int neuron_max = neurons_unique[-1]
    
    # neurons to update
    if neurons_to_update is None:
        neurons_to_update = neurons_unique
    cdef np.ndarray[DTYPEI_t, ndim=1] neurons_mask = np.zeros(neuron_max + 1, dtype=DTYPEI)
    neurons_mask[neurons_to_update] = 1
    
    # initialize the correlograms
    cdef np.ndarray[DTYPEI_t, ndim=2] correlograms = np.zeros(
        ((neuron_max + 1) ** 2, ncorrbins), dtype=DTYPEI)

    # loop through all spikes, across all neurons, all sorted
    for i in xrange(nspikes):
        t0, neuron0 = spiketimes[i], neurons[i]
        # pass neurons that do not need to be processed
        if neurons_mask[neuron0]:
            # i, t0, c0: current spike index, spike time, and neuron
            # boundaries of the second loop
            t0min, t0max = t0 - halfwidth, t0 + halfwidth
            j = i + 1
            # go forward in time up to the correlogram half-width
            while j < nspikes:
                t1, neuron1 = spiketimes[j], neurons[j]
                if t1 < t0max:
                    d = t1 - t0
                    k = int(d / corrbin) + n
                    ind = (neuron_max + 1) * neuron0 + neuron1
                    correlograms[ind, k] += 1
                else:
                    break
                j += 1
            j = i - 1
            # go backward in time up to the correlogram half-width
            while j >= 0:
                t1, neuron1 = spiketimes[j], neurons[j]
                if t0min < t1:
                    d = t1 - t0
                    k = int(d / corrbin) + n - 1
                    ind = (neuron_max + 1) * neuron0 + neuron1
                    correlograms[ind, k] += 1
                else:
                    break
                j -= 1
    return {(neuron0, neuron1): correlograms[(neuron_max + 1) * neuron0 + neuron1,:] for neuron0 in neurons_to_update for neuron1 in neurons_unique}

Tests

Computing all pairwise correlograms within a pool of 100 neurons, with 100 time steps (+/- 50 ms, 1 ms bin), for a total of 100000 spikes.


In [5]:
nspikes = 100000
nneurons = 100
ncorrbins = 100
corrbin = .001

In [6]:
spiketimes = np.cumsum(np.random.exponential(scale=.005, size=nspikes))
neurons = np.random.randint(low=0, high=nneurons - 1, size=nspikes)

Plotting a few correlograms.


In [7]:
correlograms = compute_correlograms_cython(spiketimes, neurons,
                                     ncorrbins=ncorrbins, corrbin=corrbin)
for i in xrange(4):
    for j in xrange(4):
        subplot(4, 4, i * 4 + j + 1);
        bar(arange(100) - 50, correlograms[i,j], width=1, ec='none');
        xlim(-50,50);
        grid();


Benchmarks

Pure Python version.


In [8]:
%%timeit -n 2 -r 2
correlograms = compute_correlograms(spiketimes, neurons,
                                     ncorrbins=ncorrbins, corrbin=corrbin)


2 loops, best of 2: 21.8 s per loop

Cython version.


In [9]:
%%timeit -n 2 -r 2
correlograms = compute_correlograms_cython(spiketimes, neurons,
                                    ncorrbins=ncorrbins, corrbin=corrbin)


2 loops, best of 2: 788 ms per loop