In [ ]:
%matplotlib inline

In [ ]:
import numpy as np
import scipy.sparse as sps
import matplotlib.pyplot as plt

In [ ]:
import pyopencl as cl
import pyopencl.array
import loopy as lp

In [ ]:
ctx = cl.create_some_context()
queue = cl.CommandQueue(ctx)

In [ ]:
mnorm = lambda x: np.max(np.abs(x))
def kron3(A, B, C):
    return sps.kron(A, sps.kron(B, C))

3D einsum tests


In [ ]:
n = 64
I = sps.eye(n)
D = np.random.rand(n, n)
D1 = kron3(I, I, D)
D2 = kron3(I, D, I)
D3 = kron3(D, I, I)

In [ ]:
A  = np.random.rand(n,n,n)
Ar = A.ravel()

In [ ]:
A_d = cl.array.to_device(queue, A)
D_d = cl.array.to_device(queue, D)

$I\otimes I\otimes D$


In [ ]:
knl = lp.make_kernel(
        "{ [i,j,k,a]: 0<=a,i,j,k<n }",
        "out[i,j,k]=sum(a, D[k,a]*A[i,j,a])")
knl = lp.set_options(knl, "write_cl")

knl = lp.prioritize_loops(knl, "i,j,k")

#knl = lp.split_iname(knl, 'i', 16, outer_tag='g.0')

out_d = cl.array.zeros_like(A_d)
evt, _ = knl(queue, D=D_d, A=A_d, out=out_d)
evt.wait()

print mnorm(out_d.get().ravel()-D1.dot(Ar))

In [ ]:
%timeit knl(queue, D=D_d, A=A_d, out=out_d)[0].wait()

In [ ]:
%timeit np.einsum('kj,aij', D, A)

In [ ]:
%timeit D1.dot(Ar)

In [ ]:
%timeit A.reshape((-1,n,n)).dot(D.T)

$I\otimes D\otimes I$


In [ ]:
knl = lp.make_kernel(
        "{ [i,j,k,a]: 0<=a,i,j,k<n }",
        "out[i,j,k]=sum(a, D[j,a]*A[i,a,k])")
knl = lp.set_options(knl, "write_cl")

knl = lp.prioritize_loops(knl, "i,j,k")

out_d = cl.array.zeros_like(A_d)
evt, _ = knl(queue, D=D_d, A=A_d, out=out_d)
evt.wait()

print mnorm(out_d.get().ravel()-D2.dot(Ar))

In [ ]:
#%timeit np.einsum('ij,ajk', D, A)

In [ ]:
#%timeit D2.dot(Ar)

$D\otimes I\otimes I$


In [ ]:
knl = lp.make_kernel(
        "{ [i,j,k,a]: 0<=a,i,j,k<n }",
        "out[i,j,k]=sum(a, D[i,a]*A[a,j,k])")
knl = lp.set_options(knl, "write_cl")

knl = lp.prioritize_loops(knl, "i,j,k")

out_d = cl.array.zeros_like(A_d)
evt, _ = knl(queue, D=D_d, A=A_d, out=out_d)
evt.wait()

print mnorm(out_d.get().ravel()-D3.dot(Ar))

In [ ]:
#%timeit np.einsum('ij,jkl', D, A)

In [ ]:
#%timeit D3.dot(Ar)

In [ ]: