In [ ]:
%matplotlib ipympl

import matplotlib.pyplot as plt
import numpy as np

The scalar (acoustic) wave equation is

\begin{equation} \frac{\partial^2\psi}{\partial t^2} = c^2\nabla^2\psi + s \end{equation}

where

  • $\psi$ represents the pressure field
  • $c$ represents the propagation velocity
  • $\nabla^2$ is the Laplacian operator

In the 1D case, this simplifies to

\begin{equation} \frac{\partial^2\psi}{\partial t^2} = c^2\frac{\partial^2\psi}{\partial x^2} + s \end{equation}

The first-order finite-difference operator for the second derivative with respect to coordinate $u$ is

\begin{equation} \frac{\partial^2\psi}{\partial u^2} \approx \frac{\psi(u+\Delta u) - 2\psi(u) + \psi(u-\Delta u)}{\Delta u^2} \end{equation}

and substituting this in for $\frac{\partial^2\psi}{\partial t^2}$ and $\frac{\partial^2\psi}{\partial x^2}$ gives

\begin{equation} \frac{\psi(x, t+\Delta t) - 2\psi(x, t) + \psi(x, t-\Delta t)}{\Delta t^2} = c^2(x)\frac{\psi(x+\Delta x, t) - 2\psi(x, t) + \psi(x-\Delta x, t)}{\Delta x^2} + s(x, t) \end{equation}

From this we obtain our extrapolation operator for the time dimension: \begin{equation} \psi(x, t+\Delta t) = \Delta t^2\left[c^2(x)\frac{\psi(x+\Delta x, t) - 2\psi(x, t) + \psi(x-\Delta x, t)}{\Delta x^2} + s(x, t)\right] + 2\psi(x, t) - \psi(x, t-\Delta t) \end{equation}

With this, we can obtain all future values of $\psi$ from the current values. We will use upper and lower indices for time and space, respectively, to discretize this equation on a grid with node spacing $\Delta t$ and $\Delta x$ in time and space.

\begin{equation} \psi(x + i\Delta x, t + n\Delta t) = \psi^n_i \end{equation}

with $i, n \in \mathbb{N}$. i.e.

\begin{equation} \psi^{n+1}_i = c^2_i\frac{\Delta t^2}{\Delta x^2}\left(\psi^n_{i+1} - 2\psi^n_i + \psi^n_{i-1}\right) + 2\psi^n_i - \psi^{n-1}_i + \Delta t^2s^n_i \end{equation}

For the source time function, we will use a Gaussian-derivative wavelet: \begin{equation} s(x, t) = -f_0(t-t_0)exp\left(-\left[4f_0\left(t-t_0\right)\right]^2\right)\delta\left(x-x_0\right) \end{equation} where

  • $t_0$ represents the zero crossing
  • $f_0$ represents the dominant frequency
  • $x_0$ represents the source location
  • $\delta$ is the Dirac-delta function

In [ ]:
# Parameterize the propagation domain
c          = 100         # Wave speed [m/s]
xmin, xmax = -1000, 1000 # Computational domain [m]
tmin, tmax = 0, 20       # Computational domain [s]
f0         = 20          # Dominant frequency   [1/s]
t0         = 4 / f0      # Zero-crossing time   [s]
s0         = 0           # Source location      [m]

dx = c / f0 / 20 # Node spacing along x-axis    [m]
dt   = dx / c    # Node spacing along t-axis    [s]

xx = np.linspace(xmin, xmax, (xmax - xmin) / dx)
tt = np.linspace(tmin, tmax, (tmax - tmin) / dt)
nx   = xx.size       # Number of grid nodes along x-axis
nt   = tt.size       # Number of grid nodes along t-axis

psi     = np.zeros((nx, 3)) # Initialize the wavefield

# Initialize source-time function
ix_src = int((s0-xmin)//dx)     # X-index of source location
ss = - f0 * (tt - t0) * np.exp(-((4 * f0 * (tt - t0)) ** 2))
psi_src = ss / np.max(np.abs(ss))

Plot the source-time function


In [ ]:
plt.close('all')
fig = plt.figure(figsize=(8,4))
fig.suptitle('Source-time function')
ax = fig.add_subplot(1, 2, 1)
ax.plot(tt, psi_src)
ax.set_xlim(0, 3*t0)
ax.set_xlabel('Time [s]')
ax.set_ylabel('Normalized amplitude')
ax = fig.add_subplot(1, 2, 2)
SS = np.fft.fft(psi_src)
freq = np.fft.fftfreq(psi_src.size, d=dt)
SS = SS[freq >= 0]
SS /= np.max(np.abs(SS))
freq = freq[freq >= 0]
ax.plot(freq, np.abs(SS))
ax.set_xlim(np.min(freq), np.max(freq))
ax.set_xlabel('Frequency [Hz]')
ax.set_ylabel('Normalized amplitude')
ax.yaxis.tick_right()
ax.yaxis.set_label_position('right')
\begin{equation} \psi^{n+1}_i = c^2_i\frac{\Delta t^2}{\Delta x^2}\left(\psi^n_{i+1} - 2\psi^n_i + \psi^n_{i-1}\right) + 2\psi^n_i - \psi^{n-1}_i + \Delta t^2s^n_i \end{equation}

In [ ]:
for it in range(nt):
    psi[1:-1, 2] = c ** 2 * dt ** 2 / dx ** 2 * (psi[2:, 1] - 2 * psi[1:-1, 1] + psi[:-2, 1]) + 2 * psi[1:-1, 1] - psi[1:-1, 0] + psi_src[it] * dt ** 2
    if it % 50 == 0:
        fig = plt.figure()
        ax = fig.add_subplot(1, 1, 1)
        ax.plot(psi[:, 1])

In [ ]:


In [ ]:
plt.close('all')

In [33]:
%matplotlib ipympl

import matplotlib.pyplot as plt
import numpy as np
import matplotlib.animation as animation

wave_simulation_parameters = dict(
    c          = 100,         # Wave speed [m/s]
    xmin       = -1000,
    xmax       = 1000, # Computational domain [m]
    tmin       = 0,
    tmax       = 20,       # Computational domain [s]
    f0         = 20,          # Dominant frequency   [1/s]
    s0         = 0           # Source location      [m]
)

class Wave(object):
    def __init__(self, **kwargs):
        '''
        c          = 100         # Wave speed [m/s]
        xmin, xmax = -1000, 1000 # Computational domain [m]
        tmin, tmax = 0, 20       # Computational domain [s]
        f0         = 20          # Dominant frequency   [1/s]
        t0         = 4 / f0      # Zero-crossing time   [s]
        s0         = 0           # Source location      [m]
        '''
        self._c    = kwargs['c']
        self._xmin = kwargs['xmin']
        self._xmax = kwargs['xmax']
        self._tmin = kwargs['tmin']
        self._tmax = kwargs['tmax']
        self._f0   = kwargs['f0']
        self._t0   = kwargs['f0'] / 4
        self._s0   = kwargs['s0']
        self._dx = self._c / self._f0 / 20 # Node spacing along x-axis    [m]
        self._dt   = self._dx / self._c    # Node spacing along t-axis    [s]

        self._xx = np.linspace(self._xmin, self._xmax, (self._xmax - self._xmin) / self._dx)
        self._tt = np.linspace(self._tmin, self._tmax, (self._tmax - self._tmin) / self._dt)
        self._nx   = self._xx.size       # Number of grid nodes along x-axis
        self._nt   = self._tt.size       # Number of grid nodes along t-axis
        self._it   = 0

        self._psi     = np.zeros((self._nx, 3)) # Initialize the wavefield

        # Initialize source-time function
        self._ix_src = int((self._s0 - self._xmin) // self._dx)     # X-index of source location
        ss = - self._f0 * (self._tt - self._t0) * np.exp(-((4 * self._f0 * (self._tt - self._t0)) ** 2))
        self._src = ss / np.max(np.abs(ss))

    def _update(self):
        self._it += 1
        psi     = self._psi
        psi_src = self._src
        c       = self._c
        dx      = self._dx
        dt      = self._dt
        it      = self._it
        psi[1:-1, 2] = c ** 2 * dt ** 2 / dx ** 2 * (psi[2:, 1] - 2 * psi[1:-1, 1] + psi[:-2, 1]) + 2 * psi[1:-1, 1] - psi[1:-1, 0] + psi_src[it] * dt ** 2
        psi[:, 0] = psi[:, 1]
        psi[:, 1] = psi[:, 2]
        self._psi = psi
    

    def update(self):
        d2px = np.zeros(self._nx)
        for i in range(1, self._nx - 1):
            d2px[i] = (self._psi[i + 1, 1] - 2 * self._psi[i, 1] + self._psi[i - 1, 1]) / self._dx ** 2
        self._psi[1:-1, 2] = 2 * self._psi[1:-1, 1] - self._psi[1:-1, 0] + self._c ** 2 * self._dt ** 2 * d2px[1:-1]
        self._psi[self._ix_src, 2] += self._src[self._it] / (self._dx) * self._dt ** 2
        self._psi[:, 0], self._psi[:, 1] = self._psi[:, 1], self._psi[:, 2]

    def plot_source_time_function(self):
        fig = plt.figure(figsize=(8,4))
        fig.suptitle('Source-time function')
        ax = fig.add_subplot(1, 2, 1)
        ax.plot(self._tt, self._src)
        ax.set_xlim(0, 3*self._t0)
        ax.set_xlabel('Time [s]')
        ax.set_ylabel('Normalized amplitude')
        ax = fig.add_subplot(1, 2, 2)
        SS = np.fft.fft(self._src)
        freq = np.fft.fftfreq(self._src.size, d=self._dt)
        SS = SS[freq >= 0]
        SS /= np.max(np.abs(SS))
        freq = freq[freq >= 0]
        ax.plot(freq, np.abs(SS))
        ax.set_xlim(np.min(freq), np.max(freq))
        ax.set_xlabel('Frequency [Hz]')
        ax.set_ylabel('Normalized amplitude')
        ax.yaxis.tick_right()
        ax.yaxis.set_label_position('right')

    def plot(self):
        fig = plt.figure()
        ax = fig.add_subplot(1, 1, 1)
        ax.plot(self._xx, self._psi[:, 1])
        ax.set_xlim(self._xmin, self._xmax)
        ymax = np.max(np.abs(self._src))
        ax.set_ylim(-ymax, ymax)
        
wf = Wave(**wave_simulation_parameters)


/home/malcolmw/local/anaconda3/envs/py37/lib/python3.7/site-packages/ipykernel_launcher.py:38: DeprecationWarning: object of type <class 'float'> cannot be safely interpreted as an integer.
/home/malcolmw/local/anaconda3/envs/py37/lib/python3.7/site-packages/ipykernel_launcher.py:39: DeprecationWarning: object of type <class 'float'> cannot be safely interpreted as an integer.

In [ ]:
def update_psi(it, wf, ax, line):
    wf.update()
    line.set_ydata(wf._psi[:, 1])
    ax.text(0.05, 0.95, str(it), ha='left', va='top', transform=ax.transAxes, bbox=dict(facecolor='w'))
    return (ax, line)

wf = Wave(**wave_simulation_parameters)

fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)

line, = ax.plot(wf._xx, wf._psi[:, 0], 'r-')
ax.set_xlim(wf._xmin, wf._xmax)
ax.set_ylim(-1, 1)

line_ani = animation.FuncAnimation(fig, update_psi, wf._nt, fargs=(wf, ax, line),
                                   interval=50, blit=True)

In [34]:
wf = Wave(**wave_simulation_parameters)

for i in range(int(wf._t0 // wf._dt)):
    wf.update()


/home/malcolmw/local/anaconda3/envs/py37/lib/python3.7/site-packages/ipykernel_launcher.py:38: DeprecationWarning: object of type <class 'float'> cannot be safely interpreted as an integer.
/home/malcolmw/local/anaconda3/envs/py37/lib/python3.7/site-packages/ipykernel_launcher.py:39: DeprecationWarning: object of type <class 'float'> cannot be safely interpreted as an integer.

In [35]:
# wf._update()
plt.close('all')
wf.plot()



In [ ]:
for i in range(10):
    wf.update()
line.set_ydata(wf.psi[:, 1])

In [ ]:
def update_line(idx, data, line):
    line.set_ydata(data[idx])
    return line,

k, w = 0.5, 1
x = np.linspace(0, 100, 1000)
t = np.linspace(0, 10, 100)
xx, tt = np.meshgrid(x, t)
 
data = np.sin(k*xx - w*tt)

fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)

line, = ax.plot(x, data[0], 'r-')
ax.set_xlim(0, 100)
ax.set_ylim(-1, 1)

line_ani = animation.FuncAnimation(fig, update_line, 100, fargs=(data, line),
                                   interval=50, blit=True)

In [ ]:
l.set_ydata(data[10])

In [ ]: