In [ ]:
def get_best_representation(di, A):
    '''
    Get best possible representation for di given a set of atoms A
    '''
    # Start with empty set for this representation
    curr_r = set()

    # degenerate case
    if len(di)  == 0:
        return curr_r

    min_atom_ind =-1
    min_distance = paris_distance(di, A, curr_r, r_slack) + 1.0/len(di)*len(curr_r)
    # Keep adding atoms to the representation until we are unable to improve the result
    while min_atom_ind is not None:
        # Find atom to add to the representation that minimizes total distance
        min_atom_ind = None
        for i in range(len(A)):
            # Only check distance for items where there is some intersection between the line and the atom
            if i not in curr_r and len(di.intersection(A[i])) > 0:
                attempted_r = deepcopy(curr_r)
                attempted_r.add(i)
                dist = paris_distance(di, A, attempted_r, r_slack) + 1.0/len(di)*len(attempted_r)
                if verbose:
                    print 'Dist, min_dist', dist, min_distance
                if min_distance is None or dist < min_distance:
                    min_distance = dist
                    min_atom_ind = i

        if min_atom_ind is not None:
            curr_r.add(min_atom_ind)
    return curr_r

def get_global_best_representation(D, A):
    def get_individual_rep(di, A=A):
        return lambda di: get_best_representation(di, A)
    return D.map(get_individual_rep)

In [ ]:
def PARIS(D, num_iterations=1):
    for iteration in range(num_iterations):