#cython: boundscheck=False, wraparound=False, initializedcheck=False, cdivision=True
cimport numpy as np
import numpy as np
from libc.math cimport sin, cos, exp, sinh, cosh, tan, tanh, log
from scipy.linalg.cython_blas cimport zgemv, zgemm
cdef int iONE=1, iZERO=0
cdef double complex zONE=1, zZERO=0, zNHALF=-0.5
# Declaration of global variables for
# - intermediate results
# - predefined constants
cdef double complex [:, :] C = np.empty((2, 2), dtype=np.complex)
cdef double complex [:, :] E = np.empty((2, 2), dtype=np.complex)
cdef double complex [:, :] F = np.empty((2, 2), dtype=np.complex)
cdef double complex [:, :] G = np.empty((2, 2), dtype=np.complex)
cdef double complex [:, :] D # initialized in setup
cdef double complex [:, :] A # initialized in setup
# The generated cython code
cdef generated_function(
    # input arguments
    double complex [:, :] B,
    # output argument
    double complex [:, :] H
    ):
    # indices and intermediate values for various matrix operations
    cdef int i, j, k, ii, jj, kk
    cdef double complex c
    # intermediate results
    global C, E, F, G
    # predefined constants
    global D, A
    # evaluating the function
    # C = A.B
    ii = A.shape[0]
    jj = A.shape[1]
    kk = B.shape[1]
    zgemm('N', 'N', &kk, &ii, &jj, &zONE, &B[0,0], &kk, &A[0,0], &jj, &zZERO, &C[0,0], &kk)
    # E = B.D
    ii = B.shape[0]
    jj = B.shape[1]
    kk = D.shape[1]
    zgemm('N', 'N', &kk, &ii, &jj, &zONE, &D[0,0], &kk, &B[0,0], &jj, &zZERO, &E[0,0], &kk)
    # F = -1*E
    ii = E.shape[0]
    jj = E.shape[1]
    for i in range(ii):
        for j in range(jj):
            F[i,j] = -1*E[i,j]
    # G = C+F
    ii = C.shape[0]
    jj = C.shape[1]
    for i in range(ii):
        for j in range(jj):
            G[i,j] = C[i,j]+F[i,j]
    # H = -1j*G
    ii = G.shape[0]
    jj = G.shape[1]
    for i in range(ii):
        for j in range(jj):
            H[i,j] = -1j*G[i,j]
# Function to deliver the constants from python to cython
cpdef setup_generated_function(
    double complex [:, :] _D,
    double complex [:, :] _A
    ):
    global D, A
    D = _D
    A = _A
cpdef np.ndarray[np.complex_t, ndim=2] pythoncall(_0):
    cdef np.ndarray[np.complex_t, ndim=2] result = np.empty((2, 2), complex)
    generated_function(_0, result)
    return result