import numpy as np
from numba import jit
from gth_solve_jit import gth_solve, gth_solve_jit

@jit('float64[:](float64[:,:])')
def gth_solve_jit2(A):
A1 = np.array(A, dtype=np.float64)

if len(A1.shape) != 2 or A1.shape[0] != A1.shape[1]:
# raise ValueError('matrix must be square')  # Not supported
raise ValueError

n = A1.shape[0]

x = np.zeros(n, dtype=np.float64)

# === Reduction === #
for k in range(n-1):
scale = np.sum(A1[k, k+1:n])
if scale <= 0:
# There is one (and only one) recurrent class contained in
# {0, ..., k};
# compute the solution associated with that recurrent class.
n = k+1
break
for i in range(k+1, n):
A1[i, k] /= scale

for j in range(k+1, n):
A1[i, j] += A1[i, k] * A1[k, j]

# === Backward substitution === #
x[n-1] = 1
for k in range(n-2, -1, -1):
for i in range(k+1, n):
x[k] += x[i] * A1[i, k]

# === Normalization === #
norm = np.sum(x)
for k in range(n):
x[k] /= norm

return x

gth_solve_jit2.inspect_types()

gth_solve_jit2([[0.4, 0.6], [0.2, 0.8]])

array([ 0.25,  0.75])

gth_solve_jit2.inspect_types()

sizes = [10, 50, 100]  # [10, 50, 100, 1000]
rand_matrices = []

for n in sizes:
Q = np.random.rand(n, n)
Q /= np.sum(Q, axis=1, keepdims=True)
rand_matrices.append(Q)

for i, Q in enumerate(rand_matrices):
print 'rand_matrices[{0}] ({1} x {2})'.format(i, Q.shape[0], Q.shape[1])
%timeit gth_solve(Q)
%timeit gth_solve_jit(Q)
%timeit gth_solve_jit2(Q)

rand_matrices[0] (10 x 10)
1000 loops, best of 3: 182 µs per loop
100000 loops, best of 3: 5.2 µs per loop
1000 loops, best of 3: 281 µs per loop
rand_matrices[1] (50 x 50)
1000 loops, best of 3: 1.12 ms per loop
10000 loops, best of 3: 63.1 µs per loop
10 loops, best of 3: 21.1 ms per loop
rand_matrices[2] (100 x 100)
100 loops, best of 3: 3 ms per loop
1000 loops, best of 3: 429 µs per loop
10 loops, best of 3: 164 ms per loop

import platform
print platform.platform()

Darwin-13.4.0-x86_64-i386-64bit

``````
import sys
print sys.version

2.7.8 (default, Jul  2 2014, 10:14:46)
[GCC 4.2.1 Compatible Apple LLVM 5.1 (clang-503.0.40)]

print np.__version__

1.9.0

``````
import numba
print numba.__version__

0.15.1

``````
import llvm
print llvm.__version__

0.12.7-5-gc0ae9c2

``````
