In [4]:
# heap structure. a binary tree where the key values of children is larger than the key value of the parent
class myHeapArray(object):
def __init__(self, ):
self.heap = []
def bubbleDown(self):
n = len(self.heap)-1
p = 0
while (2*p <= n):
c1 = 2*p
if 2*p + 1 <= n:
c2 = 2*p + 1
# Second child exists
if self.heap[c1] < self.heap[c2]:
c_val, c = (self.heap[c1], c1)
else:
c_val, c = (self.heap[c2], c2)
else:
c_val, c = (self.heap[c1], c1)
# Swap parent with the child with the smallest key value
if self.heap[p] > c_val:
tmp = self.heap[p]
self.heap[p] = c_val
self.heap[c] = tmp
p = c
else:
break
def extractMin(self):
if not self.heap:
return None
val = self.heap[0]
# copy last value to root, then remove and discard.
last = len(self.heap) - 1
self.heap[0] = self.heap[last]
self.heap.pop()
self.bubbleDown()
return val
def bubbleUp(self):
# for parent at node i, children is at node 2*i and 2*i + 1
# for child i, parent is at i/2 if i is even or floor(i/2) if i is odd
# Bubble Up starting the newly added key at the end of the heap
c = len(self.heap) - 1
while c!= 0:
p = int(c / 2)
if self.heap[p] > self.heap[c]:
tmp = self.heap[p]
self.heap[p] = self.heap[c]
self.heap[c] = tmp
c = p
else:
break
def insert(self,elem):
self.heap.append(elem)
self.bubbleUp()
def insertList(self, elemList):
for elem in elemList:
self.heap.append(elem)
self.bubbleUp()
def get_ordered_list(self):
ordered = []
while True:
n = self.extractMin()
if n:
ordered.append(n)
else:
break
return ordered
In [13]:
# compare to InsertSort
def InsertSort(arr):
arr = arr.copy()
n = len(arr)
i=1
while i < n:
val = arr[i]
for j in range(0, i):
if arr[j] >= val:
# insert at location j, and shift all by one
arr[j+1:i+1] = arr[j:i]
arr[j]=val
break
i += 1
return arr
In [98]:
## Comparison
import numpy as np
N = 100
arr = np.random.randint(0, 10000, size=N)
s = datetime.now()
h = myHeapArray()
h.insertList(arr)
sortedHS = h.get_ordered_list()
e = datetime.now()
print ("heapSort(ms)", (e-s).microseconds / 1000, "ms")
s = datetime.now()
sortedIS = InsertSort(arr)
e = datetime.now()
print ("InsertSort(ms)", (e-s).microseconds / 1000, "ms")
arrsort = arr.copy()
arrsort.sort()
assert sum(sortedHS == arrsort) == N, "Don't match! HS {} arr {}, orig {}".format(sortedHS, arrsort, arr)
assert sum(sortedIS == arrsort) == N, "Don't match! IS {} arr {}, orig {}".format(sortedHS, arrsort, arr)
In [ ]: