In [1]:
import numpy as np
In [2]:
def impl_nth_element(A, l, r, k):
if l < r:
i, j, = l, r
pivot = A[np.random.randint(l, r, size=1)]
while True:
while A[i] < pivot:
i += 1
while A[j] > pivot:
j -= 1
if i <= j:
A[i], A[j] = A[j], A[i]
i += 1
j -= 1
if i > j:
break
if l <= k <= j:
impl_nth_element(A, l, j, k)
if i <= k <= r:
impl_nth_element(A, i, r, k)
In [3]:
def nth_element(a, k):
impl_nth_element(a, 0, len(a) - 1, k)
return a[k]
def median(a):
return nth_element(a, len(a) // 2)
In [4]:
import heapq
def heap_select(a, k):
return heapq.nsmallest(k + 1, a)[-1]
In [5]:
np.random.seed(0xDEAD)
In [10]:
xs = np.random.randint(0, 10, size=10)
print(xs)
print(nth_element(xs, 2))
print(heap_select(xs, 2))
print(xs)
In [7]:
xs = np.random.randint(0, 100, size=20)
print(xs)
print(nth_element(xs, 4))
print(heap_select(xs, 4))
print(xs)
In [13]:
xs = np.random.randint(0, 100, size=25)
print(xs)
print(nth_element(xs, 15))
print(heap_select(xs, 15))
print(xs)
In [9]:
xs = np.array(range(10))
print(xs)
print(nth_element(xs, 2))
print(heap_select(xs, 2))
print(xs)