In [1]:
import numpy as np
from scipy.linalg import cho_factor, cho_solve, solve
In [2]:
ncomp = 5
npix = 2300
In [3]:
pcomps = np.random.normal(size=(npix, ncomp))
In [4]:
Cgp = np.eye(ncomp)
Cdata = np.eye(npix)
In [5]:
dar = pcomps.dot(Cgp.dot(pcomps.T))
In [6]:
C = dar + Cdata
In [7]:
ys = 0.01 * np.random.normal(size=(npix,))
In [8]:
%%timeit
factor, flag = cho_factor(C, check_finite=False, overwrite_a=True)
np.dot(ys, cho_solve((factor, flag), ys, check_finite=False, overwrite_b=True))
In [11]:
%%timeit
np.dot(ys, solve(C, ys, sym_pos=True, overwrite_a=True, overwrite_b=True, check_finite=False))
In [12]:
mat = np.arange(25)
In [13]:
mat.shape = (5,5)
In [17]:
flag = (mat > 7) * (mat < 15)
In [25]:
for index in np.argwhere(flag):
print(index)
In [19]:
mat
Out[19]:
In [20]:
wl = np.linspace(10, 40)
In [27]:
rr = np.abs(wl[:, None] - wl[None, :])
In [28]:
flag = (rr < 7)
In [30]:
np.argwhere(flag)
Out[30]:
In [31]:
%load_ext cythonmagic
In [43]:
%%cython
cimport cython
cimport numpy as np
import numpy as np
from Starfish import constants as C
@cython.boundscheck(False)
def get_dense_C(np.ndarray[np.double_t, ndim=1] wl, k_func, double max_r):
'''
Fill out the covariance matrix.
'''
cdef int N = len(wl)
cdef int i = 0
cdef int j = 0
cdef double cov = 0.0
#Find all the indices that are less than the radius
rr = np.abs(wl[:, None] - wl[None, :]) * C.c_kms/wl #Velocity space
flag = (rr < max_r)
indices = np.argwhere(flag)
#The matrix that we want to fill
mat = np.zeros((N,N))
#Loop over all the indices
for index in indices:
i,j = index
if j > i:
continue
else:
#Initilize [i,j] and [j,i]
cov = k_func(wl[i], wl[j])
mat[i,j] = cov
mat[j,i] = cov
return mat
In [44]:
wl = np.linspace(5000, 5100, num=2000)
In [45]:
k_func = lambda wl1,wl2 : np.exp(-(wl2 - wl1)**2)
In [48]:
def k_func(wl1, wl2):
r2 = (wl2 - wl1)**2
if (wl2 > 5010) and(wl2 < 5020):
return 10.
else:
return np.exp(-(wl2 - wl1)**2)
In [49]:
%timeit get_dense_C(wl, k_func, 50.)
In [50]:
import matplotlib.pyplot as plt
In [51]:
plt.imshow(get_dense_C(wl, k_func, 50.))
plt.show()
In [ ]: