Softmax Tree


In [22]:
import h5py
import numpy as np

In [48]:
with h5py.File('data/chorales.hdf5', "r", libver='latest') as f:
    Xtrain = f['Xtrain'].value
    ytrain = f['ytrain'].value
    Xdev = f['Xdev'].value
    ydev = f['ydev'].value
    Xtest = f['Xtest'].value
    ytest = f['ytest'].value

COUNTER = 1
    
# Insert into a tree
def insert(elements, tree, max_level_size, node=1):
    global COUNTER
    el = elements.pop(0)
    if node not in tree:
        tree[node] = [0 for x in range(max_level_size)]
    if tree[node][el - 1] != 0:
        node = tree[node][el - 1]
    else:
        COUNTER += 1
        tree[node][el - 1] = COUNTER
        node = COUNTER
    if len(elements) > 0:
        insert(elements, tree, max_level_size, node)

In [49]:
t = {}
yall = np.vstack((ytrain, ydev, ytest))
mls = max([max(yall[:, i]) for i in range(5)])
COUNTER = 1
for ex in ytrain[:10]:
    if all(x < 5 for x in list(ex)):
        insert(list(ex), t, mls)

In [61]:
def lookup(ex, tree, node=1):
    if node not in tree:
        return False
    if tree[node][ex[0] - 1] == 0:
        return False
    if len(ex) == 1:
        return True
    node = tree[node][ex[0] - 1]
    ex.pop(0)
    lookup(ex, tree, node)

print lookup(list(ytrain[0]), t)


None
[1 1 1 3 4]

In [ ]: