In [ ]:
%%HTML
<style>
.container { width:100% }
</style>

An Object-Oriented Implementation of the Union-Find Algorithm

The class UnionFind maintains three member variables:

  • mParent is a dictionary that assigns each node to its parent node. Initially, all nodes point to themselves.
  • mHeight is a dictionary that stores the height of the trees. If $x$ is a node, then $\texttt{mHeight}[x]$ is the height of the tree rooted at $x$. Initially, all trees contain but a single node and therefore have the height $1$.
  • mSize is a dictionary that stores the number of nodes of the trees. If $x$ is a node, then $\texttt{mSize}[x]$ is the number of nodes of the tree rooted at $x$.

In [ ]:
class UnionFind:
    def __init__(self, M):
        self.mParent = { x: x for x in M }
        self.mHeight = { x: 1 for x in M }

Given an element $x$ from the set $M$, the function $\texttt{self}.\texttt{find}(x)$ returns the ancestor of $x$ that is at the root of the tree containing $x$.


In [ ]:
def find(self, x):
    p = self.mParent[x]
    if p == x:
        return x
    return self.find(p)

UnionFind.find = find
del find

Given two elements $x$ and $y$ and an object $o$ of type UnionFind, the call $o.\texttt{union}(x, y)$ changes the unionFind object $o$ so that afterwards the equation $$ o.\texttt{find}(x) = o.\texttt{find}(y) $$ holds.


In [ ]:
def union(self, x, y):
    root_x = self.find(x)
    root_y = self.find(y)
    if root_x != root_y:
        if self.mHeight[root_x] < self.mHeight[root_y]:
            self.mParent[root_x]  = root_y
        elif self.mHeight[root_x] > self.mHeight[root_y]:
            self.mParent[root_y]  = root_x
        else:
            self.mParent[root_y]  = root_x
            self.mHeight[root_x] += 1
                
UnionFind.union = union

In [ ]:
def partition(M, R):
    UF = UnionFind(M)
    for x, y in R:
        UF.union(x, y)
    Roots = { x for x in M if UF.find(x) == x }
    return [{y for y in M if UF.find(y) == r} for r in Roots]

In [ ]:
def demo():
    M = set(range(1, 10))
    R = { (1, 4), (7, 9), (3, 5), (2, 6), (5, 8), (1, 9), (4, 7) }
    P = partition(M, R)
    return P

In [ ]:
P = demo()
P

In [ ]: