In [ ]:
import numpy as np
#from numba import jit
#@jit
def choBackSubstitution(L,y,lower=True,modify=False):
if len(y.shape) == 2:
if not modify:
x = np.copy(y)
else:
x = y
if lower:
i = 0
while i < L.shape[0]:
x[i,:] /= L[i,i]
x[i+1:,:] -= np.outer(L[i+1:,i],x[i,:])
i += 1
else:
i = L.shape[0] - 1
while i >= 0:
x[i,:] /= L[i,i]
x[:i,:] -= np.outer(L[:i,i],x[i,:])
i -= 1
else:
if not modify:
x = np.copy(y)
else:
x = y
if lower:
i = 0
while i < L.shape[0]:
x[i] /= L[i,i]
x[i+1:] -= L[i+1:,i]*x[i]
i += 1
else:
i = L.shape[0] - 1
while i >= 0:
x[i] /= L[i,i]
x[:i] -= L[:i,i]*x[i]
i -= 1
return x
#@jit
def choSolve(L,b,modify=False):
#second can be modified always as it is last reference
return choBackSubstitution(L.T,choBackSubstitution(L,b,True,modify),False,True)
if __name__ == '__main__':
#from scipy.linalg import cho_solve
from scipy.linalg.lapack import dpotrs
N = 5
y = np.random.uniform(size=N)
Y = np.random.uniform(size=[N,2])
a = np.random.uniform(size=[N,N])
a = a.T.dot(a)
L = np.linalg.cholesky(a)
X = choSolve(L,Y,False)
xa = choSolve(L,Y[:,0],False)
xb = choSolve(L,Y[:,1],False)
assert np.alltrue(np.isclose(X[:,0],xa)),"a fails"
assert np.alltrue(np.isclose(X[:,1],xb)),"b fails"
%load_ext line_profiler
%lprun -f choBackSubstitution choSolve(L,y,False)
#with y vec mod (no copy)
%timeit -n 10 choSolve(L,y,False)
#built in
#%timeit cho_solve((L,True),y)
%timeit -n 10 dpotrs(L,y,1,0)
#x1 = cho_solve((L,True),y)
x1 = dpotrs(L,y,1,0)
x2 = choSolve(L,y,False)
#x1 = dpotrs(L,y,1,1)
print("same:",np.alltrue(np.isclose(x1[0],x2)))
times1 = []
times2 = []
Ns = 10**np.linspace(1,4,20)
from time import clock
for N in Ns:
N = int(N)
y = np.random.uniform(size=N)
a = np.random.uniform(size=[N,N])
a = a.T.dot(a)
L = np.linalg.cholesky(a)
t1 = clock()
#x1 = cho_solve((L,True),y)
x1 = dpotrs(L,y,1,0)
times1.append(clock()-t1)
t1 = clock()
x2 = choSolve(L,y,False)
times2.append(clock()-t1)
import pylab as plt
plt.plot(Ns,times1,label='scipy.linalg.cho_solve')
plt.plot(Ns,times2,label='my choSolve')
plt.yscale('log')
plt.xscale('log')
plt.legend()
plt.show()