Learning Objectives: Understand the numerical solution of ODEs and use scipy.integrate.odeint
to solve and explore ODEs numerically.
In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
Many of the equations of Physics, Chemistry, Statistics, Data Science, etc. are Ordinary Differential Equation or ODEs. An ODE is a differential equation with the form:
$$ \frac{d\vec{y}}{dt} = \vec{f}\left(\vec{y}(t), t\right) $$The goal is usually to solve for the $N$ dimensional state vector $\vec{y}(t)$ at each time $t$ given some initial condition:
$$ \vec{y}(0) = \vec{y}_0 $$In this case we are using $t$ as the independent variable, which is common when studying differential equations that depend on time. But any independent variable may be used, such as $x$. Solving an ODE numerically usually involves picking a set of $M$ discrete times at which we wish to know the solution:
In [2]:
tmax = 10.0 # The max time
M = 100 # Use 100 times between [0,tmax]
t = np.linspace(0,tmax,M)
t
Out[2]:
It is useful to define the step size $h$ as:
$$ h = t_{i+1} - t_i $$
In [3]:
h = t[1]-t[0]
print("h =", h)
The numerical solution of an ODE will then be an $M\times N$ array $y_{ij}$ such that:
$$ \left[\vec{y}(t_i)\right]_j = y_{ij} $$In other words, the rows of the array $y_{ij}$ are the state vectors $\vec{y}(t_i)$ at times $t_i$. Here is an array of zeros having the right shape for the values of $N$ and $M$ we are using here:
In [4]:
N = 2 # 2d case
y = np.zeros((M, N))
print("N =", N)
print("M =", M)
print("y.shape =", y.shape)
A numerical ODE solver takes the i
th row of this array y[i,:]
and calculates the i+1
th row y[i+1,:]
. This process starts with the initial condition y[0,:]
and continues through all of the times with steps of size $h$. One of the core ideas of numerical ODE solvers is that the error at each step is proportional to $\mathcal{O}(h^n)$ where $n\geq1$. Because $h<1$ you can reduce the error by making $h$ smaller (up to a point) or finding an ODE solver with a larger value of $n$.
Here are some common numerical algorithms for solving ODEs:
There are many other specialized methods and tricks for solving ODEs (see this page). One of the most common tricks is to use an adaptive step size, which changes the value of $h$ at each step to make sure the error stays below a certain threshold.
SciPy provides a general purpose ODE solver, scipy.integrate.odeint
, that can handle a wide variety of linear and non-linear multidimensional ODEs.
In [5]:
from scipy.integrate import odeint
In [6]:
odeint?
To show how odeint
works, we will solve the Lotka–Volterra equations, an example of a predator-prey model:
where:
It is important to note here that $y(t)$ is different from the overall solutions vector $\vec{y}(t)$. In fact, perhaps confusingly, in this case $\vec{y}(t)=[x(t),y(t)]$.
To integrate this system of differential equations, we must define a function derivs
that computes the right-hand-side of the differential equation, $\vec{f}(\vec{y}(t), t)$. The signature of this function is set by odeint
itself:
def derivs(yvec, t, *args):
...
return dyvec
yvec
will be a 1d NumPy array with $N$ elements that are the values of the solution at
the current time, $\vec{y}(t)$.t
will be the current time.*args
will be other arguments, typically parameters in the differential equation.The derivs
function must return a 1d NumPy array with elements that are the values of the function $\vec{f}(\vec{y}(t), t)$.
In [62]:
def derivs(yvec, t, alpha, beta, delta, gamma):
x = yvec[0]
y = yvec[1]
dx = alpha*x - beta*x*y
dy = delta*x*y - gamma*y
return np.array([dx, dy])
Here are the parameters and initial condition we will use to solve the differential equation. In this case, our prey variable $x$ is the number of rabbits and the predator variable $y$ is the number of foxes (foxes eat rabbits).
In [57]:
nfoxes = 10
nrabbits = 20
ic = np.array([nrabbits, nfoxes])
maxt = 20.0
alpha = 1.0
beta = 0.1
delta = 0.1
gamma = 1.0
Here we call odeint
with our derivs
function, initial condition ic
, array of times t
and the extra parameters:
In [60]:
t = np.linspace(0, maxt, int(100*maxt))
soln = odeint(derivs, # function to compute the derivatives
ic, # array of initial conditions
t, # array of times
args=(alpha, beta, delta, gamma), # extra args
atol=1e-9, rtol=1e-8) # absolute and relative error tolerances
We can plot the componenets of the solution as a function of time as follows:
In [59]:
plt.plot(t, soln[:,0], label='rabbits')
plt.plot(t, soln[:,1], label='foxes')
plt.xlabel('t')
plt.ylabel('count')
plt.legend();
We can also make a parametric plot of $[x(t),y(t)]$:
In [67]:
plt.plot(soln[:,0], soln[:,1])
plt.xlim(0, 25)
plt.ylim(0, 25)
plt.xlabel('rabbits')
plt.ylabel('foxes');