In [ ]:
    
%matplotlib inline
    
In [ ]:
    
import numpy as np
from numpy import newaxis as na
import scipy
import scipy.sparse as sps
from scipy.sparse.linalg import spsolve, LinearOperator
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
    
In [ ]:
    
from pyamg.classical import ruge_stuben_solver
from pyfem.sem import SEMhat
from pyfem.topo import Interval
norm = lambda x: np.max(np.abs(x)) if len(x)>0 else 0.0
    
In [ ]:
    
N  = 8
Ex = 4
Ey = 4
nx      = N+1
ny      = N+1
nx_dofs = N*Ex+1
ny_dofs = N*Ey+1
n_elem  = Ex*Ey
periodic = True
#periodic = False
if periodic:
    nx_dofs -= 1
    ny_dofs -= 1
n_dofs = ny_dofs*nx_dofs
semh = SEMhat(N)
    
In [ ]:
    
def l1(X):
    return (1.0-X)/2.0
def dl1(X):
    return -np.ones_like(X)/2.0
def l2(X):
    return (1.0+X)/2.0
def dl2(X):
    return np.ones_like(X)/2.0
    
In [ ]:
    
# def f(X):
#     x = X[:,0]
#     y = X[:,1]
    
#     return np.sin(np.pi*x)*np.sin(np.pi*y)
# def f2(X):
#     x = X[:,0]
#     y = X[:,1]
    
#     return np.sin(np.pi*x)*np.sin(np.pi*y)*2*(np.pi)**2
def f(X):
    x = X[:,0]
    y = X[:,1]
    
    r = np.cos(np.pi*x)*np.cos(np.pi*y)
    if not periodic:
        r += x
    return r
def f2(X):
    x = X[:,0]
    y = X[:,1]
    
    return np.cos(np.pi*x)*np.cos(np.pi*y)*2*(np.pi)**2
    
In [ ]:
    
def ref_to_phys(X, nodes):
    
    v1 = l1(X)
    v2 = l2(X)
    v1x, v1y = v1[:,0], v1[:,1]
    v2x, v2y = v2[:,0], v2[:,1]
    
    P = (v1x*v1y)[:,na]*nodes[0,:]+\
        (v2x*v1y)[:,na]*nodes[1,:]+\
        (v1x*v2y)[:,na]*nodes[2,:]+\
        (v2x*v2y)[:,na]*nodes[3,:]
    
    return P
def calc_jacb(X, nodes):
    
    v1 = l1(X)
    v2 = l2(X)
    v1x, v1y = v1[:,0], v1[:,1]
    v2x, v2y = v2[:,0], v2[:,1]
    
    dv1 = dl1(X)
    dv2 = dl2(X)
    dv1x, dv1y = dv1[:,0], dv1[:,1]
    dv2x, dv2y = dv2[:,0], dv2[:,1]
    
    J = np.zeros((len(X),2,2))
    
    t1x, t2x = ((dv1x*v1y)[:,na]*nodes[0,:]+\
                (dv2x*v1y)[:,na]*nodes[1,:]+\
                (dv1x*v2y)[:,na]*nodes[2,:]+\
                (dv2x*v2y)[:,na]*nodes[3,:]).T
                
    t1y, t2y = ((v1x*dv1y)[:,na]*nodes[0,:]+\
                (v2x*dv1y)[:,na]*nodes[1,:]+\
                (v1x*dv2y)[:,na]*nodes[2,:]+\
                (v2x*dv2y)[:,na]*nodes[3,:]).T
    
    J[:,0,0] = t1x
    J[:,0,1] = t1y
    J[:,1,0] = t2x
    J[:,1,1] = t2y
    
    return J
    
In [ ]:
    
dofx = np.linspace(-1,1,Ex+1)
dofy = np.linspace(-1,1,Ey+1)
X, Y = np.meshgrid(dofx, dofy)
vertex_ref = np.zeros((len(X.ravel()), 2))
vertex_ref[:,0] = X.ravel()
vertex_ref[:,1] = Y.ravel()
vertex_phys = vertex_ref.copy()
vertex_phys[:,0] *= 1
vertex_phys[:,1] *= 1
# vmid = (Ex+1)*(Ey+1)/2 
# vertex_phys[vmid,:] += 0.1
# Peturb verticies
chi, eta = vertex_ref.T
sx = sy = 0.05
vertex_phys[:,0] = chi+sx*np.sin(np.pi*chi)*np.sin(np.pi*eta)
vertex_phys[:,1] = eta+sy*np.sin(np.pi*chi)*np.sin(np.pi*eta)
etn = np.zeros((n_elem, 4), dtype=np.int)
ind = 0
for iy in range(Ey):
    for ix in range(Ex):
        etn[ind, 0] = ix+iy*(Ex+1)
        etn[ind, 1] = ix+iy*(Ex+1)+1
        etn[ind, 2] = ix+(iy+1)*(Ex+1)
        etn[ind, 3] = ix+(iy+1)*(Ex+1)+1
        ind += 1
    
In [ ]:
    
etd = np.zeros((n_elem, nx*ny), dtype=np.int)
rngx = np.arange(nx)
rngy = np.arange(ny)
ind = 0
for iy in range(Ey):
    for ix in range(Ex):
        indy = iy*N
        indx = ix*N
        
        e  = (rngx[na,:]+indx)%nx_dofs+((rngy[:,na]+indy)*nx_dofs)%n_dofs
        etd[ind,:] = e.ravel()
       
        ind += 1
        
cols = etd.ravel()
rows = np.arange(len(cols))
vals = np.ones(len(cols))
Q = sps.coo_matrix((vals,(rows,cols))).tocsr()
    
In [ ]:
    
# Build restriction operator
if periodic:
    R0x = sps.eye(nx_dofs)
    R0y = sps.eye(ny_dofs)
else:
    R0x = sps.dia_matrix((np.ones(nx_dofs),1),
                          shape=(nx_dofs-2,nx_dofs))
    R0y = sps.dia_matrix((np.ones(ny_dofs),1),
                         shape=(ny_dofs-2,ny_dofs))
R = sps.kron(R0y, R0x)
if not periodic:
    bd = set(np.arange(n_dofs))-set(R.dot(np.arange(n_dofs)))
    boundary_dofs = np.sort(np.array(list(bd)))
else:
    boundary_dofs = np.array([],dtype=np.int)
boundary_dofs.sort()
    
In [ ]:
    
wgll = semh.wgll
wv   = (wgll[:,na]*wgll[na,:]).ravel()
xgll = semh.xgll
n = len(xgll)
quad_ref = np.zeros((n,n,2))
quad_ref[:,:,0] = xgll[na,:]
quad_ref[:,:,1] = xgll[:,na]
quad_ref = quad_ref.reshape((-1,2))
    
In [ ]:
    
# build Gij
G0 = np.zeros((len(quad_ref),2,2))
G11 = []
G12 = []
G21 = []
G22 = []
nn = nx*ny
s  = (nn, nn)
dof_phys = np.zeros((nx_dofs*ny_dofs, 2))
wvals    = np.zeros(nx_dofs*ny_dofs)
for i in range(n_elem):
     
    ver = vertex_phys[etn[i]]
    J   = calc_jacb(quad_ref, ver)
    Ji  = np.linalg.inv(J)
    j   = np.linalg.det(J).ravel()
    dof_phys[etd[i],:] = ref_to_phys(quad_ref, ver)
    G0 = np.matmul(Ji, np.transpose(Ji, (0,2,1)))
    G0 *= (wv*j)[:,na,na]
    wvals[etd[i]] += (wv*j)
    
    G11 += [sps.dia_matrix((G0[:,0,0], 0), shape=s)]
    G12 += [sps.dia_matrix((G0[:,0,1], 0), shape=s)]
    G21 += [sps.dia_matrix((G0[:,1,0], 0), shape=s)]
    G22 += [sps.dia_matrix((G0[:,1,1], 0), shape=s)]
    
In [ ]:
    
# Build poisson stiffness matrix A
D1 = sps.kron(sps.eye(ny), semh.Dh)
D2 = sps.kron(semh.Dh, sps.eye(nx))
A0a = []
for i in range(n_elem):
    A0a += [D1.T.dot(G11[i].dot(D1)+G12[i].dot(D2))+\
            D2.T.dot(G21[i].dot(D1)+G22[i].dot(D2))]
A0 = sps.block_diag(A0a).tocsr()
A0 = Q.T.dot(A0.dot(Q))
A  = R.dot(A0.dot(R.T))
# Build mass matrix B
nd = nx_dofs*ny_dofs
b = wvals
# Bl is not the local mass matrix.
# I am just using bad notation here
Bl = sps.dia_matrix((b, 0), shape=(nd,nd))
Binv_data = (1.0/Bl.data).ravel()
Binv_data = R.dot(Binv_data)
if nd<=1e3:
    print np.min(np.linalg.svd(A.toarray())[1])
    
In [ ]:
    
def apply_A(x):
    
    x = R.T.dot(x)
    x = Q.dot(x)
    x = x.reshape((n_elem, nx*ny))
    y = np.zeros_like(x)
    for i in xrange(n_elem):
        Dx = D1.dot(x[i])
        y[i] += D1.T.dot(G11[i].dot(Dx))
        y[i] += D2.T.dot(G21[i].dot(Dx))
        Dx = D2.dot(x[i])
        y[i] += D1.T.dot(G12[i].dot(Dx))
        y[i] += D2.T.dot(G22[i].dot(Dx))
        
    return R.dot(Q.T.dot(y.ravel()))
if periodic:
    nn = n_dofs
else:
    nn = (ny_dofs-2)*(nx_dofs-2)
linOp = LinearOperator((nn, nn), matvec=apply_A)
    
In [ ]:
    
fh  = f2(dof_phys)
fl = fh
rhs = Bl.dot(fl)
radj = np.zeros(nx_dofs*ny_dofs)
radj[boundary_dofs] = f(dof_phys)[boundary_dofs]
rhs = R.dot(rhs-A0.dot(radj))
if periodic:
    rhs -= np.mean(rhs)
    
In [ ]:
    
# Check apply_A against full matrix
norm(apply_A(rhs)-A.dot(rhs))
    
In [ ]:
    
ml = ruge_stuben_solver(A)
residuals = []
sol = R.T.dot(ml.solve(rhs, tol=1e-14, 
                       maxiter=1000, residuals=residuals,
                       accel='cg'))
sol[boundary_dofs] = f(dof_phys)[boundary_dofs]
if periodic:
    sol -= np.mean(sol)
print len(residuals), residuals[-1]
print 
print norm(f(dof_phys)-sol)/norm(f(dof_phys))
    
In [ ]:
    
class CB(object):
    def __init__(self):
        self.n_iter = 0
    def __call__(self, x):
        self.n_iter += 1
        
cb = CB()
solcg, errc = sps.linalg.cg(linOp, rhs, tol=1e-14, maxiter=2000,
                            callback=cb)
solcg = R.T.dot(solcg)
if periodic:
    solcg -= np.mean(solcg)
else:
    solcg[boundary_dofs] = f(dof_phys[boundary_dofs])
    
print cb.n_iter, norm(rhs-apply_A(R.dot(solcg)))
print
print norm(f(dof_phys)-solcg)/norm(f(dof_phys))
print norm(sol-solcg)
    
In [ ]:
    
fig = plt.figure()
ax = fig.gca(projection='3d')
s = (ny_dofs,nx_dofs)
X, Y = dof_phys[:,0], dof_phys[:,1]
X = X.reshape(s)
Y = Y.reshape(s)
ax.plot_wireframe(X, Y, f(dof_phys).reshape(s))
ax.plot_wireframe(X, Y, sol.reshape(s),
                  color='g')
plt.savefig("sol.pdf")
    
In [ ]:
    
plt.scatter(dof_phys[:,0], dof_phys[:,1])
    
In [ ]: