In [ ]:
# this function code has to be debugged further with matrix dimension matching
from itertools import combinations
import numpy as np

def synchrony_analysis_efficient(spikeMat, numCoactive):
    minRpts = 1

    #numSteps is the total number of frames through all 20 trials
    [numNeurons, numSteps] = spikeMat.shape
    popRate = spikeMat.sum(axis=0) # sum all the spike numbers across rows

    binaryWords = []
    frequency = []
    for t in range(numSteps): # numSteps = 20*426 = 8520
   
        if popRate[t] >= numCoactive: # match requirement of coactive number of neurons
            activeCells = np.nonzeros(spikeMat[:,t]>0) # return the indices of the neurons firing at time t 

            # enumerate all permutations
            c = combinations(range(activeCells), numCoactive) # return all the possible combinations of fired cell assembly
            if numCoactive > 1:
                binaryWords.append(c)
            

    # frequency returns the counts of firings of cell assembly together through 20 trials
    frequency = np.zeros((len(binaryWords),1))  # of cell assembly by 1 
    for i in range(len(binaryWords)):
        frequency[i] = np.nonzeros(spikeMat[binaryWords[i,:],:].sum()>=numCoactive).sum()

    ## merge
    idx= np.nonzeros(frequency==1) 
    binaryWordsNew = binaryWords[idx, :]
    frequencyNew = frequency[idx]
    
    for numRpts in range(2, max(frequency)):
        idx= np.nonzeros(frequency==numRpts)
        spikePatterns = binaryWords[idx,:]
        
        uniqueSpikes = np.zeros((len(idx),numCoactive))
        k=1 # the next two for loops is to remove duplicates from spikePattern
     
        for i in range(len(spikePatterns)):
            presence = 0
            for j in range(k):
                if max(abs(spikePatterns[i,:]-uniqueSpikes[j,:]))==0:
                    presence = 1
                    break
            if presence ==0:
                uniqueSpikes[k,:] = spikePatterns[i,:]
                k = k + 1
       
        
        uniqueSpikes = uniqueSpikes[1:k-1,:]
    
        binaryWordsNew.append(uniqueSpikes)
        frequencyNew.append(np.ones((len(uniqueSpikes),1))*numRpts)
 
    # returned frequency is the firing frequency for each cell assembly
    # returned binaryWords is the corresponding cell indices of cell assembly
    frequency = frequencyNew
    binaryWords = binaryWordsNew

    idx = np.nonzeros(frequency>=minRpts)
    frequency = frequency[idx]
    binaryWords = binaryWords[idx,:]
    
    return frequency, binaryWords