Prim's minimum spanning tree algorithm.

This file prim.txt describes an undirected graph with integer edge costs. It has the format

[number_of_nodes] [number_of_edges]

[one_node_of_edge_1] [other_node_of_edge_1] [edge_1_cost]

[one_node_of_edge_2] [other_node_of_edge_2] [edge_2_cost]

...

For example, the third line of the file is "2 3 -8874", indicating that there is an edge connecting vertex #2 and vertex #3 that has cost -8874.

You should NOT assume that edge costs are positive, nor should you assume that they are distinct.

Your task is to run Prim's minimum spanning tree algorithm on this graph. You should report the overall cost of a minimum spanning tree --- an integer, which may or may not be negative

IMPLEMENTATION NOTES: This graph is small enough that the straightforward O(mn) time implementation of Prim's algorithm should work fine. OPTIONAL: For those of you seeking an additional challenge, try implementing a heap-based version. The simpler approach, which should already give you a healthy speed-up, is to maintain relevant edges in a heap (with keys = edge costs). The superior approach stores the unprocessed vertices in the heap, as described in lecture. Note this requires a heap that supports deletions, and you'll probably need to maintain some kind of mapping between vertices and their positions in the heap


In [35]:
class prim(object):
    MAX_WEIGHT = 1000000
    
    def __init__(self, graph, vertices, edges):
        self.graph = graph
        self.vertices = vertices
        self.edges = edges 
        
        self.X = [] # Vertices processed so far
        self.unprocessed_vertices = vertices.copy() # V-X

    def compute_next_min_edge(self):
        minW = prim.MAX_WEIGHT
        minV1 = None  #destination edge
        minV2 = None  # Source Edge
#         print ("X", self.X, "V-X", self.unprocessed_vertices )
#         print (self.edges)
        for edge in self.edges:
            node1 = edge[0]
            node2 = edge[1]
                        
            if node1 in self.X and node2 in self.unprocessed_vertices:
                weight = edge[2]
                if weight < minW:
                    minW = weight                    
                    minV1 = node1
                    minV2 = node2                    
            elif node2 in self.X and node1 in self.unprocessed_vertices:
                weight = edge[2]
                if weight < minW:
                    minW = weight                    
                    minV1 = node2
                    minV2 = node1
                    
        #print ("final choice", minS, minE, self.A[minS], minD)
        if minV1:
            # minV1 always belongs to X and minV2 to V-x
            return minV1, minV2, minW
        else:
            return None,None,None
        
    
    def reinit(self, s):
        self.X = [s]        
        self.unprocessed_vertices = vertices.copy()
        self.unprocessed_vertices.remove(s)

    def run(self, s):
        self.reinit(s)
        
        n = len(self.vertices)
        v = s
          
        MST = []
        while (n > 0):
            
            minV1, minV2, minW = self.compute_next_min_edge()
            
            if minV2 is None:
                # No more edges between X and V-X to process. Done.                    
                break
                
            #print ("pick", w)
            #print ("processed", self.X, self.A)

            self.unprocessed_vertices.remove(minV2)
            self.X.append(minV2)
            n -= 1
#             print ("Add edge", [minV1, minV2])
            MST.append([minV1, minV2, minW])
                
        return MST

In [36]:
# Test Prim

#Example
def get_edges(graph):
    edges = []
    for s, adj in graph.items():
        for v in adj:
            edges.append([s, v[0], v[1]])
    return edges
        
graph ={
    "1": [["2",1], ["3", 4]],
    "2": [["3", 2], ["4",6]],
    "3": [["4",3]],
    "4": []
}

vertices = ["1", "3", "2", "4"]
edges = get_edges(graph)

d = prim(graph, vertices, edges)

print (d.run("1"))


[['1', '2', 1], ['2', '3', 2], ['3', '4', 3]]

In [37]:
# Test  File
import collections
FILE = "prim_test.txt"

fp = open(FILE, 'r')

data = fp.readlines()

param = data[0].strip().split(' ')
n_nodes, n_edges = int(param[0]), int(param[1])

data = data[1:]

graph = collections.defaultdict(list)
vertices = set()
edges = []

for line in data:
    v = line.strip().split(" ")
    
    vertices.add(v[0])
    vertices.add(v[1])

    graph[v[0]].append(v[1])
    graph[v[1]].append(v[0])

    edges.append([v[0], v[1], int(v[2])])
    
# print ("Vertex 1 adj:", graph["1"])
# print ("First 5 Edges:", edges[:5])

d = prim(graph, vertices, edges)
MST = d.run("1")

weight = 0
for e in MST:
    weight += e[2]
    
print ("Total Weight", weight)

assert weight == 3
print ("Test Passed!")


Total Weight 3
Test Passed!

In [38]:
# Test  File
import collections
FILE = "prim.txt"

fp = open(FILE, 'r')

data = fp.readlines()

param = data[0].strip().split(' ')
n_nodes, n_edges = int(param[0]), int(param[1])

data = data[1:]

graph = collections.defaultdict(list)
vertices = set()
edges = []

for line in data:
    v = line.strip().split(" ")
    
    vertices.add(v[0])
    vertices.add(v[1])

    graph[v[0]].append(v[1])
    graph[v[1]].append(v[0])

    edges.append([v[0], v[1], int(v[2])])
    
# print ("Vertex 1 adj:", graph["1"])
# print ("First 5 Edges:", edges[:5])

d = prim(graph, vertices, edges)
MST = d.run("1")

weight = 0
for e in MST:
    weight += e[2]
    
print ("Total Weight", weight)


Total Weight -3612829
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-38-778a3ffa2680> in <module>()
     39 print ("Total Weight", weight)
     40 
---> 41 assert weight == 3
     42 print ("Test Passed!")

AssertionError: 

In [33]:



[['1', '4', 3], ['4', '2', -1], ['2', '3', 2], ['3', '5', -2], ['5', '6', 1]]

In [ ]: