In [1]:
import numpy as np
from matplotlib import pyplot as plt
import random
from sklearn.naive_bayes import GaussianNB
import math
from collections import Counter

In [7]:
LANE_WIDTH = 4
ROAD_LENGTH = 100
SPEED = 10
NUM_LANES = 3

def generate_labeled_data(N):
    trajectories = []
    labels = []
    for i in range(N):
        start_lane = random.choice([0,1,2])
        label = "keep"
        if start_lane == 0:
            end_lane = random.choice([0,1])
        elif start_lane == 1:
            end_lane = random.choice([0,1,2])
        else:
            end_lane = random.choice([1,2])
        if end_lane > start_lane: 
            label = "right"
        elif end_lane < start_lane:
            label = "left"
        start_s = 0.0
        dist = random.gauss(40, 3)
        end_s = start_s + dist
        start_d = LANE_WIDTH * start_lane + random.gauss(0, float(LANE_WIDTH) / 10)
        end_d  =  LANE_WIDTH * end_lane   + random.gauss(0, float(LANE_WIDTH) / 10)
        speed = max(random.gauss(SPEED, 1.0), 6)
        T = (end_s - start_s) / speed
        alpha_s = JMT([start_s, speed, 0], [end_s, speed + random.gauss(0,0.5), 0], T)
        alpha_d = JMT([start_d, 0, 0], [end_d, 0, 0], T)
        s = to_equation(alpha_s)
        s_dot = to_equation(differentiate(alpha_s))
        d = to_equation(alpha_d)
        d_dot = to_equation(differentiate(alpha_d))
        coords = []
        for i in range(0, int(T*100), int(T) ):
            t = float(i) / 100
            coords.append((s(t), d(t), s_dot(t), d_dot(t)))
        trajectories.append(coords)
        labels.append(label)
    return trajectories, labels

def differentiate(coefficients):
    new_cos = []
    for deg, prev_co in enumerate(coefficients[1:]):
        new_cos.append((deg+1) * prev_co)
    return new_cos


N = 1000
trajs, labels = generate_labeled_data(N)
TRAIN_DATA = trajs[:3 * N / 4]
TRAIN_LABELS = labels[:3 * N / 4]
TEST_DATA  = trajs[3 * N / 4:]
TEST_LABEL = labels[3 * N / 4:]

with o

T = 15
def process_obs(obs):
    s,d,s_dot, d_dot = obs
    return (
        s,
#             (d % LANE_WIDTH) ** 2,
#             d_dot ** 2,
#             abs(d),
#             d/d_dot,
            )
def train_naive_bayes(data,labels, t):
    X = []
    transform_label = {'left' : -1, 'keep' : 0, 'right' : 1}
    Y = [transform_label[lab] for lab in labels]
    for trajectory in data:
        obs = trajectory[t]
        X.append(process_obs(obs))
    clf = GaussianNB()
    clf.fit(X, Y)
    return clf

# clf = train_naive_bayes(TRAIN_DATA, TRAIN_LABELS, T)


[  1.97542344e-02  -4.26268092e-03   1.24951773e-05]
[  1.32243103e-03  -2.13547214e-04   4.21180890e-07]
[  2.50135994e-04  -5.28782474e-05   1.45901362e-07]
[ -1.18374202e-01   1.87280731e-02  -3.47560767e-05]
[  1.35391514e-02  -3.75185194e-03   2.28737135e-05]
[  3.27936696e-01  -6.79154447e-02   2.80242622e-04]
[ -2.65292731e-02   6.80997047e-03  -3.32174770e-05]
[ -2.51908493e-01   4.83491994e-02  -1.59281883e-04]
[  1.89865453e-02  -5.19526752e-03   3.05295647e-05]
[  2.46722491e-01  -5.04580221e-02   2.00611650e-04]
[  3.73718652e-02  -7.47800628e-03   1.75468627e-05]
[  9.62253060e-03  -1.44131122e-03   2.27263560e-06]
[ -1.22200654e-02   2.62964695e-03  -7.64588309e-06]
[  2.51552621e-02  -4.05094228e-03   7.92464046e-06]
[ -1.18373243e-02   3.10302134e-03  -1.60918445e-05]
[  6.53135013e-03  -1.27999379e-03   4.48562675e-06]
[ -1.48666385e-02   3.10309065e-03  -8.24728895e-06]
[  1.54788359e-02  -2.41811698e-03   4.32169130e-06]
[  2.17694387e-02  -5.86175641e-03   3.28709337e-05]
[ -3.06988202e-01   6.17881898e-02  -2.34316791e-04]
[ -3.50601365e-02   1.01424343e-02  -7.00590331e-05]
[ -5.90566966e-03   1.27642908e-03  -5.97547730e-06]
[ -1.79953358e-02   5.43902765e-03  -4.26530524e-05]
[  6.16009866e-02  -1.39062811e-02   7.40177381e-05]
[  1.98470710e-02  -5.05819753e-03   2.41601812e-05]
[  1.24968523e-02  -2.38146551e-03   7.68112615e-06]
[ -3.85327490e-02   9.97996172e-03  -4.99664398e-05]
[ -2.46167821e-02   4.76688266e-03  -1.61227839e-05]
[ -2.65978542e-02   6.75945898e-03  -3.20191648e-05]
[ -2.60405108e-01   4.94841450e-02  -1.58273865e-04]
[  1.34020703e-02  -3.41418719e-03   1.62874936e-05]
[  2.72534990e-03  -5.19138261e-04   1.67232822e-06]
[ -4.48519850e-02   1.09984806e-02  -4.69295440e-05]
[ -2.36670619e-01   4.34041278e-02  -1.24943241e-04]
[  8.92946126e-03  -2.05450889e-03   7.27258564e-06]
[ -1.47256205e-02   2.53472153e-03  -6.04461574e-06]
[ 0.02605885 -0.00964124  0.00013496]
[-0.88480461  0.24407668 -0.00233921]
[ -3.59199734e-02   9.91957749e-03  -5.98728684e-05]
[  3.11177369e-01  -6.42242637e-02   2.62341070e-04]
[ -5.28610414e-03   1.17780104e-03  -3.79376339e-06]
[ -2.23963460e-02   3.73380213e-03  -8.09705628e-06]
[  5.73569936e-03  -1.43687846e-03   6.52677037e-06]
[ -3.71162154e-02   6.95317422e-03  -2.13182673e-05]
[  1.29726464e-02  -2.93889896e-03   9.94044135e-06]
[ -1.60040101e-01   2.71262719e-02  -6.17918879e-05]
[  3.09715787e-02  -9.36628037e-03   7.35693626e-05]
[ -1.56662073e-02   3.53857086e-03  -1.88652353e-05]
[  6.98246191e-03  -1.66484046e-03   6.54303497e-06]
[ -2.36450613e-02   4.21699347e-03  -1.11737609e-05]
[  2.95734114e-02  -7.38755631e-03   3.32788881e-05]
[  5.03532643e-02  -9.40631500e-03   2.85988063e-05]
[  9.47948873e-04  -2.52196324e-04   1.36549864e-06]
[ -8.68718588e-03   1.72769944e-03  -6.32395817e-06]
[ -2.21692331e-02   5.70741595e-03  -2.80780808e-05]
[ -2.36363595e-01   4.54976308e-02  -1.51183710e-04]
[  3.64586184e-02  -9.74386275e-03   5.34621827e-05]
[  1.05430617e-01  -2.10630332e-02   7.81374644e-05]
[-0.0386568   0.01222143 -0.00010912]
[ 0.48216427 -0.11381578  0.0006909 ]
[ -1.50479412e-02   3.39701724e-03  -1.13712094e-05]
[ -3.16041872e-02   5.33798739e-03  -1.20330587e-05]
[ -9.20691360e-04   2.09184530e-04  -7.13593916e-07]
[ -1.79361453e-02   3.04890593e-03  -7.00505147e-06]
[  1.06419686e-02  -2.63223460e-03   1.15191490e-05]
[ -4.58719830e-04   8.48527369e-05  -2.50563911e-07]
[ -4.99790358e-04   1.53959397e-04  -1.27551869e-06]
[ -6.26670981e-02   1.44164245e-02  -8.11203703e-05]
[  1.07455204e-02  -2.66834746e-03   1.18126979e-05]
[ -2.85944244e-01   5.31010126e-02  -1.58638062e-04]
[  4.71047351e-03  -1.24978021e-03   6.71323779e-06]
[  3.61879759e-01  -7.17754252e-02   2.60621326e-04]
[  4.02076578e-02  -1.03451865e-02   5.08052826e-05]
[ -3.16312083e-01   6.08507811e-02  -2.01845161e-04]
[ -2.09712885e-03   5.32769214e-04  -2.52112801e-06]
[  1.12210466e-02  -2.13156931e-03   6.81077453e-06]
[ 0.06201727 -0.01726129  0.00010659]
[ -3.57596742e-01   7.43817005e-02  -3.10908037e-04]
[ -4.64030752e-02   1.20775586e-02  -6.13418617e-05]
[ -2.52301197e-02   4.90957150e-03  -1.68474190e-05]
[  3.60418456e-02  -8.96925431e-03   3.99573836e-05]
[  2.04799055e-01  -3.81134912e-02   1.14588182e-04]
[ -1.30294706e-02   2.28169659e-03  -3.60896508e-06]
[  6.24467596e-03  -8.18949781e-04   8.68836381e-07]
[  6.46546038e-03  -1.82770982e-03   1.18072146e-05]
[ -3.24125795e-03   6.84681619e-04  -2.99547716e-06]
[ -1.98752700e-02   5.47797527e-03  -3.28763813e-05]
[ -4.13384677e-01   8.51531733e-02  -3.45835810e-04]
[ -1.27328200e-02   2.52909539e-03  -5.80670926e-06]
[  2.64614203e-02  -3.93453006e-03   6.06968371e-06]
[ -2.34646538e-02   5.58466796e-03  -2.18330571e-05]
[  1.91837606e-01  -3.41522489e-02   9.00137821e-05]
[  3.40092313e-02  -1.00467653e-02   7.37428972e-05]
[  2.11059982e-02  -4.65768363e-03   2.31855887e-05]
[  1.16025223e-03  -3.07271519e-04   1.64168711e-06]
[  2.82687431e-01  -5.59659152e-02   2.02118312e-04]
[  1.36879999e-02  -3.04867623e-03   9.80899917e-06]
[ -1.92548150e-01   3.20884921e-02  -6.95083608e-05]
[  9.58274910e-03  -2.26817604e-03   8.72501226e-06]
[  2.45210120e-02  -4.34149532e-03   1.12576663e-05]
[ -6.62406846e-03   2.00649585e-03  -1.58350450e-05]
[ 0.35196419 -0.07962819  0.00042656]
[ -1.02245670e-02   2.15997488e-03  -5.94779013e-06]
[ -1.11257461e-01   1.75901398e-02  -3.25781469e-05]
[ -3.99257890e-03   1.05616009e-03  -5.62416806e-06]
[  2.40907118e-01  -4.76404722e-02   1.71476705e-04]
[ -9.01783639e-04   2.19555221e-04  -9.17383727e-07]
[  2.34127452e-01  -4.26329836e-02   1.20156875e-04]
[ -2.89912344e-02   6.67169353e-03  -2.36304095e-05]
[  1.51569557e-02  -2.60948518e-03   6.22657774e-06]
[ -1.31824971e-03   3.45803921e-04  -1.79692201e-06]
[ -2.66617772e-01   5.22869396e-02  -1.83609349e-04]
[ 0.12506963 -0.04083394  0.00039996]
[-0.4642206   0.11313349 -0.00075432]
[ 0.14667371 -0.04624534  0.0004097 ]
[ 0.49782279 -0.11719562  0.0007058 ]
[  2.58802651e-03  -6.74276151e-04   3.43471100e-06]
[ -2.21589221e-02   4.31625451e-03  -1.48553522e-05]
[  1.81743257e-02  -4.19366950e-03   1.49711329e-05]
[ -1.59105549e-02   2.74656051e-03  -6.60591981e-06]
[ -2.01903089e-02   5.44662905e-03  -3.07082464e-05]
[  4.58791463e-02  -9.25120775e-03   3.52745915e-05]
[ -3.03827927e-02   6.59307098e-03  -1.96482030e-05]
[  1.25819575e-02  -2.04313098e-03   4.09723835e-06]
[  2.48140031e-03  -5.67417122e-04   1.97251623e-06]
[  2.68434490e-02  -4.59230132e-03   1.07535291e-05]
[  2.31162618e-02  -6.57074828e-03   4.31314012e-05]
[ -3.77580662e-01   8.01970116e-02  -3.56572897e-04]
[ -5.78580388e-02   1.61002722e-02  -9.93579096e-05]
[  6.45485745e-02  -1.34235890e-02   5.60746735e-05]
[ -2.92460251e-02   7.27462024e-03  -3.23630008e-05]
[  2.37735878e-01  -4.42222453e-02   1.32768385e-04]
[ -3.26121832e-02   7.52611923e-03  -2.68778177e-05]
[ -1.94243020e-01   3.35354647e-02  -8.06885766e-05]
[ -7.07978375e-03   1.53574314e-03  -4.57170685e-06]
[  1.45791331e-01  -2.36656759e-02   4.74062559e-05]
[ -9.00827184e-03   2.35904690e-03  -1.21979016e-05]
[ -1.97424025e-01   3.86519446e-02  -1.35052449e-04]
[ -2.51289388e-02   6.54410568e-03  -3.32919792e-05]
[ -2.71935154e-01   5.29458560e-02  -1.81986721e-04]
[  1.05563575e-02  -2.35963960e-03   7.67267461e-06]
[ -1.46554837e-01   2.45111345e-02  -5.36623028e-05]
[ -1.83126072e-02   5.84655518e-03  -5.36969066e-05]
[-0.42917969  0.10229781 -0.00063899]
[  2.15176507e-02  -5.35428930e-03   2.38461719e-05]
[ -2.18603648e-01   4.06786080e-02  -1.22265039e-04]
[  2.59082608e-02  -5.86571051e-03   1.98032895e-05]
[  1.61459413e-01  -2.73496776e-02   6.21847318e-05]
[ -3.52106700e-02   9.27351122e-03  -4.87542859e-05]
[  1.14224586e-01  -2.24900846e-02   7.99113755e-05]
[ -3.50290525e-02   8.48939982e-03  -3.49982788e-05]
[  1.80787438e-02  -3.27702056e-03   9.11168286e-06]
[ -1.51956464e-02   4.10492486e-03  -2.32371877e-05]
[  3.00293177e-03  -6.06353705e-04   2.32144145e-06]
[  2.25690359e-02  -5.81213398e-03   2.86189803e-05]
[ -1.85140819e-01   3.56486726e-02  -1.18564227e-04]
[  9.75647978e-03  -2.19169537e-03   7.23132568e-06]
[ -1.79270504e-01   3.01312977e-02  -6.69426291e-05]
[ -8.44216280e-03   1.76558796e-03  -4.71979653e-06]
[  8.48608395e-03  -1.32830221e-03   2.38783958e-06]
[ -1.83004990e-02   5.77917628e-03  -5.14335467e-05]
[  7.46943252e-03  -1.76118926e-03   1.06556810e-05]
[ -2.44317510e-02   8.35933199e-03  -9.36476874e-05]
[-0.13299867  0.03395359 -0.00025944]
[-0.07579357  0.02033308 -0.0001128 ]
[ -1.63382942e-02   3.27635343e-03  -1.22900527e-05]
[ -4.15671354e-02   1.05385923e-02  -4.95749054e-05]
[ -5.43413396e-02   1.03019601e-02  -3.27203778e-05]
[  1.33622788e-02  -3.03443079e-03   1.03360827e-05]
[ -1.59639546e-01   2.71229935e-02  -6.22239629e-05]
[  4.04852060e-02  -9.89274687e-03   4.17781277e-05]
[  2.04860715e-01  -3.74388881e-02   1.06656499e-04]
[  5.10419433e-03  -1.28520625e-03   5.92543030e-06]
[ -2.57754054e-02   4.85316146e-03  -1.51048590e-05]
[ -7.15654063e-04   1.53973479e-04  -4.47442509e-07]
[  1.35561549e-01  -2.18264418e-02   4.26742689e-05]
[  3.47479438e-02  -9.13475943e-03   4.77668361e-05]
[  2.97391107e-01  -5.84470146e-02   2.06547067e-04]
[ -2.05808243e-02   5.45960259e-03  -2.93125061e-05]
[  2.81511919e-02  -5.58261765e-03   2.02611297e-05]
[  1.80980400e-02  -4.69946298e-03   2.37061566e-05]
[ -4.82283926e-02   9.36304879e-03  -3.19091686e-05]
[ -3.52492874e-02   1.02987807e-02  -7.32137784e-05]
[ 0.34926201 -0.07623521  0.00036742]
[ -7.17612099e-03   2.08173277e-03  -1.44959150e-05]
[-0.40044792  0.08679042 -0.00040962]
[ -1.39554221e-02   3.76117743e-03  -2.11482404e-05]
[ -3.66612120e-02   7.38564516e-03  -2.80843125e-05]
[ -3.98792335e-02   9.98712880e-03  -4.53221014e-05]
[  3.15034783e-01  -5.89982175e-02   1.80715478e-04]
[ -8.01380652e-04   2.26748444e-04  -1.46872157e-06]
[  3.25081268e-01  -6.87324820e-02   3.01513404e-04]
[ 0.09632518 -0.02804983  0.00019749]
[  7.26312919e-02  -1.58013109e-02   7.54152352e-05]
[ -6.90450549e-03   1.80282749e-03  -9.24243724e-06]
[ -2.37908665e-01   4.64425016e-02  -1.60877996e-04]
[  2.32912670e-02  -5.30601494e-03   1.82430588e-05]
[ -2.52416230e-01   4.30216117e-02  -9.96288472e-05]
[  1.31109614e-02  -2.79941048e-03   7.95441114e-06]
[ -1.50238391e-02   2.40065852e-03  -4.58885109e-06]

In [4]:
def JMT(start, end, T):
    a_0, a_1, a_2 = start[0], start[1], start[2] / 2.0
    c_0 = a_0 + a_1 * T + a_2 * T**2
    c_1 = a_1 + a_2 * T
    c_2 = 2 * a_2
    
    A = np.array([
            [  T**3,   T**4,    T**5],
            [3*T**2, 4*T**3,  5*T**4],
            [6*T,   12*T**2, 20*T**5],
        ])
    B = np.array([
            end[0] - c_0,
            end[1] - c_1,
            end[2] - c_2
        ])
    a_3_4_5 = np.linalg.solve(A,B)
    print a_3_4_5
    alphas = np.concatenate([np.array([a_0, a_1, a_2]), a_3_4_5])
    return alphas

def to_equation(coefficients):
    def f(t):
        total = 0.0
        for i, c in enumerate(coefficients): 
            total += c * t ** i
        return total
    return f

def PTG(start_s, start_d, end_s, end_d, T, time_step = 0.1):
    s_coefficients = JMT(start_s, end_s, T)
    d_coefficients = JMT(start_d, end_d, T)
    print "s_coefficients: {}".format(s_coefficients)
    s = to_equation(s_coefficients)
    d = to_equation(d_coefficients)
    coords = []
    t = 0
    while t <= T:
        new = (s(t), d(t), t)
        coords.append(new)
        t += time_step
    return coords

In [73]:
cnt = Counter()
for l in TRAIN_LABELS + TEST_LABEL:
    cnt[l] += 1

In [74]:
class Predictor():
    def __init__(self):


Out[74]:
Counter({'keep': 464, 'left': 253, 'right': 283})

In [71]:
transform_label = {'left' : -1, 'keep' : 0, 'right' : 1}
Y = [transform_label[l] for l in TEST_LABEL]
X = [process_obs(td[T]) for td in TEST_DATA]
clf.score(X, Y)


Out[71]:
0.46400000000000002

In [66]:
clf.predict_proba(X[20])


Out[66]:
array([[  3.04824214e-04,   0.00000000e+00,   9.99695176e-01]])

In [67]:
Y[20]


Out[67]:
1

In [9]:
COL = {'keep' : 'black', 'left'  : 'red', 'right':'blue'}

for coords, label in zip(trajs, labels):
    S = [p[0] for p in coords]
    D = [p[1] for p in coords]
    S_DOT = [p[2] for p in coords]
    D_DOT = [p[3] for p in coords]
    print label
    color = COL[label]
    plt.scatter(S,D,color=color)
plt.show()


keep
left
right
left
right
keep
keep
keep
keep
left
keep
keep
keep
keep
left
keep
left
keep
left
right
keep
keep
left
keep
keep
keep
keep
left
keep
right
keep
keep
keep
keep
left
right
left
keep
left
keep
right
keep
keep
left
keep
right
keep
right
left
keep
right
left
right
right
keep
left
left
right
keep
keep
keep
keep
keep
left
keep
right
left
right
left
left
left
left
left
right
keep
keep
keep
left
left
keep
keep
keep
keep
keep
left
right
keep
right
right
keep
keep
right
left
keep
right
right
keep
left
left
keep

In [9]:
coords = generate_lane_change(1,0)
X = [p[0] for p in coords]
Y = [p[1] for p in coords]


[ 0.  0. -0.]
[ -1.28796680e-01   1.95186722e-02  -3.18672199e-05]

In [10]:
plt.scatter(X,Y)


Out[10]:
<matplotlib.collections.PathCollection at 0x107ccb1d0>

In [11]:
plt.show()


//anaconda/lib/python2.7/site-packages/matplotlib/collections.py:590: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison
  if self._edgecolors == str('face'):