In [2]:
import sklearn
import numpy as np
from numpy import linalg as LA
from scipy.integrate import odeint
import matplotlib.pyplot as plt
%matplotlib inline
### *Controlled linearized noisy cart-pole balancing problem*
### from "Reinforcement Learning for Humanoid Robotics"
### Jan Peters, Sethu Vijayakumar, Stefan Schaal, ICHR 2003
### refs: [peters03, mehta02]
### NOTES : LQR problem with noisy state evolution and noisy control

Nx = 4 # dimensionality of state space [x, xdot, a, adot]
Nt = 200
T = 4
# system parameters (in SI units)
F = 1
tau = 0.017 # [s]
g = 9.81 # [m/s^2]
nu = 13.2 # [1/s^2]
mu0 = 0
sig0 = 0.1
Sigmat = 0.01*sig0*np.eye(Nx)
t  = np.linspace(0, T, Nt)   # time grid
# reward parameters
Q = np.diag([1.25, 1, 12, 0.25])
R = 0.01
A = np.array([[1, tau, 0, 0],[0, 1, 0, 0],[0, 0, 1, tau],[0, 0, nu*tau, 1]])
B = np.diag([0, tau, 0, nu*tau/g])
#### CONTROL VARIABLES (later: to be learned, for now we use analytical optimal ones)
global K, eta, gamma
gamma = 0.2
K = np.array([5.7, 11.3, -82.1, -21.6])   # analytical opt.
eta = 1000 # "

In [3]:
#initial conditions [x0, x0dot, a0, a0dot] 
def ic(mu0, sig0, Nx):
    return sig0 * np.random.randn(Nx) + mu0

def reward(x,u,Q,R):
    return x*Q*x + u*R*u

def policy_map(eta):
    return 1/(1 + np.exp(eta)) + 0.1
def policy(K,x):
    m = K*x # control linear in the state
    s = policy_map(eta)
    u = m + s**2 * np.random.randn(Nx) # noisy control 
    # u = m # deterministic control
    return u
def check_stability():
    eigval, eigvec = LA.eig(A + B*K)
    if max(eigval) > 1/(gamma**2):
        print str(max(eigval)) + ": unstable!"
def f(x,t): 
    """system dynamics"""
    u = policy(K,x)
    mu = A.dot(x) + B.dot(u)
    #return Sigmat * np.random.randn(Nx) + mu
    return np.random.multivariate_normal(mu, Sigmat)

In [4]:
x0 = ic(mu0, sig0, Nx)  # IC

#print policy(K,x0)
#print policy_map(eta)

check_stability()
# solve the DEs
soln = odeint(f, x0, t)

xpos = soln[:, 0]
xvel = soln[:, 1]
angle= soln[:, 2]
anglevel= soln[:, 3]
plt.figure()
ax = plt.subplot(111)
plt.plot(t,xpos,t,xvel,t,angle,t,anglevel)
ax.set_ylim([-1,1])
#ax = plt.plot(t,xpos)
#ax.set_ylim([-1,1])


Excess work done on this call (perhaps wrong Dfun type).
Run with full_output = 1 to get quantitative information.
-c:9: RuntimeWarning: overflow encountered in exp
Out[4]:
(-1, 1)

In [55]:


In [ ]: