In [ ]:
%matplotlib inline
In [ ]:
import numpy as np
from numpy import newaxis as na
import scipy
import scipy.sparse as sps
from scipy.sparse.linalg import spsolve, LinearOperator
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
In [ ]:
from pyamg.classical import ruge_stuben_solver
from pyfem.sem import SEMhat
from pyfem.topo import Interval
norm = lambda x: np.max(np.abs(x)) if len(x)>0 else 0.0
kron3 = lambda x,y,z: sps.kron(x,sps.kron(y,z))
In [ ]:
from tensormesh import HexCubePoisson
In [ ]:
N = 8
Ex = 4
Ey = Ex
Ez = Ex
nx = N+1
ny = N+1
nz = N+1
nx_dofs = N*Ex+1
ny_dofs = N*Ey+1
nz_dofs = N*Ez+1
n_elem = Ex*Ey*Ez
periodic = True
# periodic = False
if periodic:
nx_dofs -= 1
ny_dofs -= 1
nz_dofs -= 1
n_dofs = nz_dofs*ny_dofs*nx_dofs
do_assemble = n_dofs<1e5
semh = SEMhat(N)
In [ ]:
def l1(X):
return (1.0-X)/2.0
def dl1(X):
return -np.ones_like(X)/2.0
def l2(X):
return (1.0+X)/2.0
def dl2(X):
return np.ones_like(X)/2.0
In [ ]:
# def f(X):
# x = X[:,0]
# y = X[:,1]
# z = X[:,2]
# return np.sin(np.pi*x)*np.sin(np.pi*y)*np.sin(np.pi*z)
# def f2(X):
# x = X[:,0]
# y = X[:,1]
# z = X[:,2]
# return np.sin(np.pi*x)*np.sin(np.pi*y)*np.sin(np.pi*z)*3*(np.pi)**2
def f(X):
x = X[:,0]
y = X[:,1]
z = X[:,2]
p = np.pi*2
r = np.cos(p*x)*np.cos(p*y)*np.cos(p*z)
if not periodic:
r += x
return r
def f2(X):
x = X[:,0]
y = X[:,1]
z = X[:,2]
p = np.pi*2
return np.cos(p*x)*np.cos(p*y)*np.cos(p*z)*3*(p)**2
In [ ]:
def ref_to_phys(X, nodes):
v1 = l1(X)
v2 = l2(X)
v1x, v1y, v1z = v1.T
v2x, v2y, v2z = v2.T
P = (v1x*v1y*v1z)[:,na]*nodes[0,:]+\
(v2x*v1y*v1z)[:,na]*nodes[1,:]+\
(v1x*v2y*v1z)[:,na]*nodes[2,:]+\
(v2x*v2y*v1z)[:,na]*nodes[3,:]+\
(v1x*v1y*v2z)[:,na]*nodes[4,:]+\
(v2x*v1y*v2z)[:,na]*nodes[5,:]+\
(v1x*v2y*v2z)[:,na]*nodes[6,:]+\
(v2x*v2y*v2z)[:,na]*nodes[7,:]
return P
def calc_jacb(X, nodes):
v1 = l1(X)
v2 = l2(X)
v1x, v1y, v1z = v1.T
v2x, v2y, v2z = v2.T
dv1 = dl1(X)
dv2 = dl2(X)
dv1x, dv1y, dv1z = dv1.T
dv2x, dv2y, dv2z = dv2.T
J = np.zeros((len(X),3,3))
t1x, t2x, t3x = ((dv1x*v1y*v1z)[:,na]*nodes[0,:]+\
(dv2x*v1y*v1z)[:,na]*nodes[1,:]+\
(dv1x*v2y*v1z)[:,na]*nodes[2,:]+\
(dv2x*v2y*v1z)[:,na]*nodes[3,:]+\
(dv1x*v1y*v2z)[:,na]*nodes[4,:]+\
(dv2x*v1y*v2z)[:,na]*nodes[5,:]+\
(dv1x*v2y*v2z)[:,na]*nodes[6,:]+\
(dv2x*v2y*v2z)[:,na]*nodes[7,:]).T
t1y, t2y, t3y = ((v1x*dv1y*v1z)[:,na]*nodes[0,:]+\
(v2x*dv1y*v1z)[:,na]*nodes[1,:]+\
(v1x*dv2y*v1z)[:,na]*nodes[2,:]+\
(v2x*dv2y*v1z)[:,na]*nodes[3,:]+\
(v1x*dv1y*v2z)[:,na]*nodes[4,:]+\
(v2x*dv1y*v2z)[:,na]*nodes[5,:]+\
(v1x*dv2y*v2z)[:,na]*nodes[6,:]+\
(v2x*dv2y*v2z)[:,na]*nodes[7,:]).T
t1z, t2z, t3z = ((v1x*v1y*dv1z)[:,na]*nodes[0,:]+\
(v2x*v1y*dv1z)[:,na]*nodes[1,:]+\
(v1x*v2y*dv1z)[:,na]*nodes[2,:]+\
(v2x*v2y*dv1z)[:,na]*nodes[3,:]+\
(v1x*v1y*dv2z)[:,na]*nodes[4,:]+\
(v2x*v1y*dv2z)[:,na]*nodes[5,:]+\
(v1x*v2y*dv2z)[:,na]*nodes[6,:]+\
(v2x*v2y*dv2z)[:,na]*nodes[7,:]).T
J[:,0,0] = t1x
J[:,0,1] = t1y
J[:,0,2] = t1z
J[:,1,0] = t2x
J[:,1,1] = t2y
J[:,1,2] = t2z
J[:,2,0] = t3x
J[:,2,1] = t3y
J[:,2,2] = t3z
return J
In [ ]:
dofx = np.linspace(-1,1,Ex+1)
dofy = np.linspace(-1,1,Ey+1)
dofz = np.linspace(-1,1,Ez+1)
XYZ = np.zeros((Ez+1,Ey+1,Ex+1,3))
XYZ[:,:,:,0] = dofx[na,na,:]
XYZ[:,:,:,1] = dofy[na,:,na]
XYZ[:,:,:,2] = dofz[:,na,na]
vertex_ref = XYZ.reshape((-1,3))
vertex_phys = vertex_ref.copy()
vertex_phys[:,0] *= 1
vertex_phys[:,1] *= 1
vertex_phys[:,2] *= 1
chi, eta, zeta = vertex_ref.T
sx = sy = sz = 0.1
vp = vertex_phys
vp[:,0] = chi +sx*np.sin(np.pi*chi)*np.sin(np.pi*eta)*np.sin(np.pi*zeta)
vp[:,1] = eta +sy*np.sin(np.pi*chi)*np.sin(np.pi*eta)*np.sin(np.pi*zeta)
vp[:,2] = zeta+sz*np.sin(np.pi*chi)*np.sin(np.pi*eta)*np.sin(np.pi*zeta)
etn = np.zeros((n_elem, 8), dtype=np.int)
ind = 0
for iz in range(Ez):
for iy in range(Ey):
for ix in range(Ex):
etn[ind, 0] = ix+iy*(Ex+1)
etn[ind, 1] = ix+iy*(Ex+1)+1
etn[ind, 2] = ix+(iy+1)*(Ex+1)
etn[ind, 3] = ix+(iy+1)*(Ex+1)+1
etn[ind,:4] += iz*(Ex+1)*(Ey+1)
etn[ind,4:] = etn[ind,:4]+(Ex+1)*(Ey+1)
ind += 1
In [ ]:
# Build restriction operator
if periodic:
R0x = sps.eye(nx_dofs)
R0y = sps.eye(ny_dofs)
R0z = sps.eye(nz_dofs)
else:
R0x = sps.dia_matrix((np.ones(nx_dofs),1),
shape=(nx_dofs-2,nx_dofs))
R0y = sps.dia_matrix((np.ones(ny_dofs),1),
shape=(ny_dofs-2,ny_dofs))
R0z = sps.dia_matrix((np.ones(nz_dofs),1),
shape=(nz_dofs-2,nz_dofs))
R = kron3(R0z, R0y, R0x)
if not periodic:
bd = set(np.arange(n_dofs))-set(R.dot(np.arange(n_dofs)))
boundary_dofs = np.sort(np.array(list(bd)))
else:
boundary_dofs = np.array([],dtype=np.int)
boundary_dofs.sort()
In [ ]:
etd = np.zeros((n_elem, nx*ny*nz), dtype=np.int)
rngx = np.arange(nx)
rngy = np.arange(ny)
rngz = np.arange(nz)
nxy_dofs = nx_dofs*ny_dofs
ind = 0
for iz in range(Ez):
for iy in range(Ey):
for ix in range(Ex):
indz = iz*N
indy = iy*N
indx = ix*N
e = (rngx[na,na,:]+indx)%nx_dofs+\
((rngy[na,:,na]+indy)*nx_dofs)%nxy_dofs+\
(rngz[:,na,na]+indz)*nxy_dofs
e = e%n_dofs
etd[ind,:] = e.ravel()
ind += 1
cols = etd.ravel()
rows = np.arange(len(cols))
vals = np.ones(len(cols))
Q = sps.coo_matrix((vals,(rows,cols))).tocsr()
In [ ]:
wgll = semh.wgll
wv = (wgll[:,na,na]*wgll[na,:,na]*wgll[na,na,:]).ravel()
xgll = semh.xgll
n = len(xgll)
quad_ref = np.zeros((n,n,n,3))
quad_ref[:,:,:,0] = xgll[na,na,:]
quad_ref[:,:,:,1] = xgll[na,:,na]
quad_ref[:,:,:,2] = xgll[:,na,na]
quad_ref = quad_ref.reshape((-1,3))
In [ ]:
# build Gij
G11 = []
G12 = []
G13 = []
G21 = []
G22 = []
G23 = []
G31 = []
G32 = []
G33 = []
nn = nx*ny*nz
s = (nn, nn)
dof_phys = np.zeros((nx_dofs*ny_dofs*nz_dofs, 3))
wvals = np.zeros(nx_dofs*ny_dofs*nz_dofs)
for i in range(n_elem):
ver = vertex_phys[etn[i]]
J = calc_jacb(quad_ref, ver)
Ji = np.linalg.inv(J)
j = np.linalg.det(J).ravel()
dof_phys[etd[i],:] = ref_to_phys(quad_ref, ver)
G0 = np.matmul(Ji, np.transpose(Ji, (0,2,1)))
G0 *= (wv*j)[:,na,na]
wvals[etd[i]] += (wv*j)
G11 += [sps.dia_matrix((G0[:,0,0], 0), shape=s)]
G12 += [sps.dia_matrix((G0[:,0,1], 0), shape=s)]
G13 += [sps.dia_matrix((G0[:,0,2], 0), shape=s)]
G21 += [sps.dia_matrix((G0[:,1,0], 0), shape=s)]
G22 += [sps.dia_matrix((G0[:,1,1], 0), shape=s)]
G23 += [sps.dia_matrix((G0[:,1,2], 0), shape=s)]
G31 += [sps.dia_matrix((G0[:,2,0], 0), shape=s)]
G32 += [sps.dia_matrix((G0[:,2,1], 0), shape=s)]
G33 += [sps.dia_matrix((G0[:,2,2], 0), shape=s)]
In [ ]:
# Build poisson stiffness matrix A
D1 = kron3(sps.eye(nz), sps.eye(ny), semh.Dh)
D2 = kron3(sps.eye(nz), semh.Dh, sps.eye(nx))
D3 = kron3(semh.Dh, sps.eye(ny), sps.eye(nx))
if do_assemble:
A0a = []
for i in range(n_elem):
A0a += [D1.T.dot(G11[i].dot(D1)+G12[i].dot(D2)+G13[i].dot(D3))+\
D2.T.dot(G21[i].dot(D1)+G22[i].dot(D2)+G23[i].dot(D3))+\
D3.T.dot(G31[i].dot(D1)+G32[i].dot(D2)+G33[i].dot(D3))]
A0 = sps.block_diag(A0a).tocsr()
A0 = Q.T.dot(A0.dot(Q))
A = R.dot(A0.dot(R.T))
# Build mass matrix B
nd = nx_dofs*ny_dofs*nz_dofs
b = wvals
# Bl is not the local mass matrix.
# I am just using bad notation here
Bl = sps.dia_matrix((b, 0), shape=(nd,nd))
Binv_data = (1.0/Bl.data).ravel()
Binv_data = R.dot(Binv_data)
In [ ]:
def apply_A(x, apply_R=True, apply_Q=True):
if apply_R:
x = R.T.dot(x)
if apply_Q:
x = Q.dot(x)
x = x.reshape((n_elem, -1))
y = np.zeros_like(x)
for i in xrange(n_elem):
Dx = D1.dot(x[i])
y[i] += D1.T.dot(G11[i].dot(Dx))
y[i] += D2.T.dot(G21[i].dot(Dx))
y[i] += D3.T.dot(G31[i].dot(Dx))
Dx = D2.dot(x[i])
y[i] += D1.T.dot(G12[i].dot(Dx))
y[i] += D2.T.dot(G22[i].dot(Dx))
y[i] += D3.T.dot(G32[i].dot(Dx))
Dx = D3.dot(x[i])
y[i] += D1.T.dot(G13[i].dot(Dx))
y[i] += D2.T.dot(G23[i].dot(Dx))
y[i] += D3.T.dot(G33[i].dot(Dx))
y = y.ravel()
if apply_Q:
y = Q.T.dot(y)
if apply_R:
y = R.dot(y)
return y
if periodic:
nn = n_dofs
else:
nn = (nz_dofs-2)*(ny_dofs-2)*(nx_dofs-2)
linOp = LinearOperator((nn, nn), matvec=apply_A)
M = HexCubePoisson(N,Ex,L=2,periodic=periodic)
M.build_mesh()
precond = LinearOperator((nn,nn),
matvec=M.solve)
In [ ]:
fh = f2(dof_phys)
fl = fh
rhs = Bl.dot(fl)
radj = np.zeros(nx_dofs*ny_dofs*nz_dofs)
radj[boundary_dofs] = f(dof_phys)[boundary_dofs]
rhs = R.dot(rhs-apply_A(radj, apply_R=False))
exact = f(dof_phys)
if periodic:
rhs -= np.mean(rhs)
In [ ]:
# Check apply_A against full matrix
if do_assemble:
print norm(apply_A(rhs)-A.dot(rhs))
In [ ]:
if do_assemble:
ml = ruge_stuben_solver(A)
residuals = []
sol = R.T.dot(ml.solve(rhs, tol=1e-14,
maxiter=500, residuals=residuals,
accel='cg'))
sol[boundary_dofs] = f(dof_phys)[boundary_dofs]
if periodic:
sol -= sol[0]
exact -= exact[0]
print len(residuals), residuals[-1]
print
print norm(exact-sol)/norm(exact)
In [ ]:
class CB(object):
def __init__(self):
self.n_iter = 0
def __call__(self, x):
self.n_iter += 1
cb = CB()
solcg, errc = sps.linalg.cg(linOp, rhs, tol=1e-14,
maxiter=2000, callback=cb,
M=precond)
# solcg, errc = sps.linalg.cg(linOp, rhs, tol=1e-14,
# maxiter=2000, callback=cb)
solcg = R.T.dot(solcg)
if periodic:
solcg -= solcg[0]
exact -= exact[0]
else:
solcg[boundary_dofs] = f(dof_phys[boundary_dofs])
print cb.n_iter, norm(rhs-apply_A(R.dot(solcg)))
print
print norm(exact-solcg)/norm(exact)
if do_assemble:
print norm(sol-solcg)
In [ ]:
dp = dof_phys.reshape((nz_dofs,ny_dofs,nx_dofs,3))
if periodic:
dp = dp.copy()
dp[dp==1.0] = -1.0
ds = solcg.reshape((nz_dofs,ny_dofs,nx_dofs))
In [ ]:
fig = plt.figure()
ax = fig.gca(projection='3d')
s = (ny_dofs,nx_dofs)
k = int(0.4*nz_dofs)
X = dp[k,:,:,0]
Y = dp[k,:,:,1]
ax.plot_wireframe(X, Y, exact.reshape((nz_dofs,ny_dofs,nx_dofs))[k,:,:])
ax.plot_wireframe(X, Y, ds[k,:,:].reshape(s),
color='g')
plt.show()
In [ ]:
plt.scatter(X,Y)
In [ ]: