In [ ]:
%matplotlib ipympl

import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np

In [ ]:
class FDSolver2DScalarWaveEQ(object):
    def __init__(self):
        self._it = 0
        self._class = str(self.__class__).strip('\'<>').split('.')[-1]

    @property
    def c0(self):
        return (self._c0)
    @c0.setter
    def c0(self, value):
        self._c0 = value
    
    @property
    def cmax(self):
        return (self._cmax)
    @cmax.setter
    def cmax(self, value):
        self._cmax = value
        
    @property
    def cc(self):
        return (self._cc)
    @cc.setter
    def cc(self, value):
        self._cc = value

    @property
    def d2px(self):
        return (self._d2px)
    @d2px.setter
    def d2px(self, value):
        self._d2px = value

    @property
    def d2py(self):
        return (self._d2py)
    @d2py.setter
    def d2py(self, value):
        self._d2py = value

    @property
    def dx(self):
        return (self._dx)
    @dx.setter
    def dx(self, value):
        self._dx = value

    @property
    def dy(self):
        return (self._dy)
    @dy.setter
    def dy(self, value):
        self._dy = value

    @property
    def dt(self):
        return (self._dt)
    @dt.setter
    def dt(self, value):
        self._dt = value

    @property
    def f0(self):
        return (self._f0)
    @f0.setter
    def f0(self, value):
        self._f0 = value

    @property
    def fmax(self):
        return (self._fmax)
    @fmax.setter
    def fmax(self, value):
        self._fmax = value
    

    @property
    def grid(self):
        return (
            np.moveaxis(np.stack(np.meshgrid(self.xx, self.yy, indexing='ij')), 0, -1)
        )


    @property
    def it(self):
        return (self._it)
    @it.setter
    def it(self, value):
        self._it = value

    @property
    def nx(self):
        return (self._nx)
    @nx.setter
    def nx(self, value):
        self._nx = value

    @property
    def nt(self):
        return (self._nt)
    @nt.setter
    def nt(self, value):
        self._nt = value

    @property
    def psi(self):
        return (self._psi)
    @psi.setter
    def psi(self, value):
        self._psi = value
        
    @property
    def src_coords(self):
        return (self._src_coords)

    @src_coords.setter
    def src_coords(self, value):
        for attr in ('xmin', 'ymin', 'dx', 'dy'):
            if not hasattr(self, attr):
                raise (AttributeError(f'Please run {self._class}.initialize_computational_domain() before setting {self._class}.src_coords'))
        self._src_coords = value
        self.src_idx = (int((value[0] - self.xmin) // self.dx), int((value[1] - self.ymin) // self.dy))
    
    @property
    def src_idx(self):
        return (self._src_idx)
    @src_idx.setter
    def src_idx(self, value):
        self._src_idx = value

    @property
    def src_tf(self):
        return (self._src_tf)
    @src_tf.setter
    def src_tf(self, value):
        self._src_tf = value

    @property
    def t0(self):
        return (self._t0)
    @t0.setter
    def t0(self, value):
        self._t0 = value
        
    @property
    def tmin(self):
        return (self._tmin)
    @tmin.setter
    def tmin(self, value):
        self._tmin = value

    @property
    def tmax(self):
        return (self._tmax)
    @tmax.setter
    def tmax(self, value):
        self._tmax = value

    @property
    def tt(self):
        return (self._tt)
    @tt.setter
    def tt(self, value):
        self._tt = value

    @property
    def vv(self):
        return (self._vv)
    @vv.setter
    def vv(self, value):
        if not hasattr(self, 'nx') or not hasattr(self, 'ny'):
            raise (AttributeError(f'Please run {self._class}.initialize_computational_domain() before setting {self._class}.velocity_model'))
        self._vv = value
        self.cc  = np.apply_along_axis(value, -1, self.grid)

    @property
    def xmin(self):
        return (self._xmin)
    @xmin.setter
    def xmin(self, value):
        self._xmin = value

    @property
    def xmax(self):
        return (self._xmax)
    @xmax.setter
    def xmax(self, value):
        self._xmax = value

    @property
    def xx(self):
        return (self._xx)
    @xx.setter
    def xx(self, value):
        self._xx = value

    @property
    def ymin(self):
        return (self._ymin)
    @ymin.setter
    def ymin(self, value):
        self._ymin = value

    @property
    def ymax(self):
        return (self._ymax)
    @ymax.setter
    def ymax(self, value):
        self._ymax = value

    @property
    def yy(self):
        return (self._yy)
    @yy.setter
    def yy(self, value):
        self._yy = value

    def initialize_computational_domain(self):
        for attr in ('xmin', 'xmax', 'ymin', 'ymax','tmin', 'tmax', 'cmin', 'cmax', 'fmax'):
            if not hasattr(self, attr):
                raise (AttributeError(f'Please set {self._class}.{attr} before running {self._class}.initialize_computational_domain()'))
        self.dx      = self.cmin / self.fmax / 20
        self.nx      = int((self.xmax - self.xmin) // self.dx)
        self.dy      = self.dx
        self.ny      = int((self.ymax - self.ymin) // self.dy)
        self.dt      = 0.5 * self.dx / self.cmax # Courant–Friedrichs–Lewy condition
        self.nt      = int((self.tmax - self.tmin) // self.dt)
        self.xx      = np.linspace(self.xmin, self.xmax, self.nx)
        self.yy      = np.linspace(self.ymin, self.ymax, self.ny)
        self.tt      = np.linspace(self.tmin, self.tmax, self.nt)
        self.psi     = np.zeros((self.nx, self.ny, 3))
        self.d2px    = np.zeros((self.nx, self.ny)) # 2nd space derivative of psi wrt x
        self.d2py    = np.zeros((self.nx, self.ny)) # 2nd space derivative of psi wrt y

    def define_src_tf(self):
        '''
        Defines a Gaussian-derivative wavelet with center frequency equal to fmax / 4
        '''
        for attr in ('fmax', 'tt', 'dt'):
            if not hasattr(self, attr):
                raise (AttributeError(f'Please run {self._class}.initialize_computational_domain() before running {self._class}.define_src_tf()'))
        self.t0 = 4 / self.fmax
        self.src_tf  = -2. * (self.tt - self.t0) * (self.fmax ** 2) * (np.exp(-1.0 * (self.fmax ** 2) * (self.tt - self.t0) ** 2))

    def plot_source_time_function(self):
        fig = plt.figure(figsize=(8, 4))
        ax = fig.add_subplot(1, 2, 1)
        ax.plot(self.tt, self.src_tf)
        ax.set_xlim(0, 4 * self.t0)
        ax.set_ylabel('Amplitude')
        ax.set_xlabel('Time [s]')
        ax = fig.add_subplot(1, 2, 2)
        SS = np.fft.fft(self.src_tf)
        ff = np.fft.fftfreq(SS.size, d=self.dt)
        SS, ff = SS[ff >= 0], ff[ff >= 0]
        ax.plot(ff, np.abs(SS))
        ax.yaxis.tick_right()
        ax.yaxis.set_label_position('right')
        ax.set_ylabel('Amplitude')
        ax.set_xlabel('Frequency [$s^{-1}$]')
    
    def update_dep(self, nstep=1):
        '''
        Simple three-point operator is implemented here.
        '''
        for i in range(nstep):
            self.it += 1
            self.d2px[1:-1, :]                = (self.psi[2:, :, 1] - 2 * self.psi[1:-1, :, 1] + self.psi[:-2, :, 1]) / self.dx ** 2
            self.psi[:, :, 2]                 = 2 * self.psi[:, :, 1] - self.psi[:, :, 0] + self.cc ** 2 * self.dt ** 2 * self.d2px
            self.psi[self.src_idx[0], self.src_idx[1], 2]      = self.psi[self.src_idx[0], self.src_idx[1], 2] + self.src_tf[self.it] / (self.dx) * self.dt ** 2
            self.psi[:, :, 0], self.psi[:, :, 1] = self.psi[:, :, 1], self.psi[:, :, 2]


    def update(self, nstep=1):
        '''
        Simple three-point operator is implemented here.
        '''
        for i in range(nstep):
            self.it += 1
            self.d2px[1:-1]                = (self.psi[2:, :, 1] - 2 * self.psi[1:-1, :, 1] + self.psi[:-2, :, 1]) / self.dx ** 2
            self.d2py[:, 1:-1]             = (self.psi[:, 2:, 1] - 2 * self.psi[:, 1:-1, 1] + self.psi[:, :-2, 1]) / self.dy ** 2
            self.psi[:, :, 2]              = 2 * self.psi[:, :, 1] - self.psi[:, :, 0] + self.cc ** 2 * self.dt ** 2 * (self.d2px + self.d2py)
            self.psi[self.src_idx + (2,)] += self.src_tf[self.it] / (self.dx * self.dy) * self.dt ** 2
            self.psi[:, :, 0], self.psi[:, :, 1] = self.psi[:, :, 1], self.psi[:, :, 2]

In [ ]:
solver = FDSolver2DScalarWaveEQ()
solver.xmin, solver.xmax = 0, 100
solver.ymin, solver.ymax = 0, 500
solver.cmin, solver.cmax = 500, 1000
solver.tmin, solver.tmax = 0, 10
solver.fmax              = 25
solver.initialize_computational_domain()
solver.cc                = solver.cmin * np.ones((solver.nx, solver.ny))
solver.define_src_tf()
solver.src_coords        = 250, 125

In [ ]:
def update_qmesh(idx, solver, ax, qmesh):
    solver.update(nstep=10)
    qmesh.set_array(solver.psi[:-1, :-1, 1].flatten())
    ax.text(0.05, 0.95, f'{solver.tmin + solver.it * solver.dt: 06.3f} s', ha='left', va='top', transform=ax.transAxes, bbox=dict(facecolor='w'))
    return (ax, qmesh)


plt.close('all')
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1, aspect=1)
qmesh = ax.pcolormesh(
    solver.grid[..., 0],
    solver.grid[..., 1],
    solver.psi[..., 1],
    vmin=-np.max(np.abs(solver.src_tf)) * solver.dt**2, 
    vmax=np.max(np.abs(solver.src_tf)) * solver.dt**2,
    cmap=plt.get_cmap('seismic')
)
ax.invert_yaxis()
fig.colorbar(qmesh, ax=ax, orientation='horizontal')
qmesh_ani = animation.FuncAnimation(fig, update_qmesh, fargs=(solver, ax, qmesh),
                                   interval=1, blit=True)