In [ ]:
%matplotlib ipympl

import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np

In [ ]:
class FDSolver1DScalarWaveEQ(object):
    def __init__(self):
        self.it = 0
        self._class = str(self.__class__).strip('\'<>').split('.')[-1]

    @property
    def c0(self):
        return (self._c0)
    @c0.setter
    def c0(self, value):
        self._c0 = value
    
    @property
    def cmax(self):
        return (self._cmax)
    @cmax.setter
    def cmax(self, value):
        self._cmax = value
        
    @property
    def cc(self):
        return (self._cc)
    @cc.setter
    def cc(self, value):
        self._cc = value

    @property
    def d2px(self):
        return (self._d2px)
    @d2px.setter
    def d2px(self, value):
        self._d2px = value

    @property
    def dx(self):
        return (self._dx)
    @dx.setter
    def dx(self, value):
        self._dx = value

    @property
    def dt(self):
        return (self._dt)
    @dt.setter
    def dt(self, value):
        self._dt = value

    @property
    def f0(self):
        return (self._f0)
    @f0.setter
    def f0(self, value):
        self._f0 = value

    @property
    def fmax(self):
        return (self._fmax)
    @fmax.setter
    def fmax(self, value):
        self._fmax = value
    
    @property
    def it(self):
        return (self._it)
    @it.setter
    def it(self, value):
        self._it = value

    @property
    def nx(self):
        return (self._nx)
    @nx.setter
    def nx(self, value):
        self._nx = value

    @property
    def nt(self):
        return (self._nt)
    @nt.setter
    def nt(self, value):
        self._nt = value

    @property
    def psi(self):
        return (self._psi)
    @psi.setter
    def psi(self, value):
        self._psi = value
        
    @property
    def src_coords(self):
        return (self._src_coords)
    @src_coords.setter
    def src_coords(self, value):
        for attr in ('xmin', 'dx'):
            if not hasattr(self, attr):
                raise (AttributeError(f'Please run {self._class}.initialize_computational_domain() before setting {self._class}.src_coords'))
        self._src_coords = value
        self.src_idx = int((value - self.xmin) // self.dx)
    
    @property
    def src_idx(self):
        return (self._src_idx)
    @src_idx.setter
    def src_idx(self, value):
        self._src_idx = value

    @property
    def src_tf(self):
        return (self._src_tf)
    @src_tf.setter
    def src_tf(self, value):
        self._src_tf = value

    @property
    def t0(self):
        return (self._t0)
    @t0.setter
    def t0(self, value):
        self._t0 = value
        
    @property
    def tmin(self):
        return (self._tmin)
    @tmin.setter
    def tmin(self, value):
        self._tmin = value

    @property
    def tmax(self):
        return (self._tmax)
    @tmax.setter
    def tmax(self, value):
        self._tmax = value

    @property
    def tt(self):
        return (self._tt)
    @tt.setter
    def tt(self, value):
        self._tt = value

    @property
    def velocity_model(self):
        return (self._velocity_model)
    @velocity_model.setter
    def velocity_model(self, value):
        if not hasattr(self, 'nx'):
            raise (AttributeError(f'Please run {self._class}.initialize_computational_domain() before setting {self._class}.velocity_model'))
        self._velocity_model = value
        self.cc              = np.array(list(map(value, self.xx)))

    @property
    def xmin(self):
        return (self._xmin)
    @xmin.setter
    def xmin(self, value):
        self._xmin = value

    @property
    def xmax(self):
        return (self._xmax)
    @xmax.setter
    def xmax(self, value):
        self._xmax = value

    @property
    def xx(self):
        return (self._xx)
    @xx.setter
    def xx(self, value):
        self._xx = value

    def initialize_computational_domain(self):
        for attr in ('xmin', 'xmax', 'tmin', 'tmax', 'cmin', 'cmax', 'fmax'):
            if not hasattr(self, attr):
                raise (AttributeError(f'Please set {self._class}.{attr} before running {self._class}.initialize_computational_domain()'))
        self.dx      = self.cmin / self.fmax / 20
        self.nx      = int((self.xmax - self.xmin) // self.dx)
        self.dt      = self.dx / self.cmax # Courant–Friedrichs–Lewy condition
        self.nt      = int((self.tmax - self.tmin) // self.dt)
        self.xx      = np.linspace(self.xmin, self.xmin + self.nx * self.dx, self.nx)
        self.tt      = np.linspace(self.tmin, self.tmin + self.nt * self.dt, self.nt)
        self.psi     = np.zeros((self.nx, 3))
        self.d2px    = np.zeros(self.nx) # 2nd space derivative of psi

    def define_src_tf(self):
        '''
        Defines a Gaussian-derivative wavelet with center frequency equal to fmax / 4
        '''
        for attr in ('fmax', 'tt', 'dt'):
            if not hasattr(self, attr):
                raise (AttributeError(f'Please run {self._class}.initialize_computational_domain() before running {self._class}.define_src_tf()'))
        self.t0 = 4 / self.fmax
        self.src_tf  = -2. * (self.tt - self.t0) * (self.fmax ** 2) * (np.exp(-1.0 * (self.fmax ** 2) * (self.tt - self.t0) ** 2))

    def plot_source_time_function(self):
        fig = plt.figure(figsize=(8, 4))
        ax = fig.add_subplot(1, 2, 1)
        ax.plot(self.tt, self.src_tf)
        ax.set_xlim(0, 4 * self.t0)
        ax.set_ylabel('Amplitude')
        ax.set_xlabel('Time [s]')
        ax = fig.add_subplot(1, 2, 2)
        SS = np.fft.fft(self.src_tf)
        ff = np.fft.fftfreq(SS.size, d=self.dt)
        SS, ff = SS[ff >= 0], ff[ff >= 0]
        ax.plot(ff, np.abs(SS))
        ax.set_xlim(0, self.fmax)
        ax.yaxis.tick_right()
        ax.yaxis.set_label_position('right')
        ax.set_ylabel('Amplitude')
        ax.set_xlabel('Frequency [$s^{-1}$]')
    
    def update(self, nstep=1):
        print(self.src_idx)
        for i in range(nstep):
            self.it += 1
            self.d2px[1:-1]                = (self.psi[2:, 1] - 2 * self.psi[1:-1, 1] + self.psi[:-2, 1]) / self.dx ** 2
            self.psi[:, 2]                 = 2 * self.psi[:, 1] - self.psi[:, 0] + self.cc ** 2 * self.dt ** 2 * self.d2px
            self.psi[self.src_idx, 2]      = self.psi[self.src_idx, 2] + self.src_tf[self.it] / (self.dx) * self.dt ** 2
            self.psi[:, 0], self.psi[:, 1] = self.psi[:, 1], self.psi[:, 2]

In [ ]:
def velocity_model(x):
    if x < -750:
        return (500)
    elif x < 750:
        return (334)
    else:
        return (250)

In [ ]:
solver = FDSolver1DScalarWaveEQ()
solver.xmin, solver.xmax = -1000, 1000
solver.cmin, solver.cmax = 250, 500
solver.tmin, solver.tmax = 0, 10
solver.fmax              = 20
solver.initialize_computational_domain()
solver.velocity_model = velocity_model
solver.define_src_tf()
solver.src_coords        = 0

In [ ]:
solver.plot_source_time_function()

In [ ]:
def update_line(idx, solver, ax, line):
    solver.update(nstep=10)
    line.set_ydata(solver.psi[:, 1])
    ax.text(0.05, 0.95, f'{solver.tmin + solver.it * solver.dt: 06.3f} s', ha='left', va='top', transform=ax.transAxes, bbox=dict(facecolor='w'))
    return (ax, line)


plt.close('all')
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
ax.set_xlim(solver.xmin, solver.xmax)
ymax = 1.1 * np.max(np.abs(solver.src_tf))
ax.set_ylim(-1e-2, 1e-2)
line, = ax.plot(solver.xx, solver.psi[:, 1])
line_ani = animation.FuncAnimation(fig, update_line, fargs=(solver, ax, line),
                                   interval=1, blit=True)

In [ ]:
np.max(np.abs(solver.src_tf))

In [ ]: