ODE solver based on Rune-Kutta method

Suppose that we have a state-space represented differential equation like : $$ \frac{d \mathbf x}{dt} = \mathbf A \mathbf x(t) + \mathbf b \mathbf u(t) $$

Using Runge-Kutta algorithm, the update equation from $\mathbf x(t)$ to $\mathbf x(t+\Delta t)$ is approximated as : $$ \mathbf x(t+\Delta t) \simeq \mathbf x(t) = \frac{\Delta t}{6} \left \{ \mathbf d_1 + 2 \mathbf d_2 + 2 \mathbf d_3 + \mathbf d_4 \right \} $$

where $\Delta t$ is the simulation time-step, and the gradients $\mathbf d_i, i \in \{1,2,3,4 \}$ are defined as :

\begin{align} \mathbf d_1 & = \mathbf A \mathbf x + \mathbf b \mathbf u \\ \mathbf x_1 & = \mathbf x + \frac{\Delta t}{2}\mathbf d_1 \\ \mathbf d_2 & = \mathbf A \mathbf x_1 + \mathbf b \mathbf u \\ \mathbf x_2 & = \mathbf x + \frac{\Delta t}{2}\mathbf d_2 \\ \mathbf d_3 & = \mathbf A \mathbf x_2 + \mathbf b \mathbf u \\ \mathbf x_3 & = \mathbf x + \Delta t\mathbf d_3 \\ \mathbf d_4 & = \mathbf A \mathbf x_3 + \mathbf b \mathbf u \end{align}

The following script is an example of the Runge-Kutta algorithm for simulating a one-dimentional spring-mass-damper model.


In [1]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
# The coefficients for the spring-mass-damper model.
m = 10
k = 50
c = 10

# Simulation time and step.
T = 10
dt = 1e-3

# the dimension of the state-space.
d = 2

# Initial State
x0 = 0
v0 = 0

# Coefficient matrix for the LTI (linear-time-invariant) system.
A = np.asarray([ [0, 1], [-k/m, -c/m]])
b = np.asarray([0,1/m])

In [3]:
N = int(T/dt)

# Initialize all data
t_list = np.arange(N) * dt
u_list = np.zeros_like(t_list)
u_list[ np.logical_and(1 <= t_list, t_list <= 2) ] = 10
x_list = np.zeros((d, N))
x = x_list[:,0] = np.asarray([x0,v0])

for n in range(1, N):    
    u = u_list[n]
    
    d1 = np.dot(A, x) + np.dot(b, u)
    x1 = x + d1 * dt / 2
    d2 = np.dot(A, x1) + np.dot(b, u)
    x2 = x + d2 * dt / 2
    d3 = np.dot(A, x2) + np.dot(b, u)
    x3 = x + d3 * dt
    d4 = np.dot(A, x3) + np.dot(b, u)
    
    x = x + (d1 + 2*d2 + 2*d3 + d4) * dt / 6
    x_list[:,n] = x

(fig, axs) = plt.subplots(3, 1, figsize=(12, 5), sharex=True)
ax = axs[0]
ax.plot(t_list, x_list[0])
ax.set_ylabel('Pos (m)')
ax = axs[1]
ax.plot(t_list, x_list[1])
ax.set_ylabel('Velocity (m/s)')
ax = axs[2]
ax.plot(t_list, u_list)
ax.set_ylabel('Input (N)')
ax.set_xlabel('Time (s)')
for ax in axs:
    ax.grid()