In [ ]:
import numpy as np
import matplotlib.pyplot as plt

from sympy.interactive import printing
from sympy import Eq, Function, Derivative, symbols, dsolve

printing.init_printing()

In [ ]:
t, z, Az, W = symbols('t z A_z, W',  real=True)
u = Function('u')

In [ ]:
dudt = Eq(Derivative(u(z), z, z) - W/Az*Derivative(u(z), z))
dudt

In [ ]:
dsolve(dudt)

In [ ]:
W = 1e-3  # (m/s)
Az = 10e-4  # (m^2/s)
z = np.linspace(0, 15, 100)  # (m)
Uo = 0.5  # m/s

In [ ]:
day1 = Uo * np.exp((W/Az) * -z)
day10 = Uo * np.exp((W/Az/10) * -z)
day1year = Uo * np.exp((W/Az/365) * -z)

fig, ax = plt.subplots(figsize=(4, 6))
ax.plot(day1, z, 'k')
ax.plot(day10, z, 'g')
ax.plot(day1year, z, 'r')
ax.invert_yaxis()

ax.set_xlabel(r'Speed [m s$^{-1}$]')
ax.set_ylabel('Depth [m]')