Day 04 - Miller-Rabin primality test


Definition(s)

The Miller-Rabin primality test or Rabin-Miller primality test is a primality test: an algorithm which determines whether a given number is prime or not.

The pseudocode, from Wikipedia is:

Input: n > 2, an odd integer to be tested for primality;
       k, a parameter that determines the accuracy of the test
Output: composite if n is composite, otherwise probably prime
write n − 1 as 2^s·d with d odd by factoring powers of 2 from n − 1
LOOP: repeat k times:
   pick a randomly in the range [2, n − 1]
   x ← a^d mod n
   if x = 1 or x = n − 1 then do next LOOP
   for r = 1 .. s − 1
      x ← x^2 mod n
      if x = 1 then return composite
      if x = n − 1 then do next LOOP
   return composite
return probably prime

For more, check wikipedia.

Algorithm(s)


In [22]:
import random  # used for generation random bases

In [2]:
def miller_rabin(n, number_trials=13, use_random_bases=False, bases=[2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41],
                 fast_version=True):
    if n < 2:
        return False

    # special case 2
    if n == 2 or n == 3:
        return True

    # ensure n is odd
    if n % 2 == 0:
        return False

    # write n-1 as 2**s * d
    # repeatedly try to divide n-1 by 2
    s, d = 0, n - 1
    while True:
        quotient, remainder = divmod(d, 2)
        if remainder == 1:
            break

        s += 1
        d = quotient

    assert 2 ** s * d == n - 1

    # test the base a to see whether it is a witness for the compositeness of n
    def slow_witness(a):
        if pow(a, d, n) == 1:
            return False  # possibly prime

        for i in range(s):
            if pow(a, 2 ** i * d, n) == n - 1:
                return False  # possibly prime

        return True  # composite
    
    # test the base a to see whether it is a witness for the compositeness of n (but faster)
    def fast_witness(a):
        x = pow(a, d, n)

        if x == 1 or x == n - 1:
            return False  # possibly prime

        for _ in range(s - 1):
            x = pow(x, 2, n)

            if x == 1:
                return True  # composite

            if x == n - 1:
                return False  # possibly prime

        return True  # composite

    if fast_version:
        witness = fast_witness
    else:
        witness = slow_witness

    for i in range(number_trials):
        if use_random_bases:
            a = random.randrange(2, n)
        else:
            a = bases[i]

        if a != n and witness(a):  # do not use a is a == n
            return False  # definitely composite

    return True  # possibly prime

In [14]:
# naive implementation O(sqrt n) of a primality test
def naive_is_prime(n):
    if n < 2:
        return False

    i = 2
    while i * i <= n:
        if n % i == 0:
            return False  # composite (i is a factor of n)

        i += 1

    return True  # prime

In [34]:
# helper functions (basic choices of options)

def is_prime(n):
    return miller_rabin(n)


def is_prime_slower(n):
    return miller_rabin(n, fast_version=False)

def next_probable_prime(n):
    n += 1
    while not is_prime(n):
        n += 1
        
    return n

Run(s)


In [16]:
# basic usage
print(221, is_prime(221))


221 False

In [24]:
# different behaviours for n=221
print(221, miller_rabin(221))
print(221, miller_rabin(221, number_trials=1, bases=[174]))
print(221, miller_rabin(221, use_random_bases=True))


221 False
221 True
221 False

In [31]:
import timeit

n = 89888786858483
print(n, is_prime(n), timeit.timeit("is_prime(n)", number=1, setup="from __main__ import is_prime, n"))
print(n, naive_is_prime(n), timeit.timeit("naive_is_prime(n)", number=1, setup="from __main__ import naive_is_prime, n"))


89888786858483 True 0.00012834600056521595
89888786858483 True 1.969153579993872

In [15]:
# test for correctness
is_prime(100) == naive_is_prime(100)


Out[15]:
True

In [25]:
# test for correctness (0 <= i <= 10^6)
N = 10 ** 6
[is_prime(i) for i in range(N)] == [naive_is_prime(i) for i in range(N)]


Out[25]:
True

In [13]:
# Miller-Rabin with slower witness function
%%time
xs = [is_prime_slower(i) for i in range(10 ** 6)]


CPU times: user 11.8 s, sys: 8 ms, total: 11.8 s
Wall time: 11.8 s

In [12]:
# Miller-Rabin with faster witness function
%%time
xs = [is_prime(i) for i in range(10 ** 6)]


CPU times: user 6.51 s, sys: 0 ns, total: 6.51 s
Wall time: 6.51 s

In [11]:
# Naive implementation of a primality test
%%time
xs = len([naive_is_prime(i) for i in range(10 ** 6)])


CPU times: user 12.1 s, sys: 4 ms, total: 12.1 s
Wall time: 12.1 s

In [36]:
print(next_probable_prime(100))
print(next_probable_prime(2 ** 20))
print(next_probable_prime(2 ** 10))
print(next_probable_prime(13))


101
1048583
1031
17