In [2]:
import numpy as np

In [3]:
X = np.random.random(size=(80, 100)) > 0.3

In [4]:
XX = [np.nonzero(row)[0] for row in X]
XX


Out[4]:
[array([ 1,  2,  4,  5,  7,  8,  9, 10, 12, 13, 16, 17, 18, 19, 20, 21, 22,
        26, 27, 28, 31, 32, 33, 36, 37, 38, 39, 40, 41, 42, 43, 47, 49, 50,
        51, 52, 53, 56, 57, 58, 59, 60, 61, 63, 65, 66, 68, 69, 70, 71, 72,
        73, 74, 76, 77, 78, 81, 82, 84, 85, 86, 88, 89, 90, 91, 92, 93, 94,
        95, 96, 97, 98, 99]),
 array([ 0,  1,  2,  4,  5,  6,  8,  9, 11, 12, 13, 14, 16, 17, 18, 19, 20,
        23, 24, 25, 26, 27, 28, 29, 30, 33, 34, 35, 36, 38, 39, 40, 41, 43,
        45, 46, 47, 48, 49, 52, 53, 55, 56, 57, 58, 59, 60, 61, 63, 65, 66,
        68, 69, 70, 71, 73, 74, 76, 79, 80, 82, 83, 84, 85, 87, 89, 90, 92,
        94, 96, 97, 98, 99]),
 array([ 0,  2,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 17, 19, 20, 21,
        23, 24, 25, 27, 29, 30, 31, 33, 35, 36, 37, 38, 39, 41, 43, 44, 45,
        46, 47, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
        65, 67, 68, 69, 72, 73, 75, 76, 78, 79, 80, 81, 82, 83, 84, 85, 86,
        89, 90, 91, 93, 94, 95, 96, 97, 99]),
 array([ 0,  1,  2,  4,  5,  6,  7, 11, 12, 13, 14, 15, 16, 18, 19, 20, 23,
        24, 25, 26, 27, 28, 30, 31, 32, 34, 36, 37, 38, 39, 40, 41, 44, 45,
        46, 51, 52, 53, 55, 56, 57, 59, 61, 62, 63, 64, 65, 66, 67, 71, 72,
        73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90,
        91, 93, 94, 95, 96, 97, 98, 99]),
 array([ 1,  3,  5,  6,  8,  9, 10, 11, 12, 16, 18, 19, 23, 26, 28, 29, 30,
        31, 35, 36, 37, 40, 41, 42, 44, 45, 46, 47, 48, 49, 50, 52, 53, 54,
        56, 57, 58, 59, 60, 61, 62, 64, 65, 66, 67, 69, 70, 71, 72, 73, 75,
        76, 77, 78, 79, 80, 81, 82, 83, 85, 87, 88, 89, 92, 93, 94, 96, 97,
        98]),
 array([ 0,  1,  3,  4,  6,  7,  9, 10, 11, 12, 13, 14, 15, 16, 18, 19, 20,
        22, 23, 24, 26, 28, 29, 30, 32, 33, 34, 35, 37, 38, 40, 43, 45, 46,
        48, 49, 50, 51, 53, 54, 55, 57, 59, 60, 63, 64, 65, 68, 69, 71, 73,
        76, 77, 78, 79, 80, 82, 83, 84, 86, 88, 89, 91, 92, 93, 94, 95, 96,
        99]),
 array([ 0,  1,  2,  3,  7,  9, 11, 12, 13, 14, 16, 17, 18, 19, 20, 21, 23,
        27, 29, 30, 32, 33, 36, 37, 38, 39, 41, 42, 43, 44, 45, 46, 47, 48,
        49, 50, 51, 53, 55, 56, 57, 58, 59, 61, 62, 64, 65, 66, 67, 68, 69,
        70, 72, 73, 75, 76, 77, 78, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89,
        91, 92, 93, 95, 96, 97, 98, 99]),
 array([ 1,  2,  5,  8,  9, 11, 12, 13, 14, 16, 17, 18, 19, 20, 21, 22, 23,
        24, 26, 27, 28, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42,
        44, 47, 48, 50, 51, 52, 53, 54, 56, 57, 61, 62, 63, 64, 65, 66, 67,
        68, 69, 72, 73, 74, 76, 78, 79, 80, 81, 82, 84, 85, 86, 88, 91, 92,
        93, 94, 95, 97, 98]),
 array([ 0,  1,  4,  5,  8,  9, 10, 12, 14, 15, 19, 20, 21, 22, 23, 25, 26,
        27, 28, 29, 31, 33, 34, 35, 36, 39, 40, 41, 42, 43, 44, 45, 46, 51,
        53, 54, 59, 60, 63, 64, 65, 66, 68, 69, 70, 71, 72, 73, 74, 75, 77,
        78, 79, 80, 81, 82, 83, 84, 86, 89, 90, 92, 93, 94, 96, 97, 98]),
 array([ 0,  2,  3,  4,  5,  8,  9, 11, 13, 14, 15, 16, 18, 19, 20, 22, 23,
        25, 27, 28, 29, 30, 31, 33, 34, 35, 38, 40, 41, 43, 44, 45, 46, 47,
        48, 53, 54, 55, 59, 60, 61, 62, 64, 65, 66, 67, 68, 70, 72, 73, 74,
        75, 76, 78, 79, 80, 81, 82, 83, 85, 86, 87, 88, 91, 93, 95, 97, 99]),
 array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
        17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 29, 30, 31, 33, 34, 35,
        36, 37, 38, 39, 40, 41, 42, 43, 44, 47, 49, 51, 52, 53, 55, 56, 57,
        58, 60, 61, 62, 63, 64, 65, 67, 68, 69, 72, 73, 74, 75, 79, 80, 81,
        82, 83, 84, 89, 90, 92, 94, 96]),
 array([ 0,  2,  3,  5,  6,  8, 10, 11, 12, 13, 14, 15, 17, 18, 19, 20, 22,
        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 37, 39, 40, 41, 42,
        43, 44, 45, 46, 47, 48, 49, 51, 53, 54, 55, 56, 58, 59, 60, 61, 62,
        63, 64, 65, 66, 67, 68, 69, 70, 72, 73, 75, 76, 78, 82, 84, 86, 92,
        93, 94, 95, 96, 98, 99]),
 array([ 0,  2,  3,  4,  5,  7,  8,  9, 10, 11, 12, 13, 15, 16, 18, 19, 21,
        22, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 38, 39, 41,
        42, 43, 45, 46, 47, 50, 51, 52, 53, 54, 55, 56, 57, 58, 61, 63, 64,
        65, 66, 67, 68, 71, 72, 74, 75, 78, 85, 86, 87, 88, 91, 92, 93, 94,
        95, 96, 97, 99]),
 array([ 0,  2,  3,  4,  5,  6,  7,  9, 10, 11, 13, 15, 16, 17, 18, 19, 21,
        23, 26, 27, 28, 30, 31, 33, 34, 35, 36, 37, 38, 39, 41, 42, 43, 45,
        47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 60, 61, 62, 63, 65,
        66, 67, 68, 69, 70, 71, 74, 75, 76, 78, 80, 81, 82, 85, 88, 89, 92,
        93, 94, 95, 96, 98, 99]),
 array([ 1,  3,  4,  5,  6,  7,  8, 10, 11, 13, 15, 16, 18, 19, 20, 21, 22,
        23, 26, 27, 29, 31, 32, 34, 35, 37, 38, 44, 45, 47, 48, 49, 50, 51,
        53, 54, 55, 57, 58, 59, 60, 61, 64, 65, 66, 67, 68, 70, 71, 73, 74,
        75, 77, 79, 80, 83, 84, 86, 87, 88, 92, 95, 97, 98, 99]),
 array([ 0,  1,  2,  3,  4,  5,  7,  9, 11, 13, 14, 15, 17, 18, 20, 22, 23,
        26, 27, 28, 29, 30, 32, 33, 34, 35, 36, 37, 38, 40, 41, 42, 43, 44,
        45, 46, 47, 48, 49, 50, 52, 54, 56, 58, 59, 60, 62, 63, 64, 65, 66,
        67, 68, 69, 70, 71, 72, 73, 77, 78, 79, 80, 83, 85, 86, 87, 88, 90,
        92, 93, 94, 95, 97, 98, 99]),
 array([ 1,  2,  3,  4, 11, 12, 13, 14, 16, 17, 19, 20, 22, 23, 24, 25, 27,
        28, 30, 32, 33, 34, 35, 37, 38, 39, 40, 43, 44, 45, 48, 50, 53, 54,
        55, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72,
        73, 74, 75, 76, 77, 78, 80, 81, 82, 84, 85, 86, 88, 89, 90, 91, 92,
        94, 96, 97, 98, 99]),
 array([ 1,  2,  4,  5,  6,  7,  8,  9, 10, 11, 13, 14, 15, 16, 17, 18, 19,
        21, 22, 23, 24, 26, 27, 28, 29, 31, 33, 34, 36, 39, 40, 42, 43, 44,
        45, 46, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 61, 62, 63,
        64, 65, 67, 68, 69, 70, 72, 73, 74, 78, 79, 81, 82, 83, 87, 88, 90,
        91, 93, 94, 97, 99]),
 array([ 0,  1,  3,  4,  5,  7,  8,  9, 10, 11, 12, 14, 15, 16, 18, 19, 20,
        21, 22, 24, 25, 26, 27, 29, 30, 32, 33, 34, 35, 36, 37, 38, 39, 40,
        41, 42, 43, 44, 45, 46, 47, 49, 50, 51, 54, 55, 57, 58, 59, 60, 61,
        62, 63, 64, 65, 66, 68, 69, 70, 72, 74, 75, 76, 77, 78, 79, 80, 82,
        84, 86, 90, 91, 93, 94, 95, 98, 99]),
 array([ 0,  1,  2,  3,  4,  5,  6,  7,  9, 10, 12, 13, 15, 17, 18, 19, 22,
        24, 26, 27, 28, 29, 32, 34, 36, 37, 38, 40, 41, 43, 45, 46, 47, 48,
        49, 50, 51, 54, 55, 56, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68,
        70, 72, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 86, 87, 89, 90,
        92, 93, 94, 95, 97, 99]),
 array([ 1,  2,  4,  5,  8,  9, 10, 12, 13, 14, 16, 17, 19, 20, 21, 24, 25,
        26, 27, 29, 31, 34, 35, 37, 38, 39, 40, 41, 45, 46, 48, 50, 52, 53,
        55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,
        72, 73, 74, 76, 77, 79, 80, 81, 83, 85, 86, 87, 88, 89, 91, 92, 93,
        94, 96, 98]),
 array([ 0,  1,  2,  3,  6,  8,  9, 10, 14, 15, 17, 18, 19, 20, 21, 22, 23,
        25, 27, 28, 29, 32, 33, 34, 35, 37, 38, 39, 40, 42, 43, 45, 46, 47,
        48, 50, 51, 52, 53, 54, 55, 57, 58, 59, 60, 63, 65, 66, 67, 68, 69,
        71, 72, 74, 75, 76, 77, 78, 80, 81, 83, 84, 85, 86, 87, 88, 89, 90,
        92, 94, 95, 96, 97, 98, 99]),
 array([ 2,  3,  4,  5,  6,  8,  9, 10, 11, 12, 13, 14, 16, 19, 20, 21, 22,
        23, 24, 25, 26, 27, 28, 29, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40,
        41, 42, 43, 46, 47, 48, 49, 51, 52, 53, 57, 58, 59, 62, 64, 65, 66,
        67, 68, 69, 70, 71, 72, 73, 76, 77, 78, 80, 82, 83, 84, 85, 86, 87,
        89, 90, 91, 92, 93, 94, 96, 97, 98]),
 array([ 0,  1,  2,  3,  4,  5,  7,  8,  9, 11, 12, 15, 16, 17, 18, 19, 20,
        21, 22, 24, 26, 28, 29, 30, 31, 32, 33, 34, 36, 38, 39, 40, 41, 43,
        44, 45, 47, 48, 49, 50, 52, 53, 54, 57, 58, 60, 61, 62, 63, 64, 68,
        71, 72, 73, 75, 76, 79, 80, 82, 83, 84, 85, 86, 87, 88, 89, 91, 92,
        98, 99]),
 array([ 0,  2,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 18, 19,
        21, 23, 24, 26, 27, 28, 30, 31, 32, 36, 39, 40, 41, 43, 45, 46, 47,
        48, 50, 53, 54, 55, 56, 59, 60, 64, 65, 66, 67, 69, 70, 71, 72, 73,
        74, 75, 77, 78, 79, 80, 81, 84, 85, 86, 87, 89, 90, 91, 94, 95, 96,
        98, 99]),
 array([ 0,  1,  2,  3,  4,  5,  6,  8,  9, 10, 11, 13, 14, 15, 16, 17, 18,
        19, 20, 21, 23, 24, 25, 26, 27, 28, 29, 34, 35, 36, 37, 38, 39, 40,
        43, 44, 48, 49, 50, 51, 52, 54, 55, 56, 58, 59, 60, 62, 63, 64, 66,
        67, 68, 69, 70, 71, 72, 74, 75, 76, 78, 79, 80, 81, 82, 85, 86, 88,
        89, 90, 91, 92, 93, 94, 96, 97, 98, 99]),
 array([ 0,  2,  3,  4,  5,  7,  8,  9, 11, 12, 13, 14, 16, 17, 18, 19, 20,
        21, 22, 24, 28, 29, 30, 31, 32, 34, 35, 36, 38, 39, 40, 42, 43, 44,
        45, 46, 48, 49, 50, 51, 52, 53, 54, 56, 57, 59, 60, 62, 63, 64, 65,
        66, 67, 68, 69, 70, 71, 72, 74, 75, 76, 77, 78, 79, 80, 81, 83, 84,
        86, 87, 88, 90, 91, 92, 93, 94, 96, 98, 99]),
 array([ 0,  1,  2,  4,  5,  8,  9, 10, 12, 13, 14, 15, 17, 19, 20, 21, 23,
        26, 27, 28, 29, 30, 32, 34, 40, 42, 43, 44, 45, 46, 47, 49, 50, 51,
        52, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 66, 67, 68, 69, 70, 71,
        72, 74, 75, 76, 77, 79, 80, 82, 83, 84, 86, 87, 88, 89, 90, 91, 92,
        94, 97, 98]),
 array([ 1,  2,  3,  4,  5,  6,  7,  9, 10, 11, 12, 14, 16, 17, 18, 19, 21,
        22, 23, 24, 25, 26, 27, 29, 30, 31, 32, 33, 34, 35, 37, 39, 41, 43,
        44, 45, 47, 48, 49, 50, 51, 52, 54, 55, 56, 59, 60, 61, 63, 64, 65,
        66, 68, 69, 70, 71, 72, 75, 77, 79, 80, 81, 82, 83, 84, 85, 86, 87,
        88, 89, 90, 91, 95, 96, 97, 99]),
 array([ 0,  2,  3,  4,  5,  7, 10, 11, 12, 14, 15, 18, 19, 20, 21, 24, 25,
        26, 27, 28, 30, 31, 34, 36, 37, 40, 41, 44, 46, 49, 51, 53, 54, 55,
        56, 57, 58, 61, 62, 63, 64, 65, 66, 68, 71, 73, 75, 76, 78, 79, 80,
        82, 84, 85, 86, 90, 92, 93, 94, 96, 97, 98, 99]),
 array([ 0,  4,  5,  6,  9, 10, 11, 13, 14, 15, 16, 17, 19, 20, 21, 22, 24,
        25, 27, 28, 29, 30, 31, 32, 35, 37, 39, 40, 41, 42, 43, 45, 46, 47,
        48, 49, 51, 53, 56, 57, 60, 61, 62, 65, 66, 68, 69, 71, 72, 73, 74,
        75, 76, 77, 78, 82, 83, 85, 87, 89, 92, 94, 95, 96, 98]),
 array([ 0,  1,  4,  5,  6,  7,  9, 10, 11, 12, 15, 16, 18, 19, 20, 22, 23,
        24, 27, 28, 29, 30, 31, 32, 34, 35, 36, 37, 38, 40, 42, 43, 44, 45,
        47, 48, 49, 50, 51, 54, 56, 57, 58, 59, 60, 62, 64, 65, 66, 67, 69,
        70, 72, 73, 74, 75, 77, 78, 79, 80, 81, 83, 85, 86, 88, 89, 94, 96,
        97, 98]),
 array([ 1,  2,  3,  5,  6,  7,  8,  9, 12, 14, 15, 16, 17, 20, 21, 22, 23,
        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 35, 36, 39, 40, 42, 44, 46,
        47, 48, 49, 50, 54, 55, 56, 59, 60, 61, 62, 63, 64, 65, 67, 68, 69,
        71, 72, 73, 74, 75, 76, 77, 79, 80, 81, 83, 84, 85, 86, 87, 88, 89,
        91, 92, 95, 96, 99]),
 array([ 0,  3,  4,  7,  8,  9, 11, 12, 13, 14, 15, 16, 17, 18, 21, 22, 24,
        27, 28, 29, 30, 31, 32, 34, 36, 38, 40, 42, 44, 45, 46, 47, 49, 50,
        51, 52, 53, 54, 55, 56, 57, 59, 60, 65, 66, 67, 68, 69, 70, 71, 72,
        73, 75, 76, 78, 79, 80, 81, 84, 85, 86, 87, 88, 89, 90, 91, 93, 95,
        96, 97, 98, 99]),
 array([ 0,  1,  2,  3,  4,  6,  7,  8,  9, 11, 15, 16, 17, 18, 19, 21, 22,
        23, 24, 26, 27, 29, 30, 32, 33, 34, 35, 37, 40, 41, 42, 43, 44, 45,
        46, 47, 48, 49, 50, 51, 52, 55, 57, 58, 60, 61, 62, 63, 64, 65, 66,
        67, 72, 75, 77, 79, 80, 81, 83, 84, 85, 86, 87, 88, 89, 90, 91, 97,
        99]),
 array([ 1,  2,  3,  4,  5,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
        19, 20, 21, 22, 23, 24, 25, 28, 29, 30, 32, 33, 34, 35, 36, 37, 40,
        43, 44, 46, 47, 48, 49, 50, 51, 53, 54, 55, 57, 61, 62, 63, 64, 65,
        66, 67, 69, 72, 74, 75, 76, 77, 78, 79, 80, 81, 82, 85, 86, 87, 88,
        89, 91, 92, 93, 94, 96, 97]),
 array([ 0,  1,  2,  3,  5,  6,  9, 10, 11, 12, 14, 16, 19, 20, 21, 22, 23,
        24, 25, 26, 28, 30, 31, 32, 34, 35, 36, 37, 38, 39, 40, 44, 46, 47,
        48, 49, 52, 53, 54, 57, 58, 59, 61, 62, 65, 66, 67, 68, 69, 70, 72,
        73, 75, 76, 77, 78, 79, 80, 83, 85, 86, 87, 88, 89, 91, 96, 98, 99]),
 array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 11, 12, 16, 17, 18, 19, 20, 23,
        24, 25, 26, 27, 28, 30, 31, 33, 35, 36, 37, 38, 39, 41, 42, 43, 44,
        45, 46, 47, 49, 52, 53, 54, 55, 56, 57, 59, 60, 61, 62, 63, 64, 65,
        66, 67, 69, 72, 74, 75, 76, 77, 79, 81, 82, 83, 84, 85, 87, 88, 90,
        91, 92, 95, 96, 97, 98, 99]),
 array([ 0,  1,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 25, 26, 33, 34, 35, 37, 38, 41, 42, 43, 44,
        45, 46, 47, 48, 49, 50, 51, 54, 56, 57, 58, 59, 61, 63, 65, 67, 68,
        69, 70, 71, 72, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 88,
        90, 93, 94, 95, 96, 97, 98, 99]),
 array([ 0,  1,  2,  3,  4,  5,  6,  7,  8, 12, 13, 14, 16, 17, 20, 22, 23,
        24, 26, 28, 29, 31, 32, 33, 34, 35, 36, 38, 39, 40, 41, 44, 45, 46,
        48, 51, 53, 54, 55, 56, 57, 59, 60, 61, 62, 64, 65, 66, 67, 68, 70,
        73, 74, 75, 78, 79, 80, 81, 82, 83, 84, 85, 86, 88, 89, 90, 91, 92,
        94, 95, 97, 98]),
 array([ 0,  1,  4,  6,  8,  9, 10, 11, 12, 13, 14, 15, 16, 18, 19, 21, 22,
        23, 25, 26, 29, 30, 31, 32, 33, 34, 35, 38, 40, 41, 45, 46, 47, 48,
        49, 53, 54, 55, 57, 58, 59, 64, 65, 67, 72, 74, 75, 76, 78, 81, 82,
        84, 85, 86, 89, 90, 92, 93, 96, 97, 99]),
 array([ 0,  1,  4,  5,  6,  9, 15, 16, 17, 19, 20, 21, 23, 24, 25, 30, 31,
        33, 34, 35, 38, 39, 41, 42, 44, 47, 48, 50, 51, 52, 54, 57, 58, 59,
        61, 63, 64, 65, 66, 68, 70, 71, 72, 73, 74, 76, 77, 78, 79, 83, 84,
        85, 86, 87, 88, 90, 91, 92, 93, 94, 95, 98, 99]),
 array([ 0,  1,  2,  3,  4,  5,  6,  7, 10, 13, 15, 16, 17, 18, 19, 20, 23,
        26, 28, 30, 31, 33, 34, 35, 36, 37, 39, 40, 41, 42, 45, 46, 47, 48,
        49, 50, 51, 52, 53, 54, 55, 58, 59, 60, 61, 62, 65, 66, 67, 69, 70,
        71, 74, 75, 78, 81, 82, 83, 84, 85, 86, 87, 88, 90, 92, 94, 95, 96,
        98]),
 array([ 0,  1,  2,  4,  5,  6,  8,  9, 10, 11, 12, 13, 14, 15, 17, 18, 19,
        21, 22, 23, 24, 25, 26, 27, 29, 30, 32, 33, 34, 35, 36, 38, 39, 40,
        41, 43, 44, 45, 47, 48, 49, 51, 52, 53, 54, 55, 56, 57, 58, 59, 61,
        64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 79, 81, 82, 85,
        86, 87, 88, 89, 90, 93, 94, 95, 96, 98, 99]),
 array([ 0,  2,  4,  5,  7,  8,  9, 11, 13, 14, 15, 18, 19, 21, 22, 24, 26,
        27, 28, 29, 31, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46,
        47, 48, 50, 52, 54, 55, 56, 57, 59, 61, 62, 63, 64, 65, 66, 67, 68,
        71, 72, 73, 75, 76, 79, 80, 81, 82, 83, 84, 87, 88, 90, 93, 94, 96]),
 array([ 0,  2,  4,  5,  6,  7,  8, 10, 12, 13, 14, 15, 16, 17, 18, 19, 20,
        23, 24, 26, 27, 28, 29, 33, 35, 36, 37, 39, 40, 41, 43, 44, 45, 46,
        48, 49, 50, 51, 52, 53, 54, 55, 57, 59, 60, 61, 62, 63, 64, 66, 69,
        71, 73, 74, 75, 76, 79, 81, 82, 87, 90, 95, 96, 97, 99]),
 array([ 0,  2,  3,  4,  5,  6,  8,  9, 10, 12, 14, 15, 16, 19, 21, 23, 24,
        25, 26, 28, 29, 30, 31, 32, 33, 35, 37, 38, 39, 41, 42, 43, 44, 45,
        46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 58, 64, 66, 69, 74, 81,
        82, 83, 84, 85, 86, 87, 88, 90, 91, 92, 94, 97, 98]),
 array([ 0,  3,  4,  6,  7,  8,  9, 11, 14, 15, 16, 19, 20, 22, 23, 24, 25,
        27, 28, 29, 30, 31, 32, 33, 34, 36, 37, 39, 40, 41, 43, 44, 45, 46,
        47, 49, 51, 52, 54, 56, 57, 58, 59, 60, 61, 62, 63, 65, 66, 67, 68,
        70, 71, 72, 73, 74, 76, 78, 79, 80, 83, 85, 86, 87, 88, 89, 90, 91,
        93, 94, 96, 98, 99]),
 array([ 1,  2,  3,  4,  5,  7,  8,  9, 10, 11, 13, 14, 16, 18, 19, 20, 24,
        27, 28, 29, 30, 34, 36, 37, 38, 39, 40, 41, 42, 43, 46, 47, 48, 51,
        55, 56, 57, 58, 59, 62, 63, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74,
        75, 82, 83, 85, 86, 87, 88, 90, 92, 93, 95, 96, 97]),
 array([ 0,  3,  4,  5,  6,  8, 10, 12, 15, 16, 17, 19, 20, 23, 24, 26, 29,
        31, 32, 33, 34, 36, 37, 38, 42, 43, 44, 45, 46, 47, 48, 49, 51, 52,
        54, 57, 58, 60, 62, 63, 64, 65, 66, 67, 68, 71, 72, 74, 75, 77, 78,
        79, 81, 82, 83, 85, 86, 87, 89, 90, 91, 92, 94, 95, 96, 97, 99]),
 array([ 0,  1,  2,  3,  4,  5,  7, 11, 13, 14, 15, 16, 17, 18, 19, 21, 22,
        23, 24, 25, 27, 28, 31, 33, 34, 36, 37, 38, 41, 44, 47, 48, 51, 54,
        55, 56, 57, 60, 62, 63, 64, 66, 67, 69, 70, 72, 73, 74, 75, 76, 77,
        78, 80, 81, 83, 84, 86, 87, 88, 90, 92, 93, 96, 98, 99]),
 array([ 1,  2,  3,  4,  5,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 18, 19,
        20, 21, 22, 24, 25, 26, 27, 28, 29, 30, 31, 32, 35, 37, 40, 42, 43,
        46, 47, 48, 49, 50, 51, 52, 53, 56, 57, 58, 59, 61, 62, 65, 66, 68,
        69, 70, 71, 73, 75, 76, 77, 80, 82, 84, 87, 88, 90, 92, 93, 94, 95,
        96, 97, 98, 99]),
 array([ 0,  3,  4,  5,  8,  9, 11, 12, 13, 14, 16, 17, 18, 19, 21, 22, 23,
        25, 27, 29, 30, 31, 33, 34, 36, 37, 38, 39, 40, 42, 44, 46, 48, 49,
        51, 52, 53, 54, 57, 58, 59, 60, 62, 65, 66, 67, 68, 71, 72, 73, 74,
        75, 76, 77, 78, 79, 80, 83, 85, 86, 87, 88, 89, 90, 92, 93, 94, 95,
        98, 99]),
 array([ 0,  1,  3,  4,  6,  8,  9, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21,
        23, 24, 25, 26, 28, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
        42, 43, 44, 45, 46, 47, 49, 50, 52, 55, 56, 58, 60, 61, 62, 63, 64,
        65, 66, 67, 69, 70, 71, 72, 73, 74, 76, 77, 78, 79, 80, 82, 83, 84,
        85, 86, 87, 89, 90, 91, 92, 93, 94, 98, 99]),
 array([ 0,  1,  2,  3,  4,  5,  7,  8,  9, 10, 11, 13, 15, 16, 17, 18, 19,
        20, 21, 23, 25, 26, 27, 29, 32, 33, 34, 35, 38, 39, 42, 43, 45, 46,
        49, 50, 51, 52, 53, 54, 55, 56, 59, 60, 61, 62, 64, 66, 68, 70, 71,
        72, 74, 76, 77, 78, 79, 80, 81, 83, 85, 87, 88, 94, 98]),
 array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 18,
        20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 33, 34, 36, 37, 39, 40, 41,
        44, 46, 47, 48, 50, 52, 55, 56, 57, 59, 60, 61, 62, 65, 67, 68, 69,
        70, 71, 72, 73, 75, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 89, 90,
        91, 93, 94, 95, 96, 98, 99]),
 array([ 1,  2,  3,  5,  6,  7,  9, 10, 11, 12, 14, 16, 17, 21, 22, 23, 25,
        26, 28, 29, 31, 33, 35, 36, 37, 38, 40, 41, 43, 44, 45, 46, 49, 50,
        51, 52, 53, 54, 56, 59, 60, 61, 62, 65, 66, 69, 70, 72, 73, 75, 76,
        78, 79, 81, 82, 83, 84, 85, 86, 88, 90, 91, 92, 96, 97, 98, 99]),
 array([ 0,  1,  2,  4,  5,  6,  8,  9, 11, 12, 13, 15, 16, 17, 19, 21, 22,
        23, 25, 26, 27, 28, 29, 31, 32, 33, 35, 36, 37, 38, 39, 41, 42, 43,
        44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 62,
        63, 65, 66, 67, 70, 71, 72, 73, 74, 76, 77, 78, 79, 80, 81, 82, 83,
        86, 88, 91, 92, 93, 94, 95, 96, 98]),
 array([ 0,  1,  2,  3,  4,  5,  6,  9, 10, 11, 13, 14, 16, 18, 19, 20, 21,
        22, 24, 26, 27, 30, 31, 32, 33, 34, 35, 36, 38, 40, 41, 42, 43, 44,
        45, 47, 49, 50, 52, 53, 54, 57, 60, 64, 65, 67, 68, 69, 71, 73, 74,
        75, 76, 77, 78, 79, 81, 83, 84, 87, 88, 90, 91, 93, 96, 98, 99]),
 array([ 0,  1,  2,  3,  5,  6,  7,  8,  9, 11, 15, 16, 18, 19, 22, 23, 24,
        25, 26, 29, 30, 31, 32, 36, 37, 38, 39, 40, 42, 43, 44, 47, 48, 49,
        50, 51, 52, 54, 55, 57, 58, 59, 61, 62, 64, 66, 69, 70, 71, 72, 73,
        74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 88, 90, 91, 93,
        94, 96, 97, 98, 99]),
 array([ 0,  1,  2,  4,  5,  6,  7,  9, 11, 12, 13, 14, 16, 17, 19, 22, 23,
        24, 27, 28, 30, 32, 33, 36, 37, 38, 39, 41, 45, 47, 48, 49, 50, 51,
        55, 57, 58, 60, 63, 64, 65, 67, 69, 70, 71, 72, 74, 75, 76, 81, 82,
        83, 84, 85, 86, 87, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98]),
 array([ 1,  2,  3,  5,  6,  7,  8, 11, 12, 13, 15, 16, 19, 21, 22, 24, 25,
        26, 27, 28, 30, 31, 33, 34, 35, 36, 37, 38, 40, 41, 42, 44, 46, 47,
        48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 62, 63, 65, 66, 67,
        69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 81, 83, 85, 87, 88, 90,
        91, 94, 95, 99]),
 array([ 0,  1,  2,  3,  4,  5,  7,  8, 11, 13, 15, 19, 20, 21, 22, 23, 26,
        27, 28, 30, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 45, 46, 48,
        49, 50, 52, 53, 54, 55, 56, 60, 61, 62, 63, 64, 65, 66, 67, 68, 70,
        71, 72, 74, 75, 76, 77, 78, 79, 80, 81, 84, 86, 87, 88, 89, 90, 91,
        93, 94, 95, 96, 97, 98, 99]),
 array([ 1,  2,  3,  4,  5,  7, 10, 11, 16, 17, 18, 20, 21, 22, 23, 24, 25,
        26, 28, 29, 30, 32, 33, 34, 36, 37, 39, 40, 41, 42, 43, 44, 47, 49,
        52, 53, 54, 55, 56, 57, 60, 61, 62, 63, 64, 65, 66, 67, 68, 70, 71,
        72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 83, 85, 86, 88, 89, 90, 91,
        92, 93, 94, 97, 98, 99]),
 array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 13, 14, 15, 16, 17,
        18, 19, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 35, 39, 41,
        42, 43, 44, 45, 46, 48, 49, 51, 52, 53, 54, 55, 57, 58, 63, 64, 66,
        70, 73, 75, 76, 78, 79, 81, 83, 84, 88, 89, 90, 91, 92, 93, 95, 96,
        97, 98, 99]),
 array([ 0,  1,  2,  3,  4,  5,  7,  8,  9, 10, 12, 14, 15, 16, 17, 18, 19,
        20, 22, 23, 24, 26, 27, 28, 32, 33, 34, 35, 37, 38, 39, 40, 42, 43,
        44, 46, 48, 49, 50, 51, 53, 55, 56, 59, 60, 61, 65, 66, 67, 68, 69,
        70, 71, 73, 74, 75, 76, 77, 78, 79, 80, 83, 84, 85, 87, 88, 90, 91,
        93, 96, 97, 98, 99]),
 array([ 3,  6,  8,  9, 10, 11, 12, 13, 15, 17, 18, 19, 20, 21, 22, 24, 27,
        29, 30, 32, 33, 34, 35, 36, 39, 40, 41, 42, 43, 46, 47, 50, 51, 54,
        55, 56, 59, 60, 61, 62, 63, 65, 67, 68, 69, 71, 72, 73, 75, 76, 77,
        78, 80, 81, 84, 85, 86, 87, 88, 89, 91, 93, 94, 95, 97, 98, 99]),
 array([ 2,  3,  5,  7,  8,  9, 10, 11, 12, 14, 16, 17, 20, 21, 22, 24, 25,
        26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 38, 40, 41, 42, 46, 47, 49,
        50, 51, 53, 54, 57, 58, 59, 60, 61, 62, 64, 65, 66, 67, 68, 69, 70,
        71, 73, 74, 75, 76, 77, 78, 79, 80, 81, 85, 86, 89, 90, 91, 92, 93,
        94, 95, 96, 97, 98]),
 array([ 0,  1,  2,  3,  4,  5,  6,  7, 11, 12, 13, 14, 15, 17, 18, 20, 21,
        23, 24, 27, 28, 32, 34, 35, 36, 39, 40, 41, 43, 44, 46, 47, 48, 49,
        50, 51, 54, 56, 57, 58, 59, 61, 64, 67, 68, 69, 71, 72, 74, 75, 77,
        78, 80, 81, 82, 83, 84, 87, 88, 91, 94, 96, 97]),
 array([ 1,  3,  4,  5,  7,  8,  9, 10, 11, 12, 13, 17, 18, 22, 23, 26, 27,
        29, 31, 33, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49,
        50, 54, 56, 59, 61, 62, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 77,
        78, 79, 80, 81, 82, 83, 84, 86, 88, 89, 92, 96, 98]),
 array([ 0,  1,  2,  3,  8,  9, 11, 13, 14, 15, 17, 18, 19, 21, 22, 24, 25,
        26, 28, 29, 30, 31, 32, 34, 35, 36, 37, 38, 41, 42, 43, 44, 45, 46,
        47, 49, 51, 52, 53, 54, 56, 57, 59, 60, 61, 62, 63, 64, 66, 67, 69,
        72, 74, 75, 76, 77, 78, 82, 83, 84, 87, 88, 89, 90, 92, 93, 96, 97,
        98, 99]),
 array([ 0,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 15, 16, 17, 18,
        19, 24, 25, 27, 28, 29, 30, 31, 32, 35, 36, 37, 39, 41, 43, 44, 45,
        47, 48, 49, 50, 51, 52, 54, 55, 56, 57, 58, 59, 60, 61, 62, 64, 65,
        66, 67, 69, 70, 72, 73, 74, 75, 76, 77, 78, 81, 82, 83, 84, 87, 89,
        91, 93, 94, 96, 98]),
 array([ 0,  1,  5,  7,  8,  9, 10, 13, 14, 15, 16, 20, 21, 22, 23, 24, 25,
        27, 28, 29, 31, 32, 34, 36, 37, 38, 39, 40, 41, 42, 43, 45, 46, 47,
        48, 49, 51, 53, 54, 55, 56, 57, 59, 60, 62, 63, 65, 66, 67, 69, 71,
        72, 73, 74, 75, 76, 77, 79, 80, 81, 82, 83, 84, 89, 91, 92, 95, 96,
        97, 99]),
 array([ 0,  2,  3,  4,  6,  7,  9, 10, 11, 12, 13, 15, 16, 17, 18, 20, 22,
        23, 24, 25, 26, 27, 28, 29, 32, 33, 35, 37, 39, 40, 41, 46, 47, 48,
        50, 51, 52, 53, 56, 58, 59, 60, 63, 64, 65, 66, 67, 68, 70, 71, 73,
        74, 75, 76, 77, 78, 79, 81, 82, 83, 86, 87, 88, 90, 91, 92, 93, 95,
        98]),
 array([ 2,  3,  4,  5,  6,  8, 11, 14, 15, 16, 18, 19, 20, 22, 24, 25, 26,
        27, 29, 31, 32, 36, 38, 39, 43, 45, 46, 47, 48, 50, 51, 53, 55, 56,
        57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 68, 69, 70, 72, 73, 74, 76,
        77, 78, 79, 80, 82, 83, 84, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95,
        96, 97, 98]),
 array([ 1,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 14, 15, 16, 17, 18, 19,
        20, 21, 22, 23, 24, 25, 26, 27, 29, 30, 31, 32, 35, 36, 37, 39, 40,
        41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 56, 57, 60, 61,
        64, 67, 68, 69, 70, 71, 72, 73, 74, 75, 77, 78, 80, 81, 82, 83, 84,
        85, 86, 87, 88, 89, 90, 91, 96, 98, 99]),
 array([ 0,  1,  2,  5,  7,  9, 10, 11, 12, 13, 14, 15, 16, 18, 19, 20, 21,
        23, 25, 26, 27, 28, 29, 33, 36, 37, 38, 39, 41, 42, 43, 45, 46, 48,
        50, 52, 53, 54, 55, 56, 57, 58, 59, 60, 62, 63, 67, 69, 70, 71, 72,
        73, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 88, 89, 90, 93, 94,
        95, 96, 98]),
 array([ 0,  1,  2,  3,  5,  6,  7,  9, 10, 13, 14, 16, 17, 18, 19, 20, 21,
        24, 27, 29, 32, 33, 34, 35, 37, 38, 40, 41, 42, 44, 45, 46, 47, 50,
        51, 53, 56, 59, 60, 61, 62, 63, 64, 66, 69, 71, 72, 73, 74, 75, 77,
        79, 83, 84, 85, 86, 87, 92, 93, 94, 95, 97, 98, 99]),
 array([ 0,  1,  3,  4,  5,  6,  7,  8,  9, 10, 11, 13, 15, 16, 19, 20, 21,
        22, 25, 27, 28, 29, 30, 31, 32, 34, 35, 36, 38, 41, 42, 43, 44, 45,
        46, 47, 48, 49, 51, 53, 54, 55, 56, 58, 59, 60, 61, 64, 67, 69, 70,
        72, 73, 76, 77, 78, 79, 81, 82, 83, 84, 85, 86, 88, 91, 92, 93, 95,
        98, 99]),
 array([ 0,  7,  8,  9, 10, 12, 13, 15, 16, 17, 18, 19, 24, 26, 28, 31, 32,
        33, 39, 42, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 60,
        61, 62, 63, 64, 68, 69, 70, 71, 73, 74, 75, 76, 77, 78, 79, 82, 83,
        84, 85, 86, 88, 89, 92, 94, 95, 96, 98, 99])]

In [5]:
#class select_random_path(object):    
#    def __call__(self):
#        while True:
#            self.cnt += 1
#            yield XX[np.random.randint(low=0, high=len(XX))]

def select_random_path():    
    while True:
        yield XX[np.random.randint(low=0, high=len(XX))]

In [6]:
class RITNode(object):
    def __init__(self, val):
        self._val = val
        self._children = []

    def is_leaf(self):
        return len(self._children) == 0

    @property
    def children(self):
        return self._children
    
    def add_child(self, val):
        val_intersect = np.intersect1d(self._val, val)
        self._children.append(RITNode(val_intersect))
        
    def is_empty(self):
        return len(self._val) == 0
        
    def is_leaf(self):
        return len(self._children) == 0
    
    @property
    def nr_children(self):
        return len(self._children) + sum(child.nr_children for child in self._children)
    
    def _traverse_depth_first(self, _idx):
        yield _idx[0], self
        for child in self.children:
            _idx[0] += 1
            yield from RITNode._traverse_depth_first(child, _idx=_idx)    

class RITTree(RITNode):
    def __len__(self):
        return self.nr_children + 1
    
    def traverse_depth_first(self):
        yield from RITNode._traverse_depth_first(self, _idx=[0])

In [7]:
from functools import partial

def build_tree(feature_paths, max_depth=3, num_splits=5, noisy_split=False, _parent=None, _depth=0):
    """
    Parameters
    ----------
    feature_paths : generator of list of ints
        ...
    max_depth : int
        The built tree will never be deeper than `max_depth`.
    num_splits : int
        At each node, the maximum number of children to be added.
    """
    expand_tree = partial(build_tree, feature_paths, max_depth=max_depth,
                          num_splits=num_splits, noisy_split=noisy_split)
    
    if _parent is None:
        tree = RITTree(next(feature_paths))
        expand_tree(_parent=tree, _depth=0)
        return tree
    else:
        _depth += 1
        if _depth >= max_depth:
            return
        if noisy_split:
            num_splits += np.random.randint(low=0, high=2)
        for i in range(num_splits):
            _parent.add_child(next(feature_paths))
            added_node = _parent.children[-1]
            if not added_node.is_empty():
                expand_tree(_parent=added_node, _depth=_depth)

In [8]:
np.random.seed(12)
tree = build_tree(feature_paths=select_random_path(), max_depth=3, noisy_split=False, num_splits=5)
#path_gen = select_random_path()
#tree = build_tree(feature_paths=path_gen(), max_depth=3, num_splits=5)

In [9]:
#%timeit build_tree(feature_paths=select_random_path())

In [10]:
print("Root:\n", tree._val)
print("Some child:\n", tree.children[0].children[1]._val)


Root:
 [ 1  3  4  5  6  7  8  9 10 11 12 14 15 16 17 18 19 20 21 22 23 24 25 26 27
 29 30 31 32 35 36 37 39 40 41 42 43 44 46 47 48 49 50 51 52 53 54 56 57 60
 61 64 67 68 69 70 71 72 73 74 75 77 78 80 81 82 83 84 85 86 87 88 89 90 91
 96 98 99]
Some child:
 [ 5  8  9 10 12 14 15 17 19 20 21 23 27 29 30 43 44 46 47 49 50 51 52 56 57
 60 61 67 68 69 72 75 80 82 83 84 86 89 90 91]

In [11]:
# If noisy split is False, this should pass
assert(len(tree) == 1 + 5 + 5**2)

In [15]:
# If noisy split is True, this should pass
print(len(tree))
assert(len(tree) <= 1 + 6 + 6**2)


31

In [16]:
list(tree.traverse_depth_first())


Out[16]:
[(0, <__main__.RITTree at 0x10b3c2898>),
 (1, <__main__.RITNode at 0x10b3c2e80>),
 (2, <__main__.RITNode at 0x10b3f58d0>),
 (3, <__main__.RITNode at 0x10b3c2d30>),
 (4, <__main__.RITNode at 0x10b3f59b0>),
 (5, <__main__.RITNode at 0x10b3c2588>),
 (6, <__main__.RITNode at 0x10b3c2908>),
 (7, <__main__.RITNode at 0x10b3c2550>),
 (8, <__main__.RITNode at 0x10b3c2518>),
 (9, <__main__.RITNode at 0x10b3c2a20>),
 (10, <__main__.RITNode at 0x10b3c2470>),
 (11, <__main__.RITNode at 0x10b3c2ef0>),
 (12, <__main__.RITNode at 0x10b3c2390>),
 (13, <__main__.RITNode at 0x10b3c28d0>),
 (14, <__main__.RITNode at 0x10b3c2b00>),
 (15, <__main__.RITNode at 0x10b3c2cc0>),
 (16, <__main__.RITNode at 0x10b3c25f8>),
 (17, <__main__.RITNode at 0x10b3c2b70>),
 (18, <__main__.RITNode at 0x10b3c2b38>),
 (19, <__main__.RITNode at 0x10b3c2ac8>),
 (20, <__main__.RITNode at 0x10b3c2e10>),
 (21, <__main__.RITNode at 0x10b3c2dd8>),
 (22, <__main__.RITNode at 0x10b404048>),
 (23, <__main__.RITNode at 0x10b404080>),
 (24, <__main__.RITNode at 0x10b4040b8>),
 (25, <__main__.RITNode at 0x10b4040f0>),
 (26, <__main__.RITNode at 0x10b404128>),
 (27, <__main__.RITNode at 0x10b404160>),
 (28, <__main__.RITNode at 0x10b404198>),
 (29, <__main__.RITNode at 0x10b4041d0>),
 (30, <__main__.RITNode at 0x10b404208>)]