In [1]:
%matplotlib inline
import pylab
import numpy as np
import functools
import adapt_float
import adapt_spike
import adapt_fixed_small

import ctn_benchmark.control as ctrl

In [19]:
def objective(args, adapt_type):
    D = 1
    dt=0.001
    T = 20.0
    seed=2
    noise=0.1
    Kp=args['Kp']
    Kd=args['Kd']
    Ki=args['Ki']
    print 'trying', Kp, Kd, Ki
    tau_d=0.001
    period=4
    amplitude=1
    n_neurons=500
    learning_rate=1
    max_freq=1.0
    synapse=0.01
    scale_add=2
    delay=0.01
    filter=0.01

    n_neurons=256

    signal = ctrl.Signal(D, period, dt=dt, max_freq=max_freq, seed=seed)

    system = ctrl.System(D, D, dt=dt, seed=seed,
            motor_noise=noise, sense_noise=noise,
            scale_add=scale_add,
            motor_scale=10,
            motor_delay=delay, sensor_delay=delay,
            motor_filter=filter, sensor_filter=filter)

    pid = ctrl.PID(Kp, Kd, Ki, tau_d=tau_d)
    
    if adapt_type == 'float':
        adapt = adapt_float.AdaptiveFloat(n_inputs=D, n_outputs=D, n_neurons=n_neurons, seed=seed,
                                     learning_rate=1e-4)
        scale = 1.0
    elif adapt_type == 'fixed':
        adapt = adapt_fixed_small.AdaptiveFixed(n_inputs=D, n_outputs=D, n_neurons=n_neurons, seed=seed,
                                                input_bits=8,
                                                state_bits=8,
                                                extra_bits=4,
                                                decoder_offset=4,
                                                decoder_bits=8,
                                                smoothing=10,
                                                learning_rate=1e-4)
        scale = adapt.input_max
    else:
        adapt = None

    steps = int(T / dt)
    time = np.arange(steps)*dt
    data_desired = np.zeros((steps, D))
    data_actual = np.zeros((steps, D))
    data_pid = np.zeros((steps, D))
    data_error = np.zeros((steps, D))

    for i in range(steps):
        desired = signal.value(i*dt)*amplitude
        data_desired[i,:] = desired

        actual = system.state
        data_actual[i,:] = actual

        raw_pid = pid.step(actual, desired)
        data_pid[i,:] = raw_pid

        if adapt is not None:
            adjust = adapt.step(actual*scale, -raw_pid*scale)/scale
        else:
            adjust = 0
            
        system.step(raw_pid + adjust)
        
    #pylab.plot(data_desired)
    #pylab.plot(data_actual)
        
    rmse = np.sqrt(np.mean((data_desired - data_actual)**2))
        
    return dict(
        status='ok',
        loss=rmse
        )

In [20]:
objective(dict(Kp=10, Kd=2, Ki=0), adapt_type='none')


trying 10 2 0
Out[20]:
{'loss': 0.072578950102291276, 'status': 'ok'}

In [21]:
objective(dict(Kp=10, Kd=2, Ki=0), adapt_type='fixed')


trying 10 2 0
Out[21]:
{'loss': 0.072970952117567919, 'status': 'ok'}

In [22]:
objective(dict(Kp=10, Kd=2, Ki=0), adapt_type='float')


trying 10 2 0
Out[22]:
{'loss': 0.059855123871354471, 'status': 'ok'}

In [23]:
import hyperopt as hp

space = dict(
    Kp=hp.hp.lognormal('Kp', np.log(50), 1),
    Kd=hp.hp.lognormal('Kd', np.log(20), 1),
    Ki=hp.hp.lognormal('Ki', np.log(10), 1)
)

trials = hp.Trials()

In [24]:
import functools
best = hp.fmin(functools.partial(objective, adapt_type='fixed'), space=space, algo=hp.tpe.suggest, max_evals=100, trials=trials)


trying 48.0636816073 16.2987669305 2.46178366298
trying 92.5729886695 65.0721638221 4.91166141722
trying 48.0398255288 3.17482782935 3.19844471815
trying 17.9539178295 88.7210425554 15.6807818443
trying 74.8258451113 6.99973237925 20.2892323172
trying 8.43359536229 9.99713979308 13.8032569622
trying 897.076064842 28.6485988342 5.7337186985
trying 7.5009130424 14.4902233997 3.07147213852
trying 35.8666531335 16.6624350367 28.0480475585
trying 110.163033325 11.3324356337 18.6203292967
trying 54.3198948901 18.9432649878 21.4039280332
trying 128.07416403 12.3182046412 4.56613353946
trying 26.6975591582 3.63792779669 6.10351552529
trying 17.3133828666 15.4765039702 6.57461542176
trying 154.735393575 11.0165492512 20.556972767
trying 20.4648568995 20.6184201207 5.4720307666
trying 44.4673737099 18.3526371101 17.6197533515
trying 21.2545245067 17.6608414429 4.5114527937
trying 15.5493181917 11.2995708554 19.8429364432
trying 39.3132823447 20.6080107599 8.35235393022
trying 11.9957169664 299.937955177 14.095668344
trying 1.96741000535 100.361727826 49.4388039969
trying 13.6136606154 38.4257838431 24.6914215695
trying 11.7917075964 57.8519310885 30.9653276304
trying 12.4294637335 135.43393628 29.8652314298
trying 10.4976747133 41.2281306986 11.8505489652
trying 14.2878852181 171.495811664 1.0045772569
trying 9.39808608123 32.9233359876 24.7357687609
trying 206.582211683 70.0915188843 37.2192369747
trying 4.76989101797 46.7492083397 16.4489632973
trying 13.2398253462 23.4746421061 120.691286385
trying 6.02107093015 89.9374420139 9.05058612851
trying 17.0695243043 178.149949753 22.6664180068
trying 60.8229964467 7.53004453451 38.4299888476
trying 29.3780545141 39.2599965156 11.0939452957
trying 68.7775320573 55.7132820052 13.8668937254
trying 19.8038518578 27.2359903317 15.9391364981
trying 10.6389765313 75.5368563733 60.0995341405
trying 86.4147065114 124.396718683 12.7052442759
trying 14.4611517761 33.7098787054 25.3718300246
trying 15.4241493243 5.79351741267 15.2302224302
trying 24.4913397585 696.70677606 10.4755061614
trying 18.0743442812 388.923448882 8.39334037593
trying 256.448894435 9.02559497454 35.1419125046
trying 33.0521713318 24.9095227568 42.1553413194
trying 4.027903759 32.7705697352 26.3980078813
trying 22.9065216672 48.7586916042 18.4105317662
trying 45.3953282963 13.6426449618 12.6484206055
trying 15.1156295601 4.84377951303 2.74309557404
trying 16.1940906417 78.2822276302 6.8458459534
trying 11.0106016562 22.4088636664 7.0268547342
trying 18.099748048 32.5817022767 3.89246648992
trying 16.1306606971 61.3511505595 2.02243292771
trying 16.7283218953 85.488083786 1.54120984727
trying 18.8516896734 1.96299260414 0.321165127925
trying 16.384574696 77.2588408927 1.88868024905
trying 50.6304494599 63.6506401823 1.13628899736
trying 107.23334204 92.9230300956 0.609366199741
trying 7.3822296231 107.87862526 3.87107902176
trying 38.5778033242 82.2646282775 1.76234497041
trying 29.3022574034 68.0782728792 3.42162757579
trying 21.0341901092 75.567001316 4.82792219627
trying 16.011273937 99.3616225105 2.18424657072
trying 2.99366375121 57.0725838916 5.55071868082
trying 22.3294454475 63.5882278866 7.00133117282
trying 14.2331721895 34.5894617337 74.0629538975
trying 12.7859318674 51.1345324136 0.797986272406
trying 14.0704786722 46.3524875155 1.30632548349
trying 14.741786643 29.8547840171 6.0614564913
trying 18.9349550497 60.4285140141 6.29716177439
trying 13.6028683151 29.9749731642 7.81551543089
trying 11.8249160122 30.2632856606 7.87113234518
trying 14.8151119153 16.8304017398 7.48574084701
trying 13.50879572 42.2088185277 9.17199474009
trying 11.3644500139 19.5570695802 9.38441900147
trying 9.43303299519 26.8033910774 2.73005361952
trying 12.873285519 36.3784207279 5.91088022122
trying 13.7967379122 21.9244619846 8.44499350086
trying 12.4340269943 29.1359440536 5.19580391309
trying 340.276526615 21.7615793906 8.41662109145
trying 13.8115829811 25.5359318732 7.50315187738
trying 10.0458212073 24.4013503299 10.0397389183
trying 63.4513690599 27.1025499449 6.4219418528
trying 12.2708639761 26.245020129 8.80042837364
trying 8.02812593838 20.9373897096 7.24764201503
trying 11.3849946621 23.6632284669 6.65471327543
trying 17.6184845554 25.7693130517 6.04813027769
trying 15.6513872673 27.8980383514 7.90786202736
trying 156.12833927 31.4129250308 7.46471855684
trying 12.9081957317 18.0328477831 9.54482581509
trying 14.6703933141 19.0932276738 10.8344211649
trying 10.0128318322 22.4150399468 8.08684181732
trying 85.7078873383 15.6303149489 8.58664927266
trying 28.6390808835 23.2490688702 5.71473977457
trying 13.8292065156 24.7230551729 9.8581826709
trying 6.60663123699 28.3931320445 4.44697400342
trying 31.8723609385 20.190816576 7.49561291061
trying 15.2786206002 25.368871193 5.35669658693
trying 13.1288473368 31.3877825938 6.71609191536
trying 54.7885295301 21.1931264055 6.42996069594

In [25]:
best


Out[25]:
{'Kd': 29.854784017116174, 'Ki': 6.0614564913019953, 'Kp': 14.741786642982683}

In [26]:
loss = [t['result']['loss'] for t in trials.trials]
pylab.plot(loss)
pylab.xlabel('trial')
pylab.ylabel('rmse')
pylab.ylim(0, 0.15)
pylab.show()
print 'best rmse', np.min(loss)


best rmse 0.0549544908238

In [27]:
limits = [1.0, 1.05, 1.1, np.inf]
limit_labels = ['<5%', '5-10%', '10%+']
best_rmse = np.min(loss)
data = []
for t in trials.trials:
    d = {}
    d['Kp'] = t['misc']['vals']['Kp'][0]
    d['Kd'] = t['misc']['vals']['Kd'][0]
    d['Ki'] = t['misc']['vals']['Ki'][0]
    rate = t['result']['loss'] / best_rmse
    index = 0
    while limits[index+1] < rate:
        index += 1
    d['category'] = limit_labels[index]
    
    data.append(d)

In [28]:
import pandas
import seaborn as sns

In [29]:
df = pandas.DataFrame(data)

In [30]:
grid = sns.pairplot(data=df, x_vars=['Kp', 'Kd', 'Ki'], y_vars=['Kp', 'Kd', 'Ki'], diag_kind='kde', hue='category', hue_order=limit_labels)
ax_lim=50
for i, row in enumerate(grid.axes):
    for j, ax in enumerate(row):
        if i==j:
            ax.set_xlim(0, ax_lim)
        else:
            ax.set_xlim(0, ax_lim)
            ax.set_ylim(0, ax_lim)
        
best_vals = best['Kp'], best['Kd'], best['Ki']
for i, row in enumerate(grid.axes):
        for j, ax in enumerate(row):
            if i==j:
                ax.axvline(best_vals[i], c='k')
            else:
                ax.axvline(best_vals[j], c='k')
                ax.axhline(best_vals[i], c='k')



In [ ]: