This notebook has some profiling of Dask used to make a selection along both first and second axes of a large-ish multidimensional array. The use case is making selections of genotype data, e.g., as required for making a web-browser for genotype data as in www.malariagen.net/apps/ag1000g.


In [1]:
import zarr; print('zarr', zarr.__version__)
import dask; print('dask', dask.__version__)
import dask.array as da
import numpy as np


zarr 2.1.1
dask 0.11.0

Real data


In [2]:
# here's the real data
callset = zarr.open_group('/kwiat/2/coluzzi/ag1000g/data/phase1/release/AR3.1/variation/main/zarr2/zstd/ag1000g.phase1.ar3',
                          mode='r')
callset


Out[2]:
Group(/, 8)
  arrays: 1; samples
  groups: 7; 2L, 2R, 3L, 3R, UNKN, X, Y_unplaced
  store: DirectoryStore

In [3]:
# here's the array we're going to work with
g = callset['3R/calldata/genotype']
g


Out[3]:
Array(/3R/calldata/genotype, (22632425, 765, 2), int8, chunks=(13107, 40, 2), order=C)
  nbytes: 32.2G; nbytes_stored: 1.0G; ratio: 31.8; initialized: 34540/34540
  compressor: Blosc(cname='zstd', clevel=1, shuffle=2)
  store: DirectoryStore

In [4]:
# wrap as dask array with very simple chunking of first dim only
%time gd = da.from_array(g, chunks=(g.chunks[0], None, None))
gd


CPU times: user 4 ms, sys: 0 ns, total: 4 ms
Wall time: 5.13 ms
Out[4]:
dask.array<array-b..., shape=(22632425, 765, 2), dtype=int8, chunksize=(13107, 765, 2)>

In [5]:
# load condition used to make selection on first axis
dim0_condition = callset['3R/variants/FILTER_PASS'][:]
dim0_condition.shape, dim0_condition.dtype, np.count_nonzero(dim0_condition)


Out[5]:
((22632425,), dtype('bool'), 13167162)

In [6]:
# invent a random selection for second axis
dim1_indices = sorted(np.random.choice(765, size=100, replace=False))

In [7]:
# setup the 2D selection - this is the slow bit
%time gd_sel = gd[dim0_condition][:, dim1_indices]
gd_sel


CPU times: user 15.3 s, sys: 256 ms, total: 15.5 s
Wall time: 15.5 s
Out[7]:
dask.array<getitem..., shape=(13167162, 100, 2), dtype=int8, chunksize=(8873, 100, 2)>

In [23]:
# now load a slice from this new selection - quick!
%time gd_sel[1000000:1100000].compute(optimize_graph=False)


CPU times: user 1.21 s, sys: 152 ms, total: 1.36 s
Wall time: 316 ms
Out[23]:
array([[[0, 0],
        [0, 0],
        [0, 0],
        ..., 
        [0, 0],
        [0, 0],
        [0, 0]],

       [[0, 0],
        [0, 0],
        [0, 0],
        ..., 
        [0, 0],
        [0, 0],
        [0, 0]],

       [[0, 0],
        [0, 0],
        [0, 0],
        ..., 
        [0, 0],
        [0, 0],
        [0, 0]],

       ..., 
       [[0, 0],
        [0, 0],
        [0, 0],
        ..., 
        [0, 1],
        [0, 0],
        [0, 0]],

       [[0, 0],
        [0, 0],
        [0, 0],
        ..., 
        [0, 0],
        [0, 0],
        [0, 0]],

       [[0, 0],
        [0, 0],
        [0, 0],
        ..., 
        [0, 0],
        [0, 0],
        [0, 0]]], dtype=int8)

In [9]:
# what's taking so long?
import cProfile
cProfile.run('gd[dim0_condition][:, dim1_indices]', sort='time')


         105406881 function calls (79072145 primitive calls) in 26.182 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
13167268/6    6.807    0.000    9.038    1.506 slicing.py:623(check_index)
        2    4.713    2.356    5.831    2.916 slicing.py:398(partition_by_size)
13167270/2    4.470    0.000    8.763    4.382 slicing.py:540(posify_index)
 52669338    4.118    0.000    4.119    0.000 {built-in method builtins.isinstance}
        2    2.406    1.203    8.763    4.382 slicing.py:563(<listcomp>)
        1    0.875    0.875    0.875    0.875 slicing.py:44(<listcomp>)
 13182474    0.600    0.000    0.600    0.000 {built-in method builtins.len}
        2    0.527    0.264    0.527    0.264 slicing.py:420(issorted)
 13189168    0.520    0.000    0.520    0.000 {method 'append' of 'list' objects}
        2    0.271    0.136    0.271    0.136 slicing.py:479(<listcomp>)
        2    0.220    0.110    0.220    0.110 {built-in method builtins.sorted}
        1    0.162    0.162    0.162    0.162 {method 'tolist' of 'numpy.ndarray' objects}
        2    0.113    0.056   26.071   13.035 core.py:1024(__getitem__)
        2    0.112    0.056    6.435    3.217 slicing.py:441(take_sorted)
        1    0.111    0.111   26.182   26.182 <string>:1(<module>)
        2    0.060    0.030   24.843   12.422 slicing.py:142(slice_with_newaxes)
    106/3    0.039    0.000    1.077    0.359 slicing.py:15(sanitize_index)
        3    0.037    0.012    0.037    0.012 {built-in method _hashlib.openssl_md5}
     6726    0.012    0.000    0.017    0.000 slicing.py:567(insert_many)
     3364    0.004    0.000    0.021    0.000 slicing.py:156(<genexpr>)
    20178    0.003    0.000    0.003    0.000 {method 'pop' of 'list' objects}
        8    0.000    0.000    0.000    0.000 {method 'update' of 'dict' objects}
        2    0.000    0.000   25.920   12.960 slicing.py:60(slice_array)
        2    0.000    0.000    0.000    0.000 slicing.py:162(<genexpr>)
        2    0.000    0.000    0.000    0.000 slicing.py:464(<listcomp>)
    106/4    0.000    0.000    0.037    0.009 utils.py:502(__call__)
      100    0.000    0.000    0.000    0.000 arrayprint.py:340(array2string)
        2    0.000    0.000    0.037    0.019 base.py:343(tokenize)
      100    0.000    0.000    0.000    0.000 {built-in method builtins.repr}
        2    0.000    0.000   24.763   12.381 slicing.py:170(slice_wrap_lists)
      108    0.000    0.000    0.000    0.000 abc.py:178(__instancecheck__)
        2    0.000    0.000    6.962    3.481 slicing.py:487(take)
        1    0.000    0.000   26.182   26.182 {built-in method builtins.exec}
        2    0.000    0.000    0.000    0.000 slicing.py:465(<listcomp>)
        1    0.000    0.000    0.037    0.037 base.py:314(normalize_array)
      2/1    0.000    0.000    0.000    0.000 base.py:270(normalize_seq)
      116    0.000    0.000    0.000    0.000 _weakrefset.py:70(__contains__)
      100    0.000    0.000    0.000    0.000 numeric.py:1835(array_str)
        1    0.000    0.000    0.000    0.000 slicing.py:47(<listcomp>)
        6    0.000    0.000    0.000    0.000 {built-in method builtins.sum}
        2    0.000    0.000    0.000    0.000 exceptions.py:15(merge)
      100    0.000    0.000    0.000    0.000 inspect.py:441(getmro)
        2    0.000    0.000    0.000    0.000 slicing.py:475(<listcomp>)
        4    0.000    0.000    0.000    0.000 dicttoolz.py:19(merge)
        4    0.000    0.000    0.000    0.000 functoolz.py:217(__call__)
        2    0.000    0.000    0.000    0.000 core.py:1455(normalize_chunks)
        4    0.000    0.000    0.000    0.000 dicttoolz.py:11(_get_factory)
        2    0.000    0.000    0.000    0.000 slicing.py:467(<listcomp>)
      100    0.000    0.000    0.000    0.000 {method 'item' of 'numpy.ndarray' objects}
        2    0.000    0.000    0.000    0.000 core.py:794(__init__)
        8    0.000    0.000    0.000    0.000 {built-in method builtins.all}
        8    0.000    0.000    0.000    0.000 slicing.py:197(<genexpr>)
        8    0.000    0.000    0.000    0.000 slicing.py:183(<genexpr>)
        5    0.000    0.000    0.000    0.000 core.py:1043(<genexpr>)
        7    0.000    0.000    0.000    0.000 {built-in method builtins.hasattr}
        5    0.000    0.000    0.000    0.000 slicing.py:125(<genexpr>)
        1    0.000    0.000    0.000    0.000 {method 'view' of 'numpy.ndarray' objects}
        2    0.000    0.000    0.000    0.000 slicing.py:192(<listcomp>)
        3    0.000    0.000    0.000    0.000 {method 'hexdigest' of '_hashlib.HASH' objects}
        2    0.000    0.000    0.000    0.000 slicing.py:606(replace_ellipsis)
        2    0.000    0.000    0.000    0.000 slicing.py:613(<listcomp>)
        1    0.000    0.000    0.000    0.000 {method 'ravel' of 'numpy.ndarray' objects}
        4    0.000    0.000    0.000    0.000 {method 'items' of 'dict' objects}
        2    0.000    0.000    0.000    0.000 {method 'encode' of 'str' objects}
        8    0.000    0.000    0.000    0.000 slicing.py:207(<genexpr>)
        2    0.000    0.000    0.000    0.000 core.py:826(_get_chunks)
        2    0.000    0.000    0.000    0.000 core.py:1452(<lambda>)
        2    0.000    0.000    0.000    0.000 slicing.py:149(<listcomp>)
        2    0.000    0.000    0.000    0.000 slicing.py:150(<listcomp>)
        1    0.000    0.000    0.000    0.000 functoolz.py:11(identity)
        4    0.000    0.000    0.000    0.000 {method 'pop' of 'dict' objects}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
        2    0.000    0.000    0.000    0.000 {method 'count' of 'tuple' objects}



In [10]:
cProfile.run('gd[dim0_condition][:, dim1_indices]', sort='cumtime')


         105406881 function calls (79072145 primitive calls) in 25.630 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000   25.630   25.630 {built-in method builtins.exec}
        1    0.107    0.107   25.630   25.630 <string>:1(<module>)
        2    0.102    0.051   25.523   12.761 core.py:1024(__getitem__)
        2    0.001    0.000   25.381   12.691 slicing.py:60(slice_array)
        2    0.049    0.024   24.214   12.107 slicing.py:142(slice_with_newaxes)
        2    0.000    0.000   24.147   12.073 slicing.py:170(slice_wrap_lists)
13167268/6    6.664    0.000    8.855    1.476 slicing.py:623(check_index)
13167270/2    4.354    0.000    8.466    4.233 slicing.py:540(posify_index)
        2    2.277    1.139    8.465    4.233 slicing.py:563(<listcomp>)
        2    0.000    0.000    6.826    3.413 slicing.py:487(take)
        2    0.111    0.056    6.331    3.165 slicing.py:441(take_sorted)
        2    4.628    2.314    5.704    2.852 slicing.py:398(partition_by_size)
 52669338    4.026    0.000    4.026    0.000 {built-in method builtins.isinstance}
    106/3    0.071    0.001    1.167    0.389 slicing.py:15(sanitize_index)
        1    0.943    0.943    0.943    0.943 slicing.py:44(<listcomp>)
 13182474    0.581    0.000    0.581    0.000 {built-in method builtins.len}
 13189168    0.497    0.000    0.497    0.000 {method 'append' of 'list' objects}
        2    0.495    0.248    0.495    0.248 slicing.py:420(issorted)
        2    0.281    0.141    0.281    0.141 slicing.py:479(<listcomp>)
        2    0.234    0.117    0.234    0.117 {built-in method builtins.sorted}
        1    0.152    0.152    0.152    0.152 {method 'tolist' of 'numpy.ndarray' objects}
        2    0.000    0.000    0.039    0.020 base.py:343(tokenize)
    106/4    0.000    0.000    0.039    0.010 utils.py:502(__call__)
        1    0.000    0.000    0.039    0.039 base.py:314(normalize_array)
        3    0.039    0.013    0.039    0.013 {built-in method _hashlib.openssl_md5}
     3364    0.003    0.000    0.019    0.000 slicing.py:156(<genexpr>)
     6726    0.012    0.000    0.016    0.000 slicing.py:567(insert_many)
    20178    0.003    0.000    0.003    0.000 {method 'pop' of 'list' objects}
        4    0.000    0.000    0.000    0.000 dicttoolz.py:19(merge)
        8    0.000    0.000    0.000    0.000 {method 'update' of 'dict' objects}
        4    0.000    0.000    0.000    0.000 functoolz.py:217(__call__)
        2    0.000    0.000    0.000    0.000 exceptions.py:15(merge)
      2/1    0.000    0.000    0.000    0.000 base.py:270(normalize_seq)
        2    0.000    0.000    0.000    0.000 slicing.py:162(<genexpr>)
      100    0.000    0.000    0.000    0.000 {built-in method builtins.repr}
        1    0.000    0.000    0.000    0.000 slicing.py:47(<listcomp>)
        2    0.000    0.000    0.000    0.000 slicing.py:464(<listcomp>)
      100    0.000    0.000    0.000    0.000 numeric.py:1835(array_str)
      100    0.000    0.000    0.000    0.000 arrayprint.py:340(array2string)
      108    0.000    0.000    0.000    0.000 abc.py:178(__instancecheck__)
        2    0.000    0.000    0.000    0.000 slicing.py:465(<listcomp>)
        8    0.000    0.000    0.000    0.000 {built-in method builtins.all}
        2    0.000    0.000    0.000    0.000 core.py:794(__init__)
      116    0.000    0.000    0.000    0.000 _weakrefset.py:70(__contains__)
        2    0.000    0.000    0.000    0.000 core.py:1455(normalize_chunks)
        6    0.000    0.000    0.000    0.000 {built-in method builtins.sum}
        8    0.000    0.000    0.000    0.000 slicing.py:183(<genexpr>)
      100    0.000    0.000    0.000    0.000 {method 'item' of 'numpy.ndarray' objects}
      100    0.000    0.000    0.000    0.000 inspect.py:441(getmro)
        2    0.000    0.000    0.000    0.000 {method 'encode' of 'str' objects}
        2    0.000    0.000    0.000    0.000 slicing.py:606(replace_ellipsis)
        2    0.000    0.000    0.000    0.000 slicing.py:475(<listcomp>)
        5    0.000    0.000    0.000    0.000 slicing.py:125(<genexpr>)
        2    0.000    0.000    0.000    0.000 slicing.py:467(<listcomp>)
        3    0.000    0.000    0.000    0.000 {method 'hexdigest' of '_hashlib.HASH' objects}
        1    0.000    0.000    0.000    0.000 {method 'view' of 'numpy.ndarray' objects}
        2    0.000    0.000    0.000    0.000 slicing.py:192(<listcomp>)
        4    0.000    0.000    0.000    0.000 dicttoolz.py:11(_get_factory)
        5    0.000    0.000    0.000    0.000 core.py:1043(<genexpr>)
        7    0.000    0.000    0.000    0.000 {built-in method builtins.hasattr}
        8    0.000    0.000    0.000    0.000 slicing.py:207(<genexpr>)
        2    0.000    0.000    0.000    0.000 slicing.py:613(<listcomp>)
        2    0.000    0.000    0.000    0.000 slicing.py:149(<listcomp>)
        1    0.000    0.000    0.000    0.000 {method 'ravel' of 'numpy.ndarray' objects}
        8    0.000    0.000    0.000    0.000 slicing.py:197(<genexpr>)
        2    0.000    0.000    0.000    0.000 core.py:826(_get_chunks)
        2    0.000    0.000    0.000    0.000 core.py:1452(<lambda>)
        4    0.000    0.000    0.000    0.000 {method 'pop' of 'dict' objects}
        4    0.000    0.000    0.000    0.000 {method 'items' of 'dict' objects}
        2    0.000    0.000    0.000    0.000 slicing.py:150(<listcomp>)
        2    0.000    0.000    0.000    0.000 {method 'count' of 'tuple' objects}
        1    0.000    0.000    0.000    0.000 functoolz.py:11(identity)
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}


Synthetic data


In [22]:
# create a synthetic dataset for profiling
a = zarr.array(np.random.randint(-1, 4, size=(20000000, 200, 2), dtype='i1'),
               chunks=(10000, 100, 2), compressor=zarr.Blosc(cname='zstd', clevel=1, shuffle=2))
a


Out[22]:
Array((20000000, 200, 2), int8, chunks=(10000, 100, 2), order=C)
  nbytes: 7.5G; nbytes_stored: 2.7G; ratio: 2.8; initialized: 4000/4000
  compressor: Blosc(cname='zstd', clevel=1, shuffle=2)
  store: dict

In [24]:
# create a synthetic selection for first axis
c = np.random.randint(0, 2, size=a.shape[0], dtype=bool)

In [25]:
# create a synthetic selection for second axis
s = sorted(np.random.choice(a.shape[1], size=100, replace=False))

In [26]:
%time d = da.from_array(a, chunks=(a.chunks[0], None, None))
d


CPU times: user 208 ms, sys: 0 ns, total: 208 ms
Wall time: 206 ms
Out[26]:
dask.array<array-5..., shape=(20000000, 200, 2), dtype=int8, chunksize=(10000, 200, 2)>

In [27]:
%time ds = d[c][:, s]


CPU times: user 12 s, sys: 200 ms, total: 12.2 s
Wall time: 12.2 s

In [28]:
cProfile.run('d[c][:, s]', sort='time')


         80095589 function calls (60091843 primitive calls) in 19.467 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
10001773/6    4.872    0.000    6.456    1.076 slicing.py:623(check_index)
        2    3.517    1.758    4.357    2.179 slicing.py:398(partition_by_size)
10001775/2    3.354    0.000    6.484    3.242 slicing.py:540(posify_index)
 40007358    2.965    0.000    2.965    0.000 {built-in method builtins.isinstance}
        2    1.749    0.875    6.484    3.242 slicing.py:563(<listcomp>)
        1    0.878    0.878    0.878    0.878 slicing.py:44(<listcomp>)
 10019804    0.451    0.000    0.451    0.000 {built-in method builtins.len}
 10027774    0.392    0.000    0.392    0.000 {method 'append' of 'list' objects}
        2    0.363    0.181    0.363    0.181 slicing.py:420(issorted)
        2    0.270    0.135    4.786    2.393 slicing.py:441(take_sorted)
        1    0.207    0.207    0.207    0.207 {method 'tolist' of 'numpy.ndarray' objects}
        2    0.158    0.079    0.158    0.079 {built-in method builtins.sorted}
        1    0.094    0.094   19.467   19.467 <string>:1(<module>)
        2    0.079    0.040   19.373    9.686 core.py:1024(__getitem__)
        2    0.035    0.017   18.147    9.074 slicing.py:142(slice_with_newaxes)
        3    0.033    0.011    0.033    0.011 {built-in method _hashlib.openssl_md5}
    106/3    0.028    0.000    1.112    0.371 slicing.py:15(sanitize_index)
     8002    0.015    0.000    0.020    0.000 slicing.py:567(insert_many)
     4002    0.004    0.000    0.023    0.000 slicing.py:156(<genexpr>)
    24006    0.003    0.000    0.003    0.000 {method 'pop' of 'list' objects}
        8    0.001    0.000    0.001    0.000 {method 'update' of 'dict' objects}
        2    0.001    0.000    0.001    0.000 slicing.py:479(<listcomp>)
        2    0.000    0.000   19.259    9.630 slicing.py:60(slice_array)
        2    0.000    0.000    0.000    0.000 slicing.py:162(<genexpr>)
        2    0.000    0.000    0.000    0.000 slicing.py:464(<listcomp>)
        2    0.000    0.000    0.000    0.000 slicing.py:465(<listcomp>)
    106/4    0.000    0.000    0.034    0.008 utils.py:502(__call__)
        2    0.000    0.000   18.089    9.044 slicing.py:170(slice_wrap_lists)
      100    0.000    0.000    0.000    0.000 arrayprint.py:340(array2string)
      100    0.000    0.000    0.000    0.000 {built-in method builtins.repr}
      108    0.000    0.000    0.000    0.000 abc.py:178(__instancecheck__)
        2    0.000    0.000    5.149    2.574 slicing.py:487(take)
        2    0.000    0.000    0.034    0.017 base.py:343(tokenize)
        1    0.000    0.000    0.033    0.033 base.py:314(normalize_array)
      116    0.000    0.000    0.000    0.000 _weakrefset.py:70(__contains__)
      2/1    0.000    0.000    0.000    0.000 base.py:270(normalize_seq)
        6    0.000    0.000    0.000    0.000 {built-in method builtins.sum}
      100    0.000    0.000    0.000    0.000 numeric.py:1835(array_str)
        1    0.000    0.000    0.000    0.000 slicing.py:47(<listcomp>)
        1    0.000    0.000   19.467   19.467 {built-in method builtins.exec}
      100    0.000    0.000    0.000    0.000 inspect.py:441(getmro)
        8    0.000    0.000    0.000    0.000 {built-in method builtins.all}
        4    0.000    0.000    0.001    0.000 dicttoolz.py:19(merge)
        2    0.000    0.000    0.000    0.000 core.py:1455(normalize_chunks)
      100    0.000    0.000    0.000    0.000 {method 'item' of 'numpy.ndarray' objects}
        2    0.000    0.000    0.000    0.000 slicing.py:475(<listcomp>)
        2    0.000    0.000    0.000    0.000 core.py:794(__init__)
        2    0.000    0.000    0.000    0.000 slicing.py:467(<listcomp>)
        3    0.000    0.000    0.000    0.000 {method 'hexdigest' of '_hashlib.HASH' objects}
        2    0.000    0.000    0.001    0.000 exceptions.py:15(merge)
        7    0.000    0.000    0.000    0.000 {built-in method builtins.hasattr}
        2    0.000    0.000    0.000    0.000 slicing.py:606(replace_ellipsis)
        4    0.000    0.000    0.001    0.000 functoolz.py:217(__call__)
        8    0.000    0.000    0.000    0.000 slicing.py:183(<genexpr>)
        4    0.000    0.000    0.000    0.000 dicttoolz.py:11(_get_factory)
        5    0.000    0.000    0.000    0.000 core.py:1043(<genexpr>)
        2    0.000    0.000    0.000    0.000 {method 'encode' of 'str' objects}
        1    0.000    0.000    0.000    0.000 {method 'view' of 'numpy.ndarray' objects}
        8    0.000    0.000    0.000    0.000 slicing.py:197(<genexpr>)
        5    0.000    0.000    0.000    0.000 slicing.py:125(<genexpr>)
        2    0.000    0.000    0.000    0.000 slicing.py:192(<listcomp>)
        8    0.000    0.000    0.000    0.000 slicing.py:207(<genexpr>)
        2    0.000    0.000    0.000    0.000 slicing.py:613(<listcomp>)
        2    0.000    0.000    0.000    0.000 {method 'count' of 'tuple' objects}
        1    0.000    0.000    0.000    0.000 {method 'ravel' of 'numpy.ndarray' objects}
        1    0.000    0.000    0.000    0.000 functoolz.py:11(identity)
        4    0.000    0.000    0.000    0.000 {method 'pop' of 'dict' objects}
        2    0.000    0.000    0.000    0.000 slicing.py:150(<listcomp>)
        2    0.000    0.000    0.000    0.000 core.py:826(_get_chunks)
        2    0.000    0.000    0.000    0.000 core.py:1452(<lambda>)
        2    0.000    0.000    0.000    0.000 slicing.py:149(<listcomp>)
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
        4    0.000    0.000    0.000    0.000 {method 'items' of 'dict' objects}



In [29]:
%time ds[1000000:1100000].compute(optimize_graph=False)


CPU times: user 452 ms, sys: 8 ms, total: 460 ms
Wall time: 148 ms
Out[29]:
array([[[ 2, -1],
        [ 2,  3],
        [ 3,  0],
        ..., 
        [ 1,  3],
        [-1, -1],
        [ 1,  1]],

       [[ 1, -1],
        [ 2,  2],
        [-1,  2],
        ..., 
        [ 2, -1],
        [ 1,  3],
        [-1, -1]],

       [[ 1, -1],
        [ 2,  0],
        [ 0,  3],
        ..., 
        [ 2,  2],
        [ 3,  2],
        [ 0,  2]],

       ..., 
       [[ 1,  2],
        [ 3, -1],
        [ 2,  1],
        ..., 
        [ 1,  2],
        [ 1,  0],
        [ 2,  0]],

       [[ 1,  2],
        [ 1,  0],
        [ 2,  3],
        ..., 
        [-1,  2],
        [ 3,  3],
        [ 1, -1]],

       [[-1,  3],
        [ 2,  2],
        [ 1,  1],
        ..., 
        [ 3,  3],
        [ 0,  0],
        [ 0,  2]]], dtype=int8)

In [30]:
# problem is in fact just the dim0 selection
cProfile.run('d[c]', sort='time')


         80055494 function calls (60052157 primitive calls) in 19.425 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
10001670/3    5.032    0.000    6.671    2.224 slicing.py:623(check_index)
        1    3.459    3.459    4.272    4.272 slicing.py:398(partition_by_size)
10001671/1    3.287    0.000    6.378    6.378 slicing.py:540(posify_index)
 40006704    2.999    0.000    2.999    0.000 {built-in method builtins.isinstance}
        1    1.731    1.731    6.378    6.378 slicing.py:563(<listcomp>)
        1    0.849    0.849    0.849    0.849 slicing.py:44(<listcomp>)
 10011685    0.433    0.000    0.433    0.000 {built-in method builtins.len}
 10015670    0.381    0.000    0.381    0.000 {method 'append' of 'list' objects}
        1    0.355    0.355    0.355    0.355 slicing.py:420(issorted)
        1    0.196    0.196    0.196    0.196 {method 'tolist' of 'numpy.ndarray' objects}
        1    0.193    0.193    0.193    0.193 slicing.py:479(<listcomp>)
        1    0.157    0.157    0.157    0.157 {built-in method builtins.sorted}
        1    0.085    0.085    4.707    4.707 slicing.py:441(take_sorted)
        1    0.085    0.085   19.425   19.425 <string>:1(<module>)
        1    0.079    0.079   19.341   19.341 core.py:1024(__getitem__)
        1    0.034    0.034   18.157   18.157 slicing.py:142(slice_with_newaxes)
        2    0.033    0.017    0.033    0.017 {built-in method _hashlib.openssl_md5}
        1    0.026    0.026    1.071    1.071 slicing.py:15(sanitize_index)
     4001    0.007    0.000    0.009    0.000 slicing.py:567(insert_many)
     2001    0.002    0.000    0.011    0.000 slicing.py:156(<genexpr>)
    12003    0.001    0.000    0.001    0.000 {method 'pop' of 'list' objects}
        4    0.000    0.000    0.000    0.000 {method 'update' of 'dict' objects}
        1    0.000    0.000   19.228   19.228 slicing.py:60(slice_array)
        1    0.000    0.000    0.000    0.000 slicing.py:464(<listcomp>)
        1    0.000    0.000    0.000    0.000 slicing.py:162(<genexpr>)
        1    0.000    0.000    0.033    0.033 base.py:314(normalize_array)
        1    0.000    0.000   18.111   18.111 slicing.py:170(slice_wrap_lists)
        1    0.000    0.000    0.000    0.000 slicing.py:465(<listcomp>)
        1    0.000    0.000    5.062    5.062 slicing.py:487(take)
        1    0.000    0.000    0.033    0.033 base.py:343(tokenize)
        1    0.000    0.000   19.425   19.425 {built-in method builtins.exec}
        2    0.000    0.000    0.000    0.000 functoolz.py:217(__call__)
        3    0.000    0.000    0.000    0.000 {built-in method builtins.sum}
        2    0.000    0.000    0.000    0.000 abc.py:178(__instancecheck__)
        1    0.000    0.000    0.000    0.000 core.py:1455(normalize_chunks)
        2    0.000    0.000    0.000    0.000 dicttoolz.py:19(merge)
        4    0.000    0.000    0.000    0.000 _weakrefset.py:70(__contains__)
        2    0.000    0.000    0.000    0.000 dicttoolz.py:11(_get_factory)
        1    0.000    0.000    0.000    0.000 exceptions.py:15(merge)
        1    0.000    0.000    0.000    0.000 core.py:794(__init__)
        4    0.000    0.000    0.000    0.000 {built-in method builtins.all}
        1    0.000    0.000    0.000    0.000 slicing.py:467(<listcomp>)
        1    0.000    0.000    0.000    0.000 {method 'view' of 'numpy.ndarray' objects}
        4    0.000    0.000    0.000    0.000 slicing.py:183(<genexpr>)
        2    0.000    0.000    0.000    0.000 {method 'hexdigest' of '_hashlib.HASH' objects}
        1    0.000    0.000    0.000    0.000 slicing.py:606(replace_ellipsis)
        1    0.000    0.000    0.000    0.000 slicing.py:192(<listcomp>)
        4    0.000    0.000    0.000    0.000 slicing.py:207(<genexpr>)
        1    0.000    0.000    0.000    0.000 slicing.py:475(<listcomp>)
        2    0.000    0.000    0.033    0.017 utils.py:502(__call__)
        2    0.000    0.000    0.000    0.000 slicing.py:125(<genexpr>)
        2    0.000    0.000    0.000    0.000 core.py:1043(<genexpr>)
        4    0.000    0.000    0.000    0.000 slicing.py:197(<genexpr>)
        1    0.000    0.000    0.000    0.000 core.py:826(_get_chunks)
        2    0.000    0.000    0.000    0.000 {built-in method builtins.hasattr}
        1    0.000    0.000    0.000    0.000 {method 'ravel' of 'numpy.ndarray' objects}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
        1    0.000    0.000    0.000    0.000 {method 'encode' of 'str' objects}
        1    0.000    0.000    0.000    0.000 slicing.py:613(<listcomp>)
        1    0.000    0.000    0.000    0.000 core.py:1452(<lambda>)
        1    0.000    0.000    0.000    0.000 slicing.py:149(<listcomp>)
        2    0.000    0.000    0.000    0.000 {method 'pop' of 'dict' objects}
        2    0.000    0.000    0.000    0.000 {method 'items' of 'dict' objects}
        1    0.000    0.000    0.000    0.000 slicing.py:150(<listcomp>)
        1    0.000    0.000    0.000    0.000 {method 'count' of 'tuple' objects}