Union-find 算法


In [2]:
class UF(object):
    def __init__(self, n):
        self.id_ = [x for x in xrange(n)]
        self.count = n
        
    def count(self):
        return self.count
    
    def connected(self, p, q):
        return self.find(p) == self.find(q)
    
    def find(self, p):
        raise ImportError
    
    def union(self, p, q):
        raise ImportError

if __name__ == "__main__":
    uf = UF(15)
    print uf.count
    print uf.id_


15
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]

在上面实现了一个基础的 UF 类,因为 UF 的 find 和 union 的实现有很多种,因此我们将 UF 作为其他具体实现方式的基类。


In [21]:
class QuickFindUF(UF):
    def find(self, p):
        return self.id_[p]
    
    def union(self, p, q):
        p_id = self.find(p)
        q_id = self.find(q)
        if p_id == q_id:
            return
        for x in self.id_:
            if self.id_[x] == p_id:
                self.id_[x] = q_id
        self.count -= 1

简要介绍一下quick_find的原理。


In [22]:
uf = QuickFindUF(10)
uf.union(3, 4)
print uf.id_
assert uf.id_[3] == uf.id_[4]
assert uf.connected(3, 4) == True


[0, 1, 2, 4, 4, 5, 6, 7, 8, 9]

id_代表的就是连通分量。节点在id_数组中的值相等即说明是位于同一个连通分量,亦即节点之间是连通的。quick_find的原理就是在连通的时候把其中一个节点的连通分量改成与另一节点一致的值。


In [10]:
class QuickUnionUF(UF):
    def find(self, p):
        if p != self.id_[p]:
            p = self.id_[p]
        return p
    
    def union(self, p, q):
        p_root = self.find(p)
        q_root = self.find(q)
        if p_root == q_root:
            return
        self.id_[p_root] = q_root
        self.count -= 1

In [11]:
uf = QuickUnionUF(10)
uf.union(3, 4)
print uf.id_
assert uf.id_[3] == uf.id_[4]
assert uf.connected(3, 4) == True


[0, 1, 2, 4, 4, 5, 6, 7, 8, 9]