Score multi-session events using the replay score from Davidson et al.


In [3]:
import numpy as np
import os
import pandas as pd
import warnings

import nelpy as nel

warnings.filterwarnings("ignore")

Load experimental data


In [4]:
datadirs = ['/home/etienne/Dropbox/neoReader/',
            'C:/Etienne/Dropbox/neoReader/',
            'D:/Dropbox/neoReader/']

fileroot = next( (dir for dir in datadirs if os.path.isdir(dir)), None)
# conda install pandas=0.19.2
if fileroot is None:
    raise FileNotFoundError('datadir not found')

load_from_nel = True

# load from nel file:
if load_from_nel:
    jar = nel.load_pkl(fileroot + 'gor01vvp01_processed_speed.nel')
    exp_data = jar.exp_data
    aux_data = jar.aux_data
    del jar
    
    jar = nel.load_pkl(fileroot + 'gor01vvp01_tables_speed.nel')
    df = jar.df
    df2 = jar.df2
    del jar

Define subset of sessions to score


In [5]:
# restrict sessions to explore to a smaller subset
min_n_placecells = 16
min_n_PBEs = 27 # 27 total events ==> minimum 21 events in training set

df2_subset = df2[(df2.n_PBEs >= min_n_PBEs) & (df2.n_placecells >= min_n_placecells)]

sessions = df2_subset['time'].values.tolist()
segments = df2_subset['segment'].values.tolist()

print('Evaluating subset of {} sessions'.format(len(sessions)))

df2_subset.sort(columns=['n_PBEs', 'n_placecells'], ascending=[0,0])


Evaluating subset of 20 sessions
Out[5]:
animal month day time track segment duration n_cells n_placecells n_PBEs Notes prescreen_z
58 gor01 6 9 22-24-40 two short 1620.0 203 61 301 NaN NaN
35 gor01 6 7 16-40-19 two short 1330.0 117 46 277 NaN NaN
36 gor01 6 7 16-40-19 two long 1180.0 117 43 150 NaN NaN
73 gor01 6 9 1-22-43 one long 1012.0 203 62 117 NaN NaN
59 gor01 6 9 22-24-40 two long 912.0 203 66 103 NaN NaN
72 gor01 6 9 1-22-43 one short 617.0 203 71 91 NaN NaN
6 gor01 6 8 21-16-25 two long 720.0 171 82 57 NaN NaN
0 vvp01 4 9 17-29-30 one short 490.0 68 32 46 NaN NaN
83 vvp01 4 9 16-40-54 two long 861.0 41 25 42 NaN NaN
5 gor01 6 8 21-16-25 two short 457.0 171 64 37 NaN NaN
82 vvp01 4 9 16-40-54 two short 485.0 41 16 37 NaN NaN
14 gor01 6 12 16-53-46 two long 470.0 81 36 36 NaN NaN
28 vvp01 4 25 14-28-51 one long 385.0 80 32 36 NaN NaN
51 vvp01 4 16 15-12-23 one short 347.0 74 24 34 NaN NaN
1 vvp01 4 9 17-29-30 one long 800.0 68 34 33 NaN NaN
61 gor01 6 12 15-55-31 one long 660.0 97 40 32 NaN NaN
75 gor01 6 13 15-22-3 two long 530.0 82 37 31 NaN NaN
50 gor01 6 13 14-42-6 one long 520.0 109 36 29 NaN NaN
41 vvp01 4 17 12-52-15 two short 328.0 59 26 28 NaN NaN
60 gor01 6 12 15-55-31 one short 410.0 97 44 27 NaN NaN

In [52]:
m1 = [58]
m2 = [35]
m3 = [73,59]
m4 = [72,0,83,5]
m5 = [6,36]
m6 = [82,14,28,51,1,61]
m7 = [75,41]
m8 = [50,60]
print('machine 1 (AMZ): {}'.format(np.sum([df2_subset.loc[mid].n_PBEs for mid in m1])))
print('machine 2 (AMZ): {}'.format(np.sum([df2_subset.loc[mid].n_PBEs for mid in m2])))
print('machine 3 (AMZ): {}'.format(np.sum([df2_subset.loc[mid].n_PBEs for mid in m3])))
print('machine 4 (AMZ): {}'.format(np.sum([df2_subset.loc[mid].n_PBEs for mid in m4])))
print('machine 5 (AMZ): {}'.format(np.sum([df2_subset.loc[mid].n_PBEs for mid in m5])))
print('machine 6 (AMZ): {}'.format(np.sum([df2_subset.loc[mid].n_PBEs for mid in m6])))

print('machine 7 (ETN): {}'.format(np.sum([df2_subset.loc[mid].n_PBEs for mid in m7])))
print('machine 8 (ALX): {}'.format(np.sum([df2_subset.loc[mid].n_PBEs for mid in m8])))


machine 1 (AMZ): 301
machine 2 (AMZ): 277
machine 3 (AMZ): 220
machine 4 (AMZ): 216
machine 5 (AMZ): 207
machine 6 (AMZ): 208
machine 7 (ETN): 59
machine 8 (ALX): 56

Parallel scoring

NOTE: it is relatively easy (syntax-wise) to score each session as a parallel task, but since the Bayesian scoring takes such a long time to compute, we can be more efficient (higher % utilization) by further parallelizing over events, and not just over sessions. This further level of parallelization makes the bookkeeping a little ugly, so I provide the code for both approaches here.


In [22]:
# Amazon c4.8xlarge
from nelpy.utils import PrettyDuration

n_cores = 36*7
n_shuff = 5000
n_samp  = 35000

print('It is estimated to take {} to score 1500 events using {} cores'. format(
    PrettyDuration(32*4*60*60/4200000 * n_shuff * n_samp / n_cores), n_cores))
print('It is estimated to take {} to score {} events using {} cores'. format(
    PrettyDuration(32*4*60*60/4200000 * n_shuff * n_samp / 1500), n_cores, n_cores))


It is estimated to take 21:09:50:476 hours to score 1500 events using 252 cores
It is estimated to take 3:33:20 hours to score 252 events using 252 cores

In [15]:
# Etienne's local machine
from nelpy.utils import PrettyDuration

n_cores = 8
n_shuff = 5000
n_samp  = 35000

print('It is estimated to take {} to score 1500 events using {} cores'. format(
    PrettyDuration(21*4*60*60/4200000 * n_shuff * n_samp / n_cores), n_cores))
print('It is estimated to take {} to score {} events using {} cores'. format(
    PrettyDuration(21*4*60*60/4200000 * n_shuff * n_samp / 1500), n_cores, n_cores))


It is estimated to take 18 days 5:30:00 hours to score 1500 events using 8 cores
It is estimated to take 2:20:00 hours to score 8 events using 8 cores

In [23]:
parallelize_by_session = False
parallelize_by_event = True

n_jobs = 8 # set this equal to number of cores
n_shuffles = 5000
n_samples = 35000
w=3 # single sided bandwidth (0 means only include bin who's center is under line, 3 means a total of 7 bins)

In [21]:
if parallelize_by_session:
    from joblib import Parallel, delayed 

    # A function that can be called to do work:
    def work_sessions(arg):    

        # Split the list to individual variables:
        ii, bst, tc = arg    

        scores, shuffled_scores, percentiles = nel.analysis.replay.score_Davidson_final_bst_fast(bst=bst,
                                                                                            tuningcurve=tc,
                                                                                            w=w,
                                                                                            n_shuffles=n_shuffles,
                                                                                            n_samples=n_samples)

        return (ii, scores, shuffled_scores, percentiles)

    # List of instances to pass to work():
    parallel_sessions = [(ii, aux_data[session][segment]['PBEs'], aux_data[session][segment]['tc']) for (ii, (session, segment)) in enumerate(zip(sessions, segments))]

    # Anything returned by work() can be stored:
    parallel_results = Parallel(n_jobs=n_jobs, verbose=51)(map(delayed(work_sessions), parallel_sessions))

    # standardize parallel results
    idx = [result[0] for result in parallel_results]

    # check that parallel results came back IN ORDER:
    if nel.utils.is_sorted(idx):
        print('parallel results are ordered...')
    else:
        raise ValueError('results are not ordered! handle it here before proceeding...')

    scores_bayes = [result[1] for result in parallel_results]
    scores_bayes_shuffled = [result[2] for result in parallel_results]
    scores_bayes_percentile = [result[3] for result in parallel_results]

    results = dict()
    for ii, (session, segment) in enumerate(zip(sessions, segments)):
        try:
            results[session][segment] = dict()
        except KeyError:
            results[session] = dict()    
            results[session][segment] = dict()
        results[session][segment]['scores_bayes'] = scores_bayes[ii]
        results[session][segment]['scores_bayes_shuffled'] = scores_bayes_shuffled[ii]
        
        results[session][segment]['scores_bayes_percentile'] = scores_bayes_percentile[ii]

    print('done packing results')

In [24]:
if parallelize_by_event:
    from joblib import Parallel, delayed 

    # A function that can be called to do work:
    def work_events(arg):    

        # Split the list to individual variables:
        session, segment, ii, bst, tc = arg
        scores, shuffled_scores, percentiles = nel.analysis.replay.score_Davidson_final_bst_fast(bst=bst,
                                                                                            tuningcurve=tc,
                                                                                            w=w,
                                                                                            n_shuffles=n_shuffles,
                                                                                            n_samples=n_samples)

        return (session, segment, ii, scores, shuffled_scores, percentiles)

    # List of instances to pass to work():
    # unroll all events:
    parallel_events = []
    for session, segment in zip(sessions, segments):
        for nn in range(aux_data[session][segment]['PBEs'].n_epochs):
            parallel_events.append((session, segment, nn, aux_data[session][segment]['PBEs'][nn], aux_data[session][segment]['tc']))

    # Anything returned by work() can be stored:
    parallel_results = Parallel(n_jobs=n_jobs, verbose=51)(map(delayed(work_events), parallel_events))

    # standardize parallel results
    bdries_ = [aux_data[session][segment]['PBEs'].n_epochs for session, segment in zip(sessions, segments) ]
    bdries = np.cumsum(np.insert(bdries_,0,0))
    bdries

    sessions_ = np.array([result[0] for result in parallel_results])
    segments_ = np.array([result[1] for result in parallel_results])
    idx = [result[2] for result in parallel_results]

    scores_bayes_evt = np.array([float(result[3]) for result in parallel_results])
    scores_bayes_shuffled_evt = np.array([result[4].squeeze() for result in parallel_results])
    scores_bayes_percentile_evt = np.array([float(result[5]) for result in parallel_results])

    results = {}
    for nn in range(len(bdries)-1):
        session = np.unique(sessions_[bdries[nn]:bdries[nn+1]])
        if len(session) > 1:
            raise ValueError("parallel results in different format / order than expected!")
        session = session[0]
        segment = np.unique(segments_[bdries[nn]:bdries[nn+1]])
        if len(segment) > 1:
            raise ValueError("parallel results in different format / order than expected!")
        segment = segment[0]
        try:
            results[session][segment]['scores_bayes'] = scores_bayes_evt[bdries[nn]:bdries[nn+1]]
        except KeyError:
            try:
                results[session][segment] = dict()
                results[session][segment]['scores_bayes'] = scores_bayes_evt[bdries[nn]:bdries[nn+1]]
            except KeyError:
                results[session] = dict()
                results[session][segment] = dict()
                results[session][segment]['scores_bayes'] = scores_bayes_evt[bdries[nn]:bdries[nn+1]]

        results[session][segment]['scores_bayes_shuffled'] = scores_bayes_shuffled_evt[bdries[nn]:bdries[nn+1]]
        results[session][segment]['scores_bayes_percentile'] = scores_bayes_percentile_evt[bdries[nn]:bdries[nn+1]]

    print('done packing results')


[Parallel(n_jobs=8)]: Done   1 tasks      | elapsed:    0.1s
[Parallel(n_jobs=8)]: Batch computation too fast (0.0578s.) Setting batch_size=6.
[Parallel(n_jobs=8)]: Done   2 tasks      | elapsed:    0.1s
[Parallel(n_jobs=8)]: Done   3 tasks      | elapsed:    0.1s
[Parallel(n_jobs=8)]: Done   4 tasks      | elapsed:    0.1s
[Parallel(n_jobs=8)]: Done   5 tasks      | elapsed:    0.1s
[Parallel(n_jobs=8)]: Done   6 tasks      | elapsed:    0.1s
[Parallel(n_jobs=8)]: Done   7 tasks      | elapsed:    0.1s
[Parallel(n_jobs=8)]: Done   8 tasks      | elapsed:    0.1s
[Parallel(n_jobs=8)]: Done   9 tasks      | elapsed:    0.1s
[Parallel(n_jobs=8)]: Done  10 tasks      | elapsed:    0.1s
[Parallel(n_jobs=8)]: Done  11 tasks      | elapsed:    0.1s
[Parallel(n_jobs=8)]: Done  12 tasks      | elapsed:    0.1s
[Parallel(n_jobs=8)]: Done  13 tasks      | elapsed:    0.1s
[Parallel(n_jobs=8)]: Done  14 tasks      | elapsed:    0.1s
[Parallel(n_jobs=8)]: Done  15 tasks      | elapsed:    0.1s
[Parallel(n_jobs=8)]: Done  16 tasks      | elapsed:    0.1s
[Parallel(n_jobs=8)]: Done  22 tasks      | elapsed:    0.2s
[Parallel(n_jobs=8)]: Batch computation too fast (0.1443s.) Setting batch_size=16.
[Parallel(n_jobs=8)]: Done  28 tasks      | elapsed:    0.2s
[Parallel(n_jobs=8)]: Done  34 tasks      | elapsed:    0.2s
[Parallel(n_jobs=8)]: Done  40 tasks      | elapsed:    0.2s
[Parallel(n_jobs=8)]: Done  46 tasks      | elapsed:    0.2s
[Parallel(n_jobs=8)]: Done  52 tasks      | elapsed:    0.2s
[Parallel(n_jobs=8)]: Done  58 tasks      | elapsed:    0.2s
[Parallel(n_jobs=8)]: Done  64 tasks      | elapsed:    0.2s
[Parallel(n_jobs=8)]: Done  70 tasks      | elapsed:    0.3s
[Parallel(n_jobs=8)]: Done  76 tasks      | elapsed:    0.3s
[Parallel(n_jobs=8)]: Done  82 tasks      | elapsed:    0.3s
[Parallel(n_jobs=8)]: Done  88 tasks      | elapsed:    0.3s
[Parallel(n_jobs=8)]: Done  94 tasks      | elapsed:    0.3s
[Parallel(n_jobs=8)]: Done 100 tasks      | elapsed:    0.3s
[Parallel(n_jobs=8)]: Done 106 tasks      | elapsed:    0.3s
[Parallel(n_jobs=8)]: Done 112 tasks      | elapsed:    0.3s
[Parallel(n_jobs=8)]: Done 128 tasks      | elapsed:    0.6s
[Parallel(n_jobs=8)]: Done 144 tasks      | elapsed:    0.6s
[Parallel(n_jobs=8)]: Done 160 tasks      | elapsed:    0.6s
[Parallel(n_jobs=8)]: Done 176 tasks      | elapsed:    0.6s
[Parallel(n_jobs=8)]: Done 192 tasks      | elapsed:    0.6s
[Parallel(n_jobs=8)]: Done 208 tasks      | elapsed:    0.6s
[Parallel(n_jobs=8)]: Done 224 tasks      | elapsed:    0.6s
[Parallel(n_jobs=8)]: Done 240 tasks      | elapsed:    0.6s
[Parallel(n_jobs=8)]: Done 256 tasks      | elapsed:    0.8s
[Parallel(n_jobs=8)]: Done 272 tasks      | elapsed:    0.8s
[Parallel(n_jobs=8)]: Done 288 tasks      | elapsed:    0.9s
[Parallel(n_jobs=8)]: Done 304 tasks      | elapsed:    0.9s
[Parallel(n_jobs=8)]: Done 320 tasks      | elapsed:    0.9s
[Parallel(n_jobs=8)]: Done 336 tasks      | elapsed:    0.9s
[Parallel(n_jobs=8)]: Done 352 tasks      | elapsed:    0.9s
[Parallel(n_jobs=8)]: Done 368 tasks      | elapsed:    0.9s
[Parallel(n_jobs=8)]: Done 384 tasks      | elapsed:    1.0s
[Parallel(n_jobs=8)]: Done 400 tasks      | elapsed:    1.1s
[Parallel(n_jobs=8)]: Done 416 tasks      | elapsed:    1.1s
[Parallel(n_jobs=8)]: Done 432 tasks      | elapsed:    1.1s
[Parallel(n_jobs=8)]: Done 448 tasks      | elapsed:    1.1s
[Parallel(n_jobs=8)]: Done 464 tasks      | elapsed:    1.1s
[Parallel(n_jobs=8)]: Done 480 tasks      | elapsed:    1.1s
[Parallel(n_jobs=8)]: Done 496 tasks      | elapsed:    1.2s
[Parallel(n_jobs=8)]: Done 512 tasks      | elapsed:    1.3s
[Parallel(n_jobs=8)]: Done 528 tasks      | elapsed:    1.3s
[Parallel(n_jobs=8)]: Done 544 tasks      | elapsed:    1.4s
[Parallel(n_jobs=8)]: Done 560 tasks      | elapsed:    1.4s
[Parallel(n_jobs=8)]: Done 576 tasks      | elapsed:    1.4s
[Parallel(n_jobs=8)]: Done 592 tasks      | elapsed:    1.4s
[Parallel(n_jobs=8)]: Done 608 tasks      | elapsed:    1.4s
[Parallel(n_jobs=8)]: Done 624 tasks      | elapsed:    1.5s
[Parallel(n_jobs=8)]: Done 640 tasks      | elapsed:    1.6s
[Parallel(n_jobs=8)]: Done 656 tasks      | elapsed:    1.6s
[Parallel(n_jobs=8)]: Done 672 tasks      | elapsed:    1.6s
[Parallel(n_jobs=8)]: Done 688 tasks      | elapsed:    1.6s
[Parallel(n_jobs=8)]: Done 704 tasks      | elapsed:    1.7s
[Parallel(n_jobs=8)]: Done 720 tasks      | elapsed:    1.7s
[Parallel(n_jobs=8)]: Done 736 tasks      | elapsed:    1.7s
[Parallel(n_jobs=8)]: Done 752 tasks      | elapsed:    1.8s
[Parallel(n_jobs=8)]: Done 768 tasks      | elapsed:    1.8s
[Parallel(n_jobs=8)]: Done 784 tasks      | elapsed:    1.9s
[Parallel(n_jobs=8)]: Done 800 tasks      | elapsed:    1.9s
[Parallel(n_jobs=8)]: Done 816 tasks      | elapsed:    2.0s
[Parallel(n_jobs=8)]: Done 832 tasks      | elapsed:    2.0s
[Parallel(n_jobs=8)]: Done 848 tasks      | elapsed:    2.0s
[Parallel(n_jobs=8)]: Done 864 tasks      | elapsed:    2.1s
[Parallel(n_jobs=8)]: Done 880 tasks      | elapsed:    2.2s
[Parallel(n_jobs=8)]: Done 896 tasks      | elapsed:    2.2s
[Parallel(n_jobs=8)]: Done 912 tasks      | elapsed:    2.2s
[Parallel(n_jobs=8)]: Done 928 tasks      | elapsed:    2.2s
[Parallel(n_jobs=8)]: Done 944 tasks      | elapsed:    2.3s
[Parallel(n_jobs=8)]: Done 960 tasks      | elapsed:    2.3s
[Parallel(n_jobs=8)]: Done 976 tasks      | elapsed:    2.4s
[Parallel(n_jobs=8)]: Done 992 tasks      | elapsed:    2.4s
[Parallel(n_jobs=8)]: Done 1008 tasks      | elapsed:    2.5s
[Parallel(n_jobs=8)]: Done 1024 tasks      | elapsed:    2.5s
[Parallel(n_jobs=8)]: Done 1040 tasks      | elapsed:    2.5s
[Parallel(n_jobs=8)]: Done 1056 tasks      | elapsed:    2.6s
[Parallel(n_jobs=8)]: Done 1072 tasks      | elapsed:    2.7s
[Parallel(n_jobs=8)]: Done 1088 tasks      | elapsed:    2.7s
[Parallel(n_jobs=8)]: Done 1104 tasks      | elapsed:    2.7s
[Parallel(n_jobs=8)]: Done 1120 tasks      | elapsed:    2.7s
[Parallel(n_jobs=8)]: Done 1136 tasks      | elapsed:    2.8s
[Parallel(n_jobs=8)]: Done 1152 tasks      | elapsed:    2.8s
[Parallel(n_jobs=8)]: Done 1168 tasks      | elapsed:    2.8s
[Parallel(n_jobs=8)]: Done 1184 tasks      | elapsed:    2.9s
[Parallel(n_jobs=8)]: Done 1200 tasks      | elapsed:    2.9s
[Parallel(n_jobs=8)]: Done 1216 tasks      | elapsed:    3.0s
[Parallel(n_jobs=8)]: Done 1232 tasks      | elapsed:    3.0s
[Parallel(n_jobs=8)]: Done 1248 tasks      | elapsed:    3.0s
[Parallel(n_jobs=8)]: Done 1264 tasks      | elapsed:    3.1s
[Parallel(n_jobs=8)]: Done 1280 tasks      | elapsed:    3.1s
[Parallel(n_jobs=8)]: Done 1296 tasks      | elapsed:    3.2s
[Parallel(n_jobs=8)]: Done 1312 tasks      | elapsed:    3.2s
[Parallel(n_jobs=8)]: Done 1544 out of 1544 | elapsed:    3.7s finished
done packing results

Save results to disk


In [ ]:
jar = nel.ResultsContainer(results=results, description='gor01 and vvp01 speed restricted results for best 20 candidate sessions')
jar.save_pkl('score_bayes_all_sessions.nel')