In [ ]:
%matplotlib ipympl

import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numba
import numpy as np
import scipy.signal

# Define computational precision
DTYPE_NUMBA_REAL = numba.float32
DTYPE_NUMPY_REAL = np.float32
DTYPE_NUMBA_UINT = numba.uint16
DTYPE_NUMPY_UINT = np.uint16

In [ ]:
@numba.jitclass(
    [
        ('xmin',    DTYPE_NUMBA_REAL),
        ('xmax',    DTYPE_NUMBA_REAL),
        ('zmin',    DTYPE_NUMBA_REAL),
        ('zmax',    DTYPE_NUMBA_REAL),
        ('tmin',    DTYPE_NUMBA_REAL),
        ('tmax',    DTYPE_NUMBA_REAL),
        ('vmin',    DTYPE_NUMBA_REAL),
        ('vmax',    DTYPE_NUMBA_REAL),
        ('fmax',    DTYPE_NUMBA_REAL),
        ('cfl',     DTYPE_NUMBA_REAL),
        ('tt',      DTYPE_NUMBA_REAL[:]),
        ('vv',      DTYPE_NUMBA_REAL[:,:]),
        ('src_tf',  DTYPE_NUMBA_REAL[:]),
        ('_src_loc', DTYPE_NUMBA_REAL[:]),
        ('psi',     DTYPE_NUMBA_REAL[:,:,:]),
        ('d2px',    DTYPE_NUMBA_REAL[:,:]),
        ('d2pz',    DTYPE_NUMBA_REAL[:,:]),
        ('it',      DTYPE_NUMBA_UINT),
        ('rx_ix',   DTYPE_NUMBA_UINT[:]),
        ('rx_iz',   DTYPE_NUMBA_UINT[:]),
        ('ss',      DTYPE_NUMBA_REAL[:,:]),
    ]
)
class FDSolver2DWaveEQ_DEP(object):
    def __init__(self):
        self.cfl      = 0.5
        self.it       = 0
        self._src_loc = np.zeros(2, dtype=DTYPE_NUMPY_REAL)

    @property
    def dt(self):
        return (self.cfl * self.dx / self.vmax)
    @property
    def dx(self):
        return (self.vmin / self.fmax / 20)
    
    @property
    def dz(self):
        return (self.vmin / self.fmax / 20)
    
    @property
    def nt(self):
        return (round((self.tmax - self.tmin) / self.dt))
    
    @property
    def nx(self):
        return (round((self.xmax - self.xmin) / self.dx))

    @property
    def nz(self):
        return (round((self.zmax - self.zmin) / self.dz))
    
    @property
    def src_ix(self):
        return (round((self.src_loc[0] - self.xmin) / self.dx))
    
    @property
    def src_iz(self):
        return (round((self.src_loc[1] - self.zmin) / self.dz))

    @property
    def src_loc(self):
        return (self._src_loc)

    @src_loc.setter
    def src_loc(self, value):
        self._src_loc[0] = value[0]
        self._src_loc[1] = value[1]

    @property
    def xx(self):
        return (np.linspace(self.xmin, self.xmax, self.nx).astype(DTYPE_NUMPY_REAL))

    @property
    def zz(self):
        return (np.linspace(self.zmin, self.zmax, self.nz).astype(DTYPE_NUMPY_REAL))


    def init_grid(self):
        self.psi     = np.zeros((self.nx, self.nz, 3), dtype=DTYPE_NUMPY_REAL)
        self.d2px    = np.zeros((self.nx, self.nz), dtype=DTYPE_NUMPY_REAL)
        self.d2pz    = np.zeros((self.nx, self.nz), dtype=DTYPE_NUMPY_REAL)
        self.vv      = self.vmin * np.ones((self.nx, self.nz), dtype=DTYPE_NUMPY_REAL)
    
    def init_src_tf(self):
        tt           = np.linspace(self.tmin, self.tmax, self.nt).astype(DTYPE_NUMPY_REAL)
        t0           = 4 / self.fmax
        self.src_tf  = (
            -2.*(tt-t0) * (self.fmax**2) * (np.exp(-1.0*(self.fmax**2) * (tt-t0)**2))
        ).astype(DTYPE_NUMPY_REAL)
        self.src_tf /= np.max(np.abs(self.src_tf))
        
    def init_receivers(self):
        self.ss = np.zeros((len(self.rx_ix), self.nt), dtype=DTYPE_NUMPY_REAL)

    def update(self, nstep=1):
        self.psi = _update(self, nstep)

@numba.njit(parallel=True)
def _update(solver, nstep):
    psi  = solver.psi
    d2px = solver.d2px
    d2pz = solver.d2pz
    ss   = solver.ss
    for i in range(nstep):
        solver.it += 1
        d2px[1:-1]                            = (psi[2:, :, 1] - 2 * psi[1:-1, :, 1] + psi[:-2, :, 1]) / solver.dx ** 2
        d2pz[:, 1:-1]                         = (psi[:, 2:, 1] - 2 * psi[:, 1:-1, 1] + psi[:, :-2, 1]) / solver.dz ** 2
        psi[:, :, 2]                          = 2 * psi[:, :, 1] - psi[:, :, 0] + solver.vv ** 2 * solver.dt ** 2 * (d2px + d2pz)
        psi[solver.src_ix, solver.src_iz, 2] += solver.src_tf[solver.it] / (solver.dx * solver.dz) * solver.dt ** 2
        psi[:, :, 0], psi[:, :, 1]            = psi[:, :, 1], psi[:, :, 2]
#         ss[:, solver.it]                      = psi[(solver.rx_ix, solver.rx_iz, 1)]
        print(ss[:, solver.it])
        print(psi[solver.rx_ix][solver.rx_iz][1])
#         print(psi[(solver.rx_ix, solver.rx_iz, 1)])
    return (psi)

In [ ]:
class FDSolver2DWaveEQ(object):
    def __init__(self):
        self.cfl      = 0.5
        self.it       = 0
        self._src_loc = np.zeros(2, dtype=DTYPE_NUMPY_REAL)
        self._solva   = Solva()

    @property
    def dt(self):
        return (self.cfl * self.dx / self.vmax)

    @property
    def dx(self):
        return (self.vmin / self.fmax / 20)
    
    @property
    def dz(self):
        return (self.vmin / self.fmax / 20)
    
    @property
    def nt(self):
        return (round((self.tmax - self.tmin) / self.dt))
    
    @property
    def nx(self):
        return (round((self.xmax - self.xmin) / self.dx))

    @property
    def nz(self):
        return (round((self.zmax - self.zmin) / self.dz))
    
    @property
    def psi(self):
        return (self._solva.psi)
    
    @property
    def rx_ix(self):
        return (self._rx_ix)
    
    @rx_ix.setter
    def rx_ix(self, value):
        self._rx_ix = np.array(value, dtype=DTYPE_NUMPY_UINT)

    @property
    def rx_iz(self):
        return (self._rx_iz)
    
    @rx_iz.setter
    def rx_iz(self, value):
        self._rx_iz = np.array(value, dtype=DTYPE_NUMPY_UINT)
    
    @property
    def src_ix(self):
        return (round((self.src_loc[0] - self.xmin) / self.dx))
    
    @property
    def src_iz(self):
        return (round((self.src_loc[1] - self.zmin) / self.dz))

    @property
    def src_loc(self):
        return (self._src_loc)

    @src_loc.setter
    def src_loc(self, value):
        self._src_loc = np.array(value, dtype=DTYPE_NUMPY_REAL)

    @property
    def src_tf(self):
        return (self._solva.src_tf)
    
    @property
    def ss(self):
        return (self._solva.ss)

    @property
    def vv(self):
        if not hasattr(self._solva, 'vv'):
            self._solva.vv = self.vmin * np.ones((self.nx, self.nz))
        return (self._solva.vv)
    
    @vv.setter
    def vv(self, value):
        self._solva.vv = value.astype(DTYPE_NUMPY_REAL)

    @property
    def xx(self):
        return (np.linspace(self.xmin, self.xmax, self.nx).astype(DTYPE_NUMPY_REAL))

    @property
    def zz(self):
        return (np.linspace(self.zmin, self.zmax, self.nz).astype(DTYPE_NUMPY_REAL))
    
    def init(self):
        self. tt           = np.linspace(self.tmin, self.tmax, self.nt).astype(DTYPE_NUMPY_REAL)
        t0           = 4 / self.fmax
        self._solva.psi    = np.zeros((self.nx, self.nz, 3), dtype=DTYPE_NUMPY_REAL)
        self._solva.d2px   = np.zeros((self.nx, self.nz), dtype=DTYPE_NUMPY_REAL)
        self._solva.d2pz   = np.zeros((self.nx, self.nz), dtype=DTYPE_NUMPY_REAL)
        self._solva.it     = self.it
        self._solva.dt     = self.dt
        self._solva.dx     = self.dx
        self._solva.dz     = self.dz
        self._solva.rx_ix  = self.rx_ix
        self._solva.rx_iz  = self.rx_iz
        self._solva.src_ix = self.src_ix
        self._solva.src_iz = self.src_iz
        self._solva.ss     = np.full((len(self.rx_ix), self.nt), fill_value=np.nan, dtype=DTYPE_NUMPY_REAL)
        self._solva.vv     = self.vv
        self._solva.src_tf = (
            -2.*(self.tt-t0) * (self.fmax**2) * (np.exp(-1.0*(self.fmax**2) * (self.tt-t0)**2))
        ).astype(
            DTYPE_NUMPY_REAL
        )
        
    def update(self, nstep=1):
        self._solva.update(nstep)
        
@numba.jitclass(
    [
        ('psi',    DTYPE_NUMBA_REAL[:,:,:]),
        ('d2px',   DTYPE_NUMBA_REAL[:,:]),
        ('d2pz',   DTYPE_NUMBA_REAL[:,:]),
        ('ss',     DTYPE_NUMBA_REAL[:,:]),
        ('it',     DTYPE_NUMBA_UINT),
        ('dt',     DTYPE_NUMBA_REAL),
        ('dx',     DTYPE_NUMBA_REAL),
        ('dz',     DTYPE_NUMBA_REAL),
        ('rx_ix',  DTYPE_NUMBA_UINT[:]),
        ('rx_iz',  DTYPE_NUMBA_UINT[:]),
        ('src_ix', DTYPE_NUMBA_UINT),
        ('src_iz', DTYPE_NUMBA_UINT),
        ('src_tf', DTYPE_NUMBA_REAL[:]),
        ('vv',     DTYPE_NUMBA_REAL[:,:])
    ]
)
class Solva(object):
    def __init__(self):
        pass
    
    def update(self, nstep=1):
        self.psi, self.ss = _update_psi(self, nstep)
    
@numba.njit(parallel=True)
def _update_psi(solva, nstep):
    psi = solva.psi
    d2px = solva.d2px
    d2pz = solva.d2pz
    ss  = solva.ss
    for i in range(nstep):
        for j in range(len(solva.rx_ix)):
            ss[j, solva.it] = psi[solva.rx_ix[j], solva.rx_iz[j], 1]
        solva.it += 1
        d2px[1:-1]                          = (psi[2:, :, 1] - 2 * psi[1:-1, :, 1] + psi[:-2, :, 1]) / solva.dx ** 2
        d2pz[:, 1:-1]                       = (psi[:, 2:, 1] - 2 * psi[:, 1:-1, 1] + psi[:, :-2, 1]) / solva.dz ** 2
        psi[:, :, 2]                        = 2 * psi[:, :, 1] - psi[:, :, 0] + solva.vv ** 2 * solva.dt ** 2 * (d2px + d2pz)
        psi[solva.src_ix, solva.src_iz, 2] += solva.src_tf[solva.it] / (solva.dx * solva.dz) * solva.dt ** 2
        psi[:, :, 0], psi[:, :, 1]          = psi[:, :, 1], psi[:, :, 2]
    return (psi, ss)

In [ ]:
solver = FDSolver2DWaveEQ()
solver.xmin, solver.xmax = 0, 20000
solver.zmin, solver.zmax = 0, 10000
solver.tmin, solver.tmax = 0, 3
solver.vmin, solver.vmax = 3500, 5500
solver.fmax              = 20
solver.src_loc           = 8000, 8000

x0 = round(solver.nx/2)
dx = round(solver.nx/20)
vv = 4500 * np.ones((solver.nx, solver.nz))
# vv[x0-dx:x0+dx] += 2500 * np.random.randint(-1, 2, (2*dx, solver.nz)) * np.random.rand(2*dx, solver.nz)
solver.vv = scipy.ndimage.gaussian_filter(vv, sigma=2)

solver.rx_ix = x0-2*dx, x0, x0+2*dx
solver.rx_iz = z0, z0, z0

solver.init()

In [ ]:
solver.update(1000)

In [ ]:
amax = solver.src_tf.max() / (solver.dx*solver.dz) * solver.dt**2
xx, zz = np.meshgrid(solver.xx, solver.zz)
plt.close('all')
fig = plt.figure()
ax = fig.add_subplot(2, 1, 1, aspect=1)
ax.imshow(solver.vv.T, cmap=plt.get_cmap('gray'))
qmesh = ax.imshow(
    solver.psi[..., 1].T,
    vmin=-amax,
    vmax=amax,
    cmap=plt.get_cmap('seismic'),
    alpha=0.25
)
ax.scatter(solver.rx_ix, solver.rx_iz, marker='v', color='k')

ax = fig.add_subplot(2, 1, 2)
for ir in range(solver.ss.shape[0]):
    ss = solver.ss[ir]
    ss = ir + ss / (2 * np.max(np.abs(ss[~np.isnan(ss)])))
    ax.plot(ss, solver.tt, 'k')
ax.set_ylim(solver.tt.min(), solver.tt.max())
ax.invert_yaxis()

In [ ]:
def update_qmesh(idx, solver, ax, qmesh):
    solver.update(67)
    qmesh.set_array(solver.psi[:-1, :-1, 1].flatten())
    ax.text(0.05, 0.95, f'{solver.tmin + solver.it * solver.dt: 06.3f} s', ha='left', va='top', transform=ax.transAxes, bbox=dict(facecolor='w'))
    return (ax, qmesh)


plt.close('all')
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1, aspect=1)
qmesh = ax.pcolormesh(
    xx,
    zz,
    solver.psi[..., 1],
    vmin=-np.max(np.abs(solver.src_tf)) * solver.dt**2, 
    vmax=np.max(np.abs(solver.src_tf)) * solver.dt**2,
    cmap=plt.get_cmap('seismic')
)
ax.invert_yaxis()
qmesh_ani = animation.FuncAnimation(fig, update_qmesh, fargs=(solver, ax, qmesh),
                                   interval=1, blit=True, frames=60)
qmesh_ani.save('/Users/malcolmwhite/Desktop/anim.gif', writer='pillow', fps=5)

Velocity model


In [ ]:
nx, nz = 200, 200
vv = 4.5 * np.ones((nx, nz), dtype=DTYPE_NUMPY_REAL)
vv[90:110] += np.random.rand(20, nz)
vv[90:110] -= np.random.rand(20, nz)
vv = scipy.ndimage.gaussian_filter(vv, sigma=1)

In [ ]:
solver = FDSolver2DWaveEQ()
solver.xmin, solver.xmax = 0, 20000
solver.zmin, solver.zmax = 0, 10000
solver.tmin, solver.tmax = 0, 3
solver.vmin, solver.vmax = 3500, 5500
solver.fmax              = 20
solver.init_grid()
solver.init_src_tf()

x0, z0 = round(solver.nx/2), round(solver.nz/2)
dx     = round(solver.nx/20)
solver.rx_ix = np.array([x0-2*dx, x0+2*dx], dtype=DTYPE_NUMPY_UINT)
solver.rx_iz = np.array([z0, z0], dtype=DTYPE_NUMPY_UINT)
solver.init_receivers()

# x0 = round(solver.nx/2)
# dx = round(solver.nx/20)
# vv = 4500 * np.ones((solver.nx, solver.nz), dtype=DTYPE_NUMPY_REAL)
# vv[x0-dx:x0+dx] += 2500 * np.random.randint(-1, 2, (2*dx, solver.nz)) * np.random.rand(2*dx, solver.nz)
# solver.vv = scipy.ndimage.gaussian_filter(vv, sigma=2)
# solver.src_loc[:]        = solver.xmin + (solver.xmax - solver.xmin) / 4, solver.zmin + 3 * (solver.zmax - solver.zmin) / 4
amax = 3 / 4 * solver.src_tf.max() / (solver.dx*solver.dz) * solver.dt**2

In [ ]:
solver.update(100)
amax = 3 / 4 * solver.src_tf.max() / (solver.dx*solver.dz) * solver.dt**2
plt.close('all')
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1, aspect=1)
ax.imshow(solver.vv.T, cmap=plt.get_cmap('gray'))
qmesh = ax.imshow(
    solver.psi[..., 1].T,
    vmin=-amax,
    vmax=amax,
    cmap=plt.get_cmap('seismic'),
    alpha=0.25
)
ax.scatter([x0-2*dx, x0+2*dx], [500, 500], marker='v', color='k')