Brendan Smithyman | September–October 2014
NB: For the comparisons here to be valid, the IPython parallel workers should be on the same node as the notebook server.
In [1]:
# Parallel cluster setup
from IPython.parallel import Client
pclient = Client()
dview = pclient[:]
lview = pclient.load_balanced_view()
nworkers = len(pclient.ids)
In [2]:
# Imports synced to worker nodes
with dview.sync_imports():
import numpy as np
import scipy.sparse as sp
import mkl
import SimPEG
from SimPEG import Utils
from zephyr import initHelmholtzNinePoint
nthreads = mkl.get_max_threads()
tpw = nthreads/nworkers
dview.apply_sync(mkl.set_num_threads, tpw)
# Local imports
import matplotlib.pyplot as plt
import matplotlib.cm as cm
%pylab inline
In [3]:
import matplotlib
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('png')
matplotlib.rcParams['savefig.dpi'] = 150 # Change this to adjust figure size
# Plotting options
font = {
'family': 'Bitstream Vera Sans',
'weight': 'normal',
'size': 12,
}
matplotlib.rc('font', **font)
In [4]:
cellSize = 1 # m
freq = 1e2 # Hz
velocity = 2500 # m/s
density = 2700 # units of density
Q = inf # can be inf
nx = 256 # count
nz = 256 # count
freeSurf = [True, False, False, False] # t r b l
dims = (nx,nz) # tuple
sLoc = [128,128] # x, z
nPML = 32
c = fliplr(np.ones(dims) * velocity + np.linspace(0,4000,nz).reshape((1,nz)))
rho = fliplr(np.ones(dims) * density)
nky = 80
kyrange = np.linspace(0, freq / c.min(), nky)
systemConfig = {
'dx': cellSize, # m
'dz': cellSize, # m
'freq': freq, # Hz
'c': c.T, # m/s
'rho': rho.T, # density
'Q': Q, # can be inf
'nx': nx, # count
'nz': nz, # count
'sLoc': sLoc, # x, z
'freeSurf': freeSurf, # t r b l
'nPML': nPML,
}
In [5]:
def modelfield (systemConfig, ky):
import numpy as np
import scipy.sparse as sp
import SimPEG
from SimPEG import Utils
from zephyr import initHelmholtzNinePoint
Solver = SimPEG.SolverWrapD(sp.linalg.splu)
sc = systemConfig.copy()
sc['ky'] = ky
mesh, A = initHelmholtzNinePoint(sc)
Ainv = Solver(A)
# Single source
sLoc = sc['sLoc']
q = np.zeros(mesh.nN)
qI = Utils.closestPoints(mesh, sLoc, gridLoc='N')
q[qI] = 1
u = Ainv * q
return u
In [6]:
%%time
mesh, A = initHelmholtzNinePoint(systemConfig)
Solver = SimPEG.SolverWrapD(sp.linalg.splu)
Ainv = Solver(A)
q = np.zeros(mesh.nN)
qI = Utils.closestPoints(mesh, sLoc, gridLoc='N')
q[qI] = 1
u0 = Ainv * q
print('2D Mode: Solving one 2D problem')
print('Serial with {0} MKL threads'.format(nthreads))
In [7]:
%%time
fieldsSer = map(modelfield, [systemConfig.copy() for i in xrange(nky)], kyrange)
uSer = np.array(fieldsSer).sum(axis=0)
print('2.5D Mode: Solving {0} 2D problems to form 3D wavefield'.format(nky))
print('Serial with {0} MKL threads'.format(nthreads))
In [8]:
%%time
res = dview.map(modelfield, [systemConfig.copy() for i in xrange(nky)], kyrange)
uPar = np.array(res.get()).sum(axis=0)
print('2.5D Mode: Solving {0} 2D problems to form 3D wavefield'.format(nky))
print('Parallel direct on {0} workers with {1} MKL threads each'.format(nworkers, tpw))
In [9]:
%%time
res = lview.map(modelfield, [systemConfig.copy() for i in xrange(nky)], kyrange, ordered=False)
uBal = np.array(res.get()).sum(axis=0)
print('2.5D Mode: Solving {0} 2D problems to form 3D wavefield'.format(nky))
print('Parallel balanced on {0} workers with {1} MKL threads each'.format(nworkers, tpw))
In [10]:
amax2D = abs(u0).max()
amax3D = abs(uSer).max() * 1e-1
fig = plt.figure()
ax1 = plt.subplot(2,2,1, aspect=1)
plt.set_cmap(cm.bwr)
im1 = mesh.plotImage(u0, vType='N', ax=ax1)
im1[0].set_clim(-amax2D, amax2D)
plt.title('2D Wavefield')
ax2 = plt.subplot(2,2,2, aspect=1)
plt.set_cmap(cm.bwr)
im2 = mesh.plotImage(uSer, vType='N', ax=ax2)
im2[0].set_clim(-amax3D, amax3D)
plt.title('Serial 2.5D')
ax3 = plt.subplot(2,2,3, aspect=1)
plt.set_cmap(cm.bwr)
im3 = mesh.plotImage(uPar, vType='N', ax=ax3)
im3[0].set_clim(-amax3D, amax3D)
plt.title('Parallel Direct 2.5D')
ax4 = plt.subplot(2,2,4, aspect=1)
plt.set_cmap(cm.bwr)
im4 = mesh.plotImage(uPar, vType='N', ax=ax4)
im4[0].set_clim(-amax3D, amax3D)
plt.title('Parallel Balanced 2.5D')
fig.tight_layout()
In [10]: