In [ ]:
# this function code has to be debugged further with matrix dimension matching
from scipy import signal

def calculate_subset_index(spikeMat, numFramesPerStim):
    numNeuron = spikeMat.shape[1]
    numFrames = spikeMat.shape[2]
    numRpts = spikeMat.shape[0] 
    
    spikeEarly = np.zeros(numNeuron, numFramesPerStim)
    
    # spikeEarly only count the spikes in the first trial
    spikeEarly = spikeMat[0]
    
    timewindow = 1
    
    sharedNeuronsAll = np.zeros(4)
    numNeuronsLateAll = np.zeros(4)
    tt = 0
    
    for rep in range(17,21): # compute trials from 17 to 20
        spikeLate = spikeMat[rep]
        spikeLate = signal.convolve2d(spikeLate, np.ones((1, timewindow)),'same')
        spikesLate[np.nonzero(spikesLate>0)] = 1
        # find all the cells both fire both at the 1st trial and the last trial
        sharedNeurons= numpy.multiply(spikeEarly, spikeLate).sum(axis=0)
        numNeuronLate = spikeLate.sum(axis=0)
        
        sharedNeuronsAll[tt] = sharedNeurons.sum()
        numNeuronsLateAll[tt] = numNeuronsLate.sum()
        tt = tt+1
    
    # calculate the subsetindex for trials from 17 to 20, the last 4 trials    
    subsetIndex = sharedNeuronsAll.sum()/numNeuronsLateAll.sum()
    return subsetIndex