In [2]:
import numpy as np
In [3]:
X = np.random.random(size=(80, 100)) > 0.3
In [4]:
XX = [np.nonzero(row)[0] for row in X]
XX
Out[4]:
[array([ 1, 2, 4, 5, 7, 8, 9, 10, 12, 13, 16, 17, 18, 19, 20, 21, 22,
26, 27, 28, 31, 32, 33, 36, 37, 38, 39, 40, 41, 42, 43, 47, 49, 50,
51, 52, 53, 56, 57, 58, 59, 60, 61, 63, 65, 66, 68, 69, 70, 71, 72,
73, 74, 76, 77, 78, 81, 82, 84, 85, 86, 88, 89, 90, 91, 92, 93, 94,
95, 96, 97, 98, 99]),
array([ 0, 1, 2, 4, 5, 6, 8, 9, 11, 12, 13, 14, 16, 17, 18, 19, 20,
23, 24, 25, 26, 27, 28, 29, 30, 33, 34, 35, 36, 38, 39, 40, 41, 43,
45, 46, 47, 48, 49, 52, 53, 55, 56, 57, 58, 59, 60, 61, 63, 65, 66,
68, 69, 70, 71, 73, 74, 76, 79, 80, 82, 83, 84, 85, 87, 89, 90, 92,
94, 96, 97, 98, 99]),
array([ 0, 2, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17, 19, 20, 21,
23, 24, 25, 27, 29, 30, 31, 33, 35, 36, 37, 38, 39, 41, 43, 44, 45,
46, 47, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
65, 67, 68, 69, 72, 73, 75, 76, 78, 79, 80, 81, 82, 83, 84, 85, 86,
89, 90, 91, 93, 94, 95, 96, 97, 99]),
array([ 0, 1, 2, 4, 5, 6, 7, 11, 12, 13, 14, 15, 16, 18, 19, 20, 23,
24, 25, 26, 27, 28, 30, 31, 32, 34, 36, 37, 38, 39, 40, 41, 44, 45,
46, 51, 52, 53, 55, 56, 57, 59, 61, 62, 63, 64, 65, 66, 67, 71, 72,
73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90,
91, 93, 94, 95, 96, 97, 98, 99]),
array([ 1, 3, 5, 6, 8, 9, 10, 11, 12, 16, 18, 19, 23, 26, 28, 29, 30,
31, 35, 36, 37, 40, 41, 42, 44, 45, 46, 47, 48, 49, 50, 52, 53, 54,
56, 57, 58, 59, 60, 61, 62, 64, 65, 66, 67, 69, 70, 71, 72, 73, 75,
76, 77, 78, 79, 80, 81, 82, 83, 85, 87, 88, 89, 92, 93, 94, 96, 97,
98]),
array([ 0, 1, 3, 4, 6, 7, 9, 10, 11, 12, 13, 14, 15, 16, 18, 19, 20,
22, 23, 24, 26, 28, 29, 30, 32, 33, 34, 35, 37, 38, 40, 43, 45, 46,
48, 49, 50, 51, 53, 54, 55, 57, 59, 60, 63, 64, 65, 68, 69, 71, 73,
76, 77, 78, 79, 80, 82, 83, 84, 86, 88, 89, 91, 92, 93, 94, 95, 96,
99]),
array([ 0, 1, 2, 3, 7, 9, 11, 12, 13, 14, 16, 17, 18, 19, 20, 21, 23,
27, 29, 30, 32, 33, 36, 37, 38, 39, 41, 42, 43, 44, 45, 46, 47, 48,
49, 50, 51, 53, 55, 56, 57, 58, 59, 61, 62, 64, 65, 66, 67, 68, 69,
70, 72, 73, 75, 76, 77, 78, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89,
91, 92, 93, 95, 96, 97, 98, 99]),
array([ 1, 2, 5, 8, 9, 11, 12, 13, 14, 16, 17, 18, 19, 20, 21, 22, 23,
24, 26, 27, 28, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42,
44, 47, 48, 50, 51, 52, 53, 54, 56, 57, 61, 62, 63, 64, 65, 66, 67,
68, 69, 72, 73, 74, 76, 78, 79, 80, 81, 82, 84, 85, 86, 88, 91, 92,
93, 94, 95, 97, 98]),
array([ 0, 1, 4, 5, 8, 9, 10, 12, 14, 15, 19, 20, 21, 22, 23, 25, 26,
27, 28, 29, 31, 33, 34, 35, 36, 39, 40, 41, 42, 43, 44, 45, 46, 51,
53, 54, 59, 60, 63, 64, 65, 66, 68, 69, 70, 71, 72, 73, 74, 75, 77,
78, 79, 80, 81, 82, 83, 84, 86, 89, 90, 92, 93, 94, 96, 97, 98]),
array([ 0, 2, 3, 4, 5, 8, 9, 11, 13, 14, 15, 16, 18, 19, 20, 22, 23,
25, 27, 28, 29, 30, 31, 33, 34, 35, 38, 40, 41, 43, 44, 45, 46, 47,
48, 53, 54, 55, 59, 60, 61, 62, 64, 65, 66, 67, 68, 70, 72, 73, 74,
75, 76, 78, 79, 80, 81, 82, 83, 85, 86, 87, 88, 91, 93, 95, 97, 99]),
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 29, 30, 31, 33, 34, 35,
36, 37, 38, 39, 40, 41, 42, 43, 44, 47, 49, 51, 52, 53, 55, 56, 57,
58, 60, 61, 62, 63, 64, 65, 67, 68, 69, 72, 73, 74, 75, 79, 80, 81,
82, 83, 84, 89, 90, 92, 94, 96]),
array([ 0, 2, 3, 5, 6, 8, 10, 11, 12, 13, 14, 15, 17, 18, 19, 20, 22,
24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 37, 39, 40, 41, 42,
43, 44, 45, 46, 47, 48, 49, 51, 53, 54, 55, 56, 58, 59, 60, 61, 62,
63, 64, 65, 66, 67, 68, 69, 70, 72, 73, 75, 76, 78, 82, 84, 86, 92,
93, 94, 95, 96, 98, 99]),
array([ 0, 2, 3, 4, 5, 7, 8, 9, 10, 11, 12, 13, 15, 16, 18, 19, 21,
22, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 38, 39, 41,
42, 43, 45, 46, 47, 50, 51, 52, 53, 54, 55, 56, 57, 58, 61, 63, 64,
65, 66, 67, 68, 71, 72, 74, 75, 78, 85, 86, 87, 88, 91, 92, 93, 94,
95, 96, 97, 99]),
array([ 0, 2, 3, 4, 5, 6, 7, 9, 10, 11, 13, 15, 16, 17, 18, 19, 21,
23, 26, 27, 28, 30, 31, 33, 34, 35, 36, 37, 38, 39, 41, 42, 43, 45,
47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 60, 61, 62, 63, 65,
66, 67, 68, 69, 70, 71, 74, 75, 76, 78, 80, 81, 82, 85, 88, 89, 92,
93, 94, 95, 96, 98, 99]),
array([ 1, 3, 4, 5, 6, 7, 8, 10, 11, 13, 15, 16, 18, 19, 20, 21, 22,
23, 26, 27, 29, 31, 32, 34, 35, 37, 38, 44, 45, 47, 48, 49, 50, 51,
53, 54, 55, 57, 58, 59, 60, 61, 64, 65, 66, 67, 68, 70, 71, 73, 74,
75, 77, 79, 80, 83, 84, 86, 87, 88, 92, 95, 97, 98, 99]),
array([ 0, 1, 2, 3, 4, 5, 7, 9, 11, 13, 14, 15, 17, 18, 20, 22, 23,
26, 27, 28, 29, 30, 32, 33, 34, 35, 36, 37, 38, 40, 41, 42, 43, 44,
45, 46, 47, 48, 49, 50, 52, 54, 56, 58, 59, 60, 62, 63, 64, 65, 66,
67, 68, 69, 70, 71, 72, 73, 77, 78, 79, 80, 83, 85, 86, 87, 88, 90,
92, 93, 94, 95, 97, 98, 99]),
array([ 1, 2, 3, 4, 11, 12, 13, 14, 16, 17, 19, 20, 22, 23, 24, 25, 27,
28, 30, 32, 33, 34, 35, 37, 38, 39, 40, 43, 44, 45, 48, 50, 53, 54,
55, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72,
73, 74, 75, 76, 77, 78, 80, 81, 82, 84, 85, 86, 88, 89, 90, 91, 92,
94, 96, 97, 98, 99]),
array([ 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19,
21, 22, 23, 24, 26, 27, 28, 29, 31, 33, 34, 36, 39, 40, 42, 43, 44,
45, 46, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 61, 62, 63,
64, 65, 67, 68, 69, 70, 72, 73, 74, 78, 79, 81, 82, 83, 87, 88, 90,
91, 93, 94, 97, 99]),
array([ 0, 1, 3, 4, 5, 7, 8, 9, 10, 11, 12, 14, 15, 16, 18, 19, 20,
21, 22, 24, 25, 26, 27, 29, 30, 32, 33, 34, 35, 36, 37, 38, 39, 40,
41, 42, 43, 44, 45, 46, 47, 49, 50, 51, 54, 55, 57, 58, 59, 60, 61,
62, 63, 64, 65, 66, 68, 69, 70, 72, 74, 75, 76, 77, 78, 79, 80, 82,
84, 86, 90, 91, 93, 94, 95, 98, 99]),
array([ 0, 1, 2, 3, 4, 5, 6, 7, 9, 10, 12, 13, 15, 17, 18, 19, 22,
24, 26, 27, 28, 29, 32, 34, 36, 37, 38, 40, 41, 43, 45, 46, 47, 48,
49, 50, 51, 54, 55, 56, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68,
70, 72, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 86, 87, 89, 90,
92, 93, 94, 95, 97, 99]),
array([ 1, 2, 4, 5, 8, 9, 10, 12, 13, 14, 16, 17, 19, 20, 21, 24, 25,
26, 27, 29, 31, 34, 35, 37, 38, 39, 40, 41, 45, 46, 48, 50, 52, 53,
55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,
72, 73, 74, 76, 77, 79, 80, 81, 83, 85, 86, 87, 88, 89, 91, 92, 93,
94, 96, 98]),
array([ 0, 1, 2, 3, 6, 8, 9, 10, 14, 15, 17, 18, 19, 20, 21, 22, 23,
25, 27, 28, 29, 32, 33, 34, 35, 37, 38, 39, 40, 42, 43, 45, 46, 47,
48, 50, 51, 52, 53, 54, 55, 57, 58, 59, 60, 63, 65, 66, 67, 68, 69,
71, 72, 74, 75, 76, 77, 78, 80, 81, 83, 84, 85, 86, 87, 88, 89, 90,
92, 94, 95, 96, 97, 98, 99]),
array([ 2, 3, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14, 16, 19, 20, 21, 22,
23, 24, 25, 26, 27, 28, 29, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40,
41, 42, 43, 46, 47, 48, 49, 51, 52, 53, 57, 58, 59, 62, 64, 65, 66,
67, 68, 69, 70, 71, 72, 73, 76, 77, 78, 80, 82, 83, 84, 85, 86, 87,
89, 90, 91, 92, 93, 94, 96, 97, 98]),
array([ 0, 1, 2, 3, 4, 5, 7, 8, 9, 11, 12, 15, 16, 17, 18, 19, 20,
21, 22, 24, 26, 28, 29, 30, 31, 32, 33, 34, 36, 38, 39, 40, 41, 43,
44, 45, 47, 48, 49, 50, 52, 53, 54, 57, 58, 60, 61, 62, 63, 64, 68,
71, 72, 73, 75, 76, 79, 80, 82, 83, 84, 85, 86, 87, 88, 89, 91, 92,
98, 99]),
array([ 0, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, 19,
21, 23, 24, 26, 27, 28, 30, 31, 32, 36, 39, 40, 41, 43, 45, 46, 47,
48, 50, 53, 54, 55, 56, 59, 60, 64, 65, 66, 67, 69, 70, 71, 72, 73,
74, 75, 77, 78, 79, 80, 81, 84, 85, 86, 87, 89, 90, 91, 94, 95, 96,
98, 99]),
array([ 0, 1, 2, 3, 4, 5, 6, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18,
19, 20, 21, 23, 24, 25, 26, 27, 28, 29, 34, 35, 36, 37, 38, 39, 40,
43, 44, 48, 49, 50, 51, 52, 54, 55, 56, 58, 59, 60, 62, 63, 64, 66,
67, 68, 69, 70, 71, 72, 74, 75, 76, 78, 79, 80, 81, 82, 85, 86, 88,
89, 90, 91, 92, 93, 94, 96, 97, 98, 99]),
array([ 0, 2, 3, 4, 5, 7, 8, 9, 11, 12, 13, 14, 16, 17, 18, 19, 20,
21, 22, 24, 28, 29, 30, 31, 32, 34, 35, 36, 38, 39, 40, 42, 43, 44,
45, 46, 48, 49, 50, 51, 52, 53, 54, 56, 57, 59, 60, 62, 63, 64, 65,
66, 67, 68, 69, 70, 71, 72, 74, 75, 76, 77, 78, 79, 80, 81, 83, 84,
86, 87, 88, 90, 91, 92, 93, 94, 96, 98, 99]),
array([ 0, 1, 2, 4, 5, 8, 9, 10, 12, 13, 14, 15, 17, 19, 20, 21, 23,
26, 27, 28, 29, 30, 32, 34, 40, 42, 43, 44, 45, 46, 47, 49, 50, 51,
52, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 66, 67, 68, 69, 70, 71,
72, 74, 75, 76, 77, 79, 80, 82, 83, 84, 86, 87, 88, 89, 90, 91, 92,
94, 97, 98]),
array([ 1, 2, 3, 4, 5, 6, 7, 9, 10, 11, 12, 14, 16, 17, 18, 19, 21,
22, 23, 24, 25, 26, 27, 29, 30, 31, 32, 33, 34, 35, 37, 39, 41, 43,
44, 45, 47, 48, 49, 50, 51, 52, 54, 55, 56, 59, 60, 61, 63, 64, 65,
66, 68, 69, 70, 71, 72, 75, 77, 79, 80, 81, 82, 83, 84, 85, 86, 87,
88, 89, 90, 91, 95, 96, 97, 99]),
array([ 0, 2, 3, 4, 5, 7, 10, 11, 12, 14, 15, 18, 19, 20, 21, 24, 25,
26, 27, 28, 30, 31, 34, 36, 37, 40, 41, 44, 46, 49, 51, 53, 54, 55,
56, 57, 58, 61, 62, 63, 64, 65, 66, 68, 71, 73, 75, 76, 78, 79, 80,
82, 84, 85, 86, 90, 92, 93, 94, 96, 97, 98, 99]),
array([ 0, 4, 5, 6, 9, 10, 11, 13, 14, 15, 16, 17, 19, 20, 21, 22, 24,
25, 27, 28, 29, 30, 31, 32, 35, 37, 39, 40, 41, 42, 43, 45, 46, 47,
48, 49, 51, 53, 56, 57, 60, 61, 62, 65, 66, 68, 69, 71, 72, 73, 74,
75, 76, 77, 78, 82, 83, 85, 87, 89, 92, 94, 95, 96, 98]),
array([ 0, 1, 4, 5, 6, 7, 9, 10, 11, 12, 15, 16, 18, 19, 20, 22, 23,
24, 27, 28, 29, 30, 31, 32, 34, 35, 36, 37, 38, 40, 42, 43, 44, 45,
47, 48, 49, 50, 51, 54, 56, 57, 58, 59, 60, 62, 64, 65, 66, 67, 69,
70, 72, 73, 74, 75, 77, 78, 79, 80, 81, 83, 85, 86, 88, 89, 94, 96,
97, 98]),
array([ 1, 2, 3, 5, 6, 7, 8, 9, 12, 14, 15, 16, 17, 20, 21, 22, 23,
24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 35, 36, 39, 40, 42, 44, 46,
47, 48, 49, 50, 54, 55, 56, 59, 60, 61, 62, 63, 64, 65, 67, 68, 69,
71, 72, 73, 74, 75, 76, 77, 79, 80, 81, 83, 84, 85, 86, 87, 88, 89,
91, 92, 95, 96, 99]),
array([ 0, 3, 4, 7, 8, 9, 11, 12, 13, 14, 15, 16, 17, 18, 21, 22, 24,
27, 28, 29, 30, 31, 32, 34, 36, 38, 40, 42, 44, 45, 46, 47, 49, 50,
51, 52, 53, 54, 55, 56, 57, 59, 60, 65, 66, 67, 68, 69, 70, 71, 72,
73, 75, 76, 78, 79, 80, 81, 84, 85, 86, 87, 88, 89, 90, 91, 93, 95,
96, 97, 98, 99]),
array([ 0, 1, 2, 3, 4, 6, 7, 8, 9, 11, 15, 16, 17, 18, 19, 21, 22,
23, 24, 26, 27, 29, 30, 32, 33, 34, 35, 37, 40, 41, 42, 43, 44, 45,
46, 47, 48, 49, 50, 51, 52, 55, 57, 58, 60, 61, 62, 63, 64, 65, 66,
67, 72, 75, 77, 79, 80, 81, 83, 84, 85, 86, 87, 88, 89, 90, 91, 97,
99]),
array([ 1, 2, 3, 4, 5, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
19, 20, 21, 22, 23, 24, 25, 28, 29, 30, 32, 33, 34, 35, 36, 37, 40,
43, 44, 46, 47, 48, 49, 50, 51, 53, 54, 55, 57, 61, 62, 63, 64, 65,
66, 67, 69, 72, 74, 75, 76, 77, 78, 79, 80, 81, 82, 85, 86, 87, 88,
89, 91, 92, 93, 94, 96, 97]),
array([ 0, 1, 2, 3, 5, 6, 9, 10, 11, 12, 14, 16, 19, 20, 21, 22, 23,
24, 25, 26, 28, 30, 31, 32, 34, 35, 36, 37, 38, 39, 40, 44, 46, 47,
48, 49, 52, 53, 54, 57, 58, 59, 61, 62, 65, 66, 67, 68, 69, 70, 72,
73, 75, 76, 77, 78, 79, 80, 83, 85, 86, 87, 88, 89, 91, 96, 98, 99]),
array([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 16, 17, 18, 19, 20, 23,
24, 25, 26, 27, 28, 30, 31, 33, 35, 36, 37, 38, 39, 41, 42, 43, 44,
45, 46, 47, 49, 52, 53, 54, 55, 56, 57, 59, 60, 61, 62, 63, 64, 65,
66, 67, 69, 72, 74, 75, 76, 77, 79, 81, 82, 83, 84, 85, 87, 88, 90,
91, 92, 95, 96, 97, 98, 99]),
array([ 0, 1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 25, 26, 33, 34, 35, 37, 38, 41, 42, 43, 44,
45, 46, 47, 48, 49, 50, 51, 54, 56, 57, 58, 59, 61, 63, 65, 67, 68,
69, 70, 71, 72, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 88,
90, 93, 94, 95, 96, 97, 98, 99]),
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 12, 13, 14, 16, 17, 20, 22, 23,
24, 26, 28, 29, 31, 32, 33, 34, 35, 36, 38, 39, 40, 41, 44, 45, 46,
48, 51, 53, 54, 55, 56, 57, 59, 60, 61, 62, 64, 65, 66, 67, 68, 70,
73, 74, 75, 78, 79, 80, 81, 82, 83, 84, 85, 86, 88, 89, 90, 91, 92,
94, 95, 97, 98]),
array([ 0, 1, 4, 6, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, 19, 21, 22,
23, 25, 26, 29, 30, 31, 32, 33, 34, 35, 38, 40, 41, 45, 46, 47, 48,
49, 53, 54, 55, 57, 58, 59, 64, 65, 67, 72, 74, 75, 76, 78, 81, 82,
84, 85, 86, 89, 90, 92, 93, 96, 97, 99]),
array([ 0, 1, 4, 5, 6, 9, 15, 16, 17, 19, 20, 21, 23, 24, 25, 30, 31,
33, 34, 35, 38, 39, 41, 42, 44, 47, 48, 50, 51, 52, 54, 57, 58, 59,
61, 63, 64, 65, 66, 68, 70, 71, 72, 73, 74, 76, 77, 78, 79, 83, 84,
85, 86, 87, 88, 90, 91, 92, 93, 94, 95, 98, 99]),
array([ 0, 1, 2, 3, 4, 5, 6, 7, 10, 13, 15, 16, 17, 18, 19, 20, 23,
26, 28, 30, 31, 33, 34, 35, 36, 37, 39, 40, 41, 42, 45, 46, 47, 48,
49, 50, 51, 52, 53, 54, 55, 58, 59, 60, 61, 62, 65, 66, 67, 69, 70,
71, 74, 75, 78, 81, 82, 83, 84, 85, 86, 87, 88, 90, 92, 94, 95, 96,
98]),
array([ 0, 1, 2, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14, 15, 17, 18, 19,
21, 22, 23, 24, 25, 26, 27, 29, 30, 32, 33, 34, 35, 36, 38, 39, 40,
41, 43, 44, 45, 47, 48, 49, 51, 52, 53, 54, 55, 56, 57, 58, 59, 61,
64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 79, 81, 82, 85,
86, 87, 88, 89, 90, 93, 94, 95, 96, 98, 99]),
array([ 0, 2, 4, 5, 7, 8, 9, 11, 13, 14, 15, 18, 19, 21, 22, 24, 26,
27, 28, 29, 31, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46,
47, 48, 50, 52, 54, 55, 56, 57, 59, 61, 62, 63, 64, 65, 66, 67, 68,
71, 72, 73, 75, 76, 79, 80, 81, 82, 83, 84, 87, 88, 90, 93, 94, 96]),
array([ 0, 2, 4, 5, 6, 7, 8, 10, 12, 13, 14, 15, 16, 17, 18, 19, 20,
23, 24, 26, 27, 28, 29, 33, 35, 36, 37, 39, 40, 41, 43, 44, 45, 46,
48, 49, 50, 51, 52, 53, 54, 55, 57, 59, 60, 61, 62, 63, 64, 66, 69,
71, 73, 74, 75, 76, 79, 81, 82, 87, 90, 95, 96, 97, 99]),
array([ 0, 2, 3, 4, 5, 6, 8, 9, 10, 12, 14, 15, 16, 19, 21, 23, 24,
25, 26, 28, 29, 30, 31, 32, 33, 35, 37, 38, 39, 41, 42, 43, 44, 45,
46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 58, 64, 66, 69, 74, 81,
82, 83, 84, 85, 86, 87, 88, 90, 91, 92, 94, 97, 98]),
array([ 0, 3, 4, 6, 7, 8, 9, 11, 14, 15, 16, 19, 20, 22, 23, 24, 25,
27, 28, 29, 30, 31, 32, 33, 34, 36, 37, 39, 40, 41, 43, 44, 45, 46,
47, 49, 51, 52, 54, 56, 57, 58, 59, 60, 61, 62, 63, 65, 66, 67, 68,
70, 71, 72, 73, 74, 76, 78, 79, 80, 83, 85, 86, 87, 88, 89, 90, 91,
93, 94, 96, 98, 99]),
array([ 1, 2, 3, 4, 5, 7, 8, 9, 10, 11, 13, 14, 16, 18, 19, 20, 24,
27, 28, 29, 30, 34, 36, 37, 38, 39, 40, 41, 42, 43, 46, 47, 48, 51,
55, 56, 57, 58, 59, 62, 63, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74,
75, 82, 83, 85, 86, 87, 88, 90, 92, 93, 95, 96, 97]),
array([ 0, 3, 4, 5, 6, 8, 10, 12, 15, 16, 17, 19, 20, 23, 24, 26, 29,
31, 32, 33, 34, 36, 37, 38, 42, 43, 44, 45, 46, 47, 48, 49, 51, 52,
54, 57, 58, 60, 62, 63, 64, 65, 66, 67, 68, 71, 72, 74, 75, 77, 78,
79, 81, 82, 83, 85, 86, 87, 89, 90, 91, 92, 94, 95, 96, 97, 99]),
array([ 0, 1, 2, 3, 4, 5, 7, 11, 13, 14, 15, 16, 17, 18, 19, 21, 22,
23, 24, 25, 27, 28, 31, 33, 34, 36, 37, 38, 41, 44, 47, 48, 51, 54,
55, 56, 57, 60, 62, 63, 64, 66, 67, 69, 70, 72, 73, 74, 75, 76, 77,
78, 80, 81, 83, 84, 86, 87, 88, 90, 92, 93, 96, 98, 99]),
array([ 1, 2, 3, 4, 5, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, 19,
20, 21, 22, 24, 25, 26, 27, 28, 29, 30, 31, 32, 35, 37, 40, 42, 43,
46, 47, 48, 49, 50, 51, 52, 53, 56, 57, 58, 59, 61, 62, 65, 66, 68,
69, 70, 71, 73, 75, 76, 77, 80, 82, 84, 87, 88, 90, 92, 93, 94, 95,
96, 97, 98, 99]),
array([ 0, 3, 4, 5, 8, 9, 11, 12, 13, 14, 16, 17, 18, 19, 21, 22, 23,
25, 27, 29, 30, 31, 33, 34, 36, 37, 38, 39, 40, 42, 44, 46, 48, 49,
51, 52, 53, 54, 57, 58, 59, 60, 62, 65, 66, 67, 68, 71, 72, 73, 74,
75, 76, 77, 78, 79, 80, 83, 85, 86, 87, 88, 89, 90, 92, 93, 94, 95,
98, 99]),
array([ 0, 1, 3, 4, 6, 8, 9, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21,
23, 24, 25, 26, 28, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
42, 43, 44, 45, 46, 47, 49, 50, 52, 55, 56, 58, 60, 61, 62, 63, 64,
65, 66, 67, 69, 70, 71, 72, 73, 74, 76, 77, 78, 79, 80, 82, 83, 84,
85, 86, 87, 89, 90, 91, 92, 93, 94, 98, 99]),
array([ 0, 1, 2, 3, 4, 5, 7, 8, 9, 10, 11, 13, 15, 16, 17, 18, 19,
20, 21, 23, 25, 26, 27, 29, 32, 33, 34, 35, 38, 39, 42, 43, 45, 46,
49, 50, 51, 52, 53, 54, 55, 56, 59, 60, 61, 62, 64, 66, 68, 70, 71,
72, 74, 76, 77, 78, 79, 80, 81, 83, 85, 87, 88, 94, 98]),
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 18,
20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 33, 34, 36, 37, 39, 40, 41,
44, 46, 47, 48, 50, 52, 55, 56, 57, 59, 60, 61, 62, 65, 67, 68, 69,
70, 71, 72, 73, 75, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 89, 90,
91, 93, 94, 95, 96, 98, 99]),
array([ 1, 2, 3, 5, 6, 7, 9, 10, 11, 12, 14, 16, 17, 21, 22, 23, 25,
26, 28, 29, 31, 33, 35, 36, 37, 38, 40, 41, 43, 44, 45, 46, 49, 50,
51, 52, 53, 54, 56, 59, 60, 61, 62, 65, 66, 69, 70, 72, 73, 75, 76,
78, 79, 81, 82, 83, 84, 85, 86, 88, 90, 91, 92, 96, 97, 98, 99]),
array([ 0, 1, 2, 4, 5, 6, 8, 9, 11, 12, 13, 15, 16, 17, 19, 21, 22,
23, 25, 26, 27, 28, 29, 31, 32, 33, 35, 36, 37, 38, 39, 41, 42, 43,
44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 62,
63, 65, 66, 67, 70, 71, 72, 73, 74, 76, 77, 78, 79, 80, 81, 82, 83,
86, 88, 91, 92, 93, 94, 95, 96, 98]),
array([ 0, 1, 2, 3, 4, 5, 6, 9, 10, 11, 13, 14, 16, 18, 19, 20, 21,
22, 24, 26, 27, 30, 31, 32, 33, 34, 35, 36, 38, 40, 41, 42, 43, 44,
45, 47, 49, 50, 52, 53, 54, 57, 60, 64, 65, 67, 68, 69, 71, 73, 74,
75, 76, 77, 78, 79, 81, 83, 84, 87, 88, 90, 91, 93, 96, 98, 99]),
array([ 0, 1, 2, 3, 5, 6, 7, 8, 9, 11, 15, 16, 18, 19, 22, 23, 24,
25, 26, 29, 30, 31, 32, 36, 37, 38, 39, 40, 42, 43, 44, 47, 48, 49,
50, 51, 52, 54, 55, 57, 58, 59, 61, 62, 64, 66, 69, 70, 71, 72, 73,
74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 88, 90, 91, 93,
94, 96, 97, 98, 99]),
array([ 0, 1, 2, 4, 5, 6, 7, 9, 11, 12, 13, 14, 16, 17, 19, 22, 23,
24, 27, 28, 30, 32, 33, 36, 37, 38, 39, 41, 45, 47, 48, 49, 50, 51,
55, 57, 58, 60, 63, 64, 65, 67, 69, 70, 71, 72, 74, 75, 76, 81, 82,
83, 84, 85, 86, 87, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98]),
array([ 1, 2, 3, 5, 6, 7, 8, 11, 12, 13, 15, 16, 19, 21, 22, 24, 25,
26, 27, 28, 30, 31, 33, 34, 35, 36, 37, 38, 40, 41, 42, 44, 46, 47,
48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 62, 63, 65, 66, 67,
69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 81, 83, 85, 87, 88, 90,
91, 94, 95, 99]),
array([ 0, 1, 2, 3, 4, 5, 7, 8, 11, 13, 15, 19, 20, 21, 22, 23, 26,
27, 28, 30, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 45, 46, 48,
49, 50, 52, 53, 54, 55, 56, 60, 61, 62, 63, 64, 65, 66, 67, 68, 70,
71, 72, 74, 75, 76, 77, 78, 79, 80, 81, 84, 86, 87, 88, 89, 90, 91,
93, 94, 95, 96, 97, 98, 99]),
array([ 1, 2, 3, 4, 5, 7, 10, 11, 16, 17, 18, 20, 21, 22, 23, 24, 25,
26, 28, 29, 30, 32, 33, 34, 36, 37, 39, 40, 41, 42, 43, 44, 47, 49,
52, 53, 54, 55, 56, 57, 60, 61, 62, 63, 64, 65, 66, 67, 68, 70, 71,
72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 83, 85, 86, 88, 89, 90, 91,
92, 93, 94, 97, 98, 99]),
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17,
18, 19, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 35, 39, 41,
42, 43, 44, 45, 46, 48, 49, 51, 52, 53, 54, 55, 57, 58, 63, 64, 66,
70, 73, 75, 76, 78, 79, 81, 83, 84, 88, 89, 90, 91, 92, 93, 95, 96,
97, 98, 99]),
array([ 0, 1, 2, 3, 4, 5, 7, 8, 9, 10, 12, 14, 15, 16, 17, 18, 19,
20, 22, 23, 24, 26, 27, 28, 32, 33, 34, 35, 37, 38, 39, 40, 42, 43,
44, 46, 48, 49, 50, 51, 53, 55, 56, 59, 60, 61, 65, 66, 67, 68, 69,
70, 71, 73, 74, 75, 76, 77, 78, 79, 80, 83, 84, 85, 87, 88, 90, 91,
93, 96, 97, 98, 99]),
array([ 3, 6, 8, 9, 10, 11, 12, 13, 15, 17, 18, 19, 20, 21, 22, 24, 27,
29, 30, 32, 33, 34, 35, 36, 39, 40, 41, 42, 43, 46, 47, 50, 51, 54,
55, 56, 59, 60, 61, 62, 63, 65, 67, 68, 69, 71, 72, 73, 75, 76, 77,
78, 80, 81, 84, 85, 86, 87, 88, 89, 91, 93, 94, 95, 97, 98, 99]),
array([ 2, 3, 5, 7, 8, 9, 10, 11, 12, 14, 16, 17, 20, 21, 22, 24, 25,
26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 38, 40, 41, 42, 46, 47, 49,
50, 51, 53, 54, 57, 58, 59, 60, 61, 62, 64, 65, 66, 67, 68, 69, 70,
71, 73, 74, 75, 76, 77, 78, 79, 80, 81, 85, 86, 89, 90, 91, 92, 93,
94, 95, 96, 97, 98]),
array([ 0, 1, 2, 3, 4, 5, 6, 7, 11, 12, 13, 14, 15, 17, 18, 20, 21,
23, 24, 27, 28, 32, 34, 35, 36, 39, 40, 41, 43, 44, 46, 47, 48, 49,
50, 51, 54, 56, 57, 58, 59, 61, 64, 67, 68, 69, 71, 72, 74, 75, 77,
78, 80, 81, 82, 83, 84, 87, 88, 91, 94, 96, 97]),
array([ 1, 3, 4, 5, 7, 8, 9, 10, 11, 12, 13, 17, 18, 22, 23, 26, 27,
29, 31, 33, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49,
50, 54, 56, 59, 61, 62, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 77,
78, 79, 80, 81, 82, 83, 84, 86, 88, 89, 92, 96, 98]),
array([ 0, 1, 2, 3, 8, 9, 11, 13, 14, 15, 17, 18, 19, 21, 22, 24, 25,
26, 28, 29, 30, 31, 32, 34, 35, 36, 37, 38, 41, 42, 43, 44, 45, 46,
47, 49, 51, 52, 53, 54, 56, 57, 59, 60, 61, 62, 63, 64, 66, 67, 69,
72, 74, 75, 76, 77, 78, 82, 83, 84, 87, 88, 89, 90, 92, 93, 96, 97,
98, 99]),
array([ 0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 15, 16, 17, 18,
19, 24, 25, 27, 28, 29, 30, 31, 32, 35, 36, 37, 39, 41, 43, 44, 45,
47, 48, 49, 50, 51, 52, 54, 55, 56, 57, 58, 59, 60, 61, 62, 64, 65,
66, 67, 69, 70, 72, 73, 74, 75, 76, 77, 78, 81, 82, 83, 84, 87, 89,
91, 93, 94, 96, 98]),
array([ 0, 1, 5, 7, 8, 9, 10, 13, 14, 15, 16, 20, 21, 22, 23, 24, 25,
27, 28, 29, 31, 32, 34, 36, 37, 38, 39, 40, 41, 42, 43, 45, 46, 47,
48, 49, 51, 53, 54, 55, 56, 57, 59, 60, 62, 63, 65, 66, 67, 69, 71,
72, 73, 74, 75, 76, 77, 79, 80, 81, 82, 83, 84, 89, 91, 92, 95, 96,
97, 99]),
array([ 0, 2, 3, 4, 6, 7, 9, 10, 11, 12, 13, 15, 16, 17, 18, 20, 22,
23, 24, 25, 26, 27, 28, 29, 32, 33, 35, 37, 39, 40, 41, 46, 47, 48,
50, 51, 52, 53, 56, 58, 59, 60, 63, 64, 65, 66, 67, 68, 70, 71, 73,
74, 75, 76, 77, 78, 79, 81, 82, 83, 86, 87, 88, 90, 91, 92, 93, 95,
98]),
array([ 2, 3, 4, 5, 6, 8, 11, 14, 15, 16, 18, 19, 20, 22, 24, 25, 26,
27, 29, 31, 32, 36, 38, 39, 43, 45, 46, 47, 48, 50, 51, 53, 55, 56,
57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 68, 69, 70, 72, 73, 74, 76,
77, 78, 79, 80, 82, 83, 84, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95,
96, 97, 98]),
array([ 1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, 17, 18, 19,
20, 21, 22, 23, 24, 25, 26, 27, 29, 30, 31, 32, 35, 36, 37, 39, 40,
41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 56, 57, 60, 61,
64, 67, 68, 69, 70, 71, 72, 73, 74, 75, 77, 78, 80, 81, 82, 83, 84,
85, 86, 87, 88, 89, 90, 91, 96, 98, 99]),
array([ 0, 1, 2, 5, 7, 9, 10, 11, 12, 13, 14, 15, 16, 18, 19, 20, 21,
23, 25, 26, 27, 28, 29, 33, 36, 37, 38, 39, 41, 42, 43, 45, 46, 48,
50, 52, 53, 54, 55, 56, 57, 58, 59, 60, 62, 63, 67, 69, 70, 71, 72,
73, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 88, 89, 90, 93, 94,
95, 96, 98]),
array([ 0, 1, 2, 3, 5, 6, 7, 9, 10, 13, 14, 16, 17, 18, 19, 20, 21,
24, 27, 29, 32, 33, 34, 35, 37, 38, 40, 41, 42, 44, 45, 46, 47, 50,
51, 53, 56, 59, 60, 61, 62, 63, 64, 66, 69, 71, 72, 73, 74, 75, 77,
79, 83, 84, 85, 86, 87, 92, 93, 94, 95, 97, 98, 99]),
array([ 0, 1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 16, 19, 20, 21,
22, 25, 27, 28, 29, 30, 31, 32, 34, 35, 36, 38, 41, 42, 43, 44, 45,
46, 47, 48, 49, 51, 53, 54, 55, 56, 58, 59, 60, 61, 64, 67, 69, 70,
72, 73, 76, 77, 78, 79, 81, 82, 83, 84, 85, 86, 88, 91, 92, 93, 95,
98, 99]),
array([ 0, 7, 8, 9, 10, 12, 13, 15, 16, 17, 18, 19, 24, 26, 28, 31, 32,
33, 39, 42, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 60,
61, 62, 63, 64, 68, 69, 70, 71, 73, 74, 75, 76, 77, 78, 79, 82, 83,
84, 85, 86, 88, 89, 92, 94, 95, 96, 98, 99])]
In [5]:
#class select_random_path(object):
# def __call__(self):
# while True:
# self.cnt += 1
# yield XX[np.random.randint(low=0, high=len(XX))]
def select_random_path():
while True:
yield XX[np.random.randint(low=0, high=len(XX))]
In [6]:
class RITNode(object):
def __init__(self, val):
self._val = val
self._children = []
def is_leaf(self):
return len(self._children) == 0
@property
def children(self):
return self._children
def add_child(self, val):
val_intersect = np.intersect1d(self._val, val)
self._children.append(RITNode(val_intersect))
def is_empty(self):
return len(self._val) == 0
def is_leaf(self):
return len(self._children) == 0
@property
def nr_children(self):
return len(self._children) + sum(child.nr_children for child in self._children)
def _traverse_depth_first(self, _idx):
yield _idx[0], self
for child in self.children:
_idx[0] += 1
yield from RITNode._traverse_depth_first(child, _idx=_idx)
class RITTree(RITNode):
def __len__(self):
return self.nr_children + 1
def traverse_depth_first(self):
yield from RITNode._traverse_depth_first(self, _idx=[0])
In [7]:
from functools import partial
def build_tree(feature_paths, max_depth=3, num_splits=5, noisy_split=False, _parent=None, _depth=0):
"""
Parameters
----------
feature_paths : generator of list of ints
...
max_depth : int
The built tree will never be deeper than `max_depth`.
num_splits : int
At each node, the maximum number of children to be added.
"""
expand_tree = partial(build_tree, feature_paths, max_depth=max_depth,
num_splits=num_splits, noisy_split=noisy_split)
if _parent is None:
tree = RITTree(next(feature_paths))
expand_tree(_parent=tree, _depth=0)
return tree
else:
_depth += 1
if _depth >= max_depth:
return
if noisy_split:
num_splits += np.random.randint(low=0, high=2)
for i in range(num_splits):
_parent.add_child(next(feature_paths))
added_node = _parent.children[-1]
if not added_node.is_empty():
expand_tree(_parent=added_node, _depth=_depth)
In [8]:
np.random.seed(12)
tree = build_tree(feature_paths=select_random_path(), max_depth=3, noisy_split=False, num_splits=5)
#path_gen = select_random_path()
#tree = build_tree(feature_paths=path_gen(), max_depth=3, num_splits=5)
In [9]:
#%timeit build_tree(feature_paths=select_random_path())
In [10]:
print("Root:\n", tree._val)
print("Some child:\n", tree.children[0].children[1]._val)
Root:
[ 1 3 4 5 6 7 8 9 10 11 12 14 15 16 17 18 19 20 21 22 23 24 25 26 27
29 30 31 32 35 36 37 39 40 41 42 43 44 46 47 48 49 50 51 52 53 54 56 57 60
61 64 67 68 69 70 71 72 73 74 75 77 78 80 81 82 83 84 85 86 87 88 89 90 91
96 98 99]
Some child:
[ 5 8 9 10 12 14 15 17 19 20 21 23 27 29 30 43 44 46 47 49 50 51 52 56 57
60 61 67 68 69 72 75 80 82 83 84 86 89 90 91]
In [11]:
# If noisy split is False, this should pass
assert(len(tree) == 1 + 5 + 5**2)
In [15]:
# If noisy split is True, this should pass
print(len(tree))
assert(len(tree) <= 1 + 6 + 6**2)
31
In [16]:
list(tree.traverse_depth_first())
Out[16]:
[(0, <__main__.RITTree at 0x10b3c2898>),
(1, <__main__.RITNode at 0x10b3c2e80>),
(2, <__main__.RITNode at 0x10b3f58d0>),
(3, <__main__.RITNode at 0x10b3c2d30>),
(4, <__main__.RITNode at 0x10b3f59b0>),
(5, <__main__.RITNode at 0x10b3c2588>),
(6, <__main__.RITNode at 0x10b3c2908>),
(7, <__main__.RITNode at 0x10b3c2550>),
(8, <__main__.RITNode at 0x10b3c2518>),
(9, <__main__.RITNode at 0x10b3c2a20>),
(10, <__main__.RITNode at 0x10b3c2470>),
(11, <__main__.RITNode at 0x10b3c2ef0>),
(12, <__main__.RITNode at 0x10b3c2390>),
(13, <__main__.RITNode at 0x10b3c28d0>),
(14, <__main__.RITNode at 0x10b3c2b00>),
(15, <__main__.RITNode at 0x10b3c2cc0>),
(16, <__main__.RITNode at 0x10b3c25f8>),
(17, <__main__.RITNode at 0x10b3c2b70>),
(18, <__main__.RITNode at 0x10b3c2b38>),
(19, <__main__.RITNode at 0x10b3c2ac8>),
(20, <__main__.RITNode at 0x10b3c2e10>),
(21, <__main__.RITNode at 0x10b3c2dd8>),
(22, <__main__.RITNode at 0x10b404048>),
(23, <__main__.RITNode at 0x10b404080>),
(24, <__main__.RITNode at 0x10b4040b8>),
(25, <__main__.RITNode at 0x10b4040f0>),
(26, <__main__.RITNode at 0x10b404128>),
(27, <__main__.RITNode at 0x10b404160>),
(28, <__main__.RITNode at 0x10b404198>),
(29, <__main__.RITNode at 0x10b4041d0>),
(30, <__main__.RITNode at 0x10b404208>)]
Content source: Yu-Group/scikit-learn-sandbox
Similar notebooks: