In [1]:
%pylab


Using matplotlib backend: TkAgg
Populating the interactive namespace from numpy and matplotlib

In [2]:
from explauto.environment.toys.discrete_1d_progress import Discrete1dProgress
from explauto.agent import Agent
from explauto.interest_model.competences import competence_bool
from explauto.experiment import Experiment

In [3]:
from explauto import Environment, SensorimotorModel, InterestModel, Agent, Experiment
from explauto.sensorimotor_model.discrete import LidstoneModel
m_card = 7
s_card = 7
env = Discrete1dProgress(m_card=m_card, s_card=s_card)

sm_model = LidstoneModel(m_card, s_card)
im_model = InterestModel.from_configuration(env.conf, env.conf.s_dims, 'random')
ag = Agent(env.conf, sm_model, im_model)

expe = Experiment(env, ag)

In [1]:
%pylab inline


Populating the interactive namespace from numpy and matplotlib

In [2]:
class TaskCompetences(object):
    def __init__(self, time_step=0.01):
        self.score_function = lambda t, p, b, c: c * (1 - exp( - t /  p)) + b
        self.scores = []
        self.scores.append([0.05, 0., 1.])
        self.scores.append([0.2, 0., 1.])
        self.scores.append([0.5, 0., 2.])        
        self.times = zeros(len(self.scores))
        self.time_step = time_step
    def competence(self, task_index):
        score = self.score_function(self.times[task_index], *self.scores[task_index])
        self.times[task_index] += self.time_step
        return score

In [3]:
from explauto.interest_model.discrete_progress import DiscreteProgress
tasks = TaskCompetences()
#tasks.scores = [tasks.scores[0]] * len(tasks.scores)

In [4]:
y_all = []
ydiff_all = []
x = linspace(0, 1, 100)
for p, b, c in tasks.scores:
    y = []
    for xx in x:
        y.append(tasks.score_function(xx, p, b, c))
    y_all.append(y)
    ydiff = diff(y)
    ydiff_all.append(ydiff)
    subplot(211)
    plot(x, y)
    title('scores with time')
    subplot(212)
    plot(x[:-1], ydiff)
    title('score derivates with time')



In [5]:
progress_tracker = DiscreteProgress(None, x_card=len(tasks.scores), win_size=10, measure=None)

In [6]:
task_index_history = []
progress_history = []
for i in range(300):
    #print progress_tracker.queues
    task_index = progress_tracker.sample(temp=3.)[0]
    progress_history.append(progress_tracker.w)
    #print task_index
    score = tasks.competence(task_index)
    #print score
    progress_tracker.update_from_index_and_competence(task_index, score)
    
    task_index_history.append(task_index)

In [7]:
progress_history = array(progress_history)
plot(progress_history / sum(progress_history, axis=1).reshape(-1, 1).repeat(3, axis=1))
ylim([0,1])


Out[7]:
(0, 1)

In [21]:
progress_history = array(progress_history)
print(progress_history.shape)
sum(progress_history, axis=1).reshape(-1, 1).repeat(3, axis=1)


(300, 3)
Out[21]:
array([[ 0.5       ,  0.5       ,  0.5       ],
       [ 0.91327418,  0.91327418,  0.91327418],
       [ 1.91327418,  1.91327418,  1.91327418],
       [ 2.71085328,  2.71085328,  2.71085328],
       [ 2.71085328,  2.71085328,  2.71085328],
       [ 2.75843457,  2.75843457,  2.75843457],
       [ 2.83849558,  2.83849558,  2.83849558],
       [ 2.93737182,  2.93737182,  2.93737182],
       [ 3.52773764,  3.52773764,  3.52773764],
       [ 3.63306499,  3.63306499,  3.63306499],
       [ 3.73365593,  3.73365593,  3.73365593],
       [ 3.81938757,  3.81938757,  3.81938757],
       [ 4.13932769,  4.13932769,  4.13932769],
       [ 4.51778273,  4.51778273,  4.51778273],
       [ 4.73810284,  4.73810284,  4.73810284],
       [ 4.7998156 ,  4.7998156 ,  4.7998156 ],
       [ 4.9617554 ,  4.9617554 ,  4.9617554 ],
       [ 5.07647611,  5.07647611,  5.07647611],
       [ 5.10588208,  5.10588208,  5.10588208],
       [ 5.10931564,  5.10931564,  5.10931564],
       [ 5.09891559,  5.09891559,  5.09891559],
       [ 5.03982684,  5.03982684,  5.03982684],
       [ 4.7552856 ,  4.7552856 ,  4.7552856 ],
       [ 4.69829395,  4.69829395,  4.69829395],
       [ 4.64672578,  4.64672578,  4.64672578],
       [ 4.53346182,  4.53346182,  4.53346182],
       [ 4.29835384,  4.29835384,  4.29835384],
       [ 4.25169303,  4.25169303,  4.25169303],
       [ 3.88984552,  3.88984552,  3.88984552],
       [ 3.37551547,  3.37551547,  3.37551547],
       [ 3.33329502,  3.33329502,  3.33329502],
       [ 3.29509238,  3.29509238,  3.29509238],
       [ 3.2605252 ,  3.2605252 ,  3.2605252 ],
       [ 2.51215587,  2.51215587,  2.51215587],
       [ 2.48087819,  2.48087819,  2.48087819],
       [ 1.98763442,  1.98763442,  1.98763442],
       [ 1.9593332 ,  1.9593332 ,  1.9593332 ],
       [ 0.9727583 ,  0.9727583 ,  0.9727583 ],
       [ 0.96611623,  0.96611623,  0.96611623],
       [ 0.94815766,  0.94815766,  0.94815766],
       [ 0.93107494,  0.93107494,  0.93107494],
       [ 0.90546695,  0.90546695,  0.90546695],
       [ 0.88921736,  0.88921736,  0.88921736],
       [ 0.87376028,  0.87376028,  0.87376028],
       [ 0.85905705,  0.85905705,  0.85905705],
       [ 0.8525465 ,  0.8525465 ,  0.8525465 ],
       [ 0.83856036,  0.83856036,  0.83856036],
       [ 0.81538928,  0.81538928,  0.81538928],
       [ 0.79442323,  0.79442323,  0.79442323],
       [ 0.78111919,  0.78111919,  0.78111919],
       [ 0.76214832,  0.76214832,  0.76214832],
       [ 0.74498277,  0.74498277,  0.74498277],
       [ 0.73232758,  0.73232758,  0.73232758],
       [ 0.72594595,  0.72594595,  0.72594595],
       [ 0.71041391,  0.71041391,  0.71041391],
       [ 0.69635995,  0.69635995,  0.69635995],
       [ 0.68364339,  0.68364339,  0.68364339],
       [ 0.67213698,  0.67213698,  0.67213698],
       [ 0.66009899,  0.66009899,  0.66009899],
       [ 0.65384373,  0.65384373,  0.65384373],
       [ 0.64771232,  0.64771232,  0.64771232],
       [ 0.64170233,  0.64170233,  0.64170233],
       [ 0.63581135,  0.63581135,  0.63581135],
       [ 0.62436046,  0.62436046,  0.62436046],
       [ 0.61394902,  0.61394902,  0.61394902],
       [ 0.6030566 ,  0.6030566 ,  0.6030566 ],
       [ 0.59269541,  0.59269541,  0.59269541],
       [ 0.58692107,  0.58692107,  0.58692107],
       [ 0.57750042,  0.57750042,  0.57750042],
       [ 0.56764455,  0.56764455,  0.56764455],
       [ 0.56198455,  0.56198455,  0.56198455],
       [ 0.55260936,  0.55260936,  0.55260936],
       [ 0.54706143,  0.54706143,  0.54706143],
       [ 0.53853727,  0.53853727,  0.53853727],
       [ 0.52961931,  0.52961931,  0.52961931],
       [ 0.52113628,  0.52113628,  0.52113628],
       [ 0.5134233 ,  0.5134233 ,  0.5134233 ],
       [ 0.505354  ,  0.505354  ,  0.505354  ],
       [ 0.49837501,  0.49837501,  0.49837501],
       [ 0.49206015,  0.49206015,  0.49206015],
       [ 0.48438439,  0.48438439,  0.48438439],
       [ 0.47867048,  0.47867048,  0.47867048],
       [ 0.47323241,  0.47323241,  0.47323241],
       [ 0.46790203,  0.46790203,  0.46790203],
       [ 0.46267719,  0.46267719,  0.46267719],
       [ 0.45537578,  0.45537578,  0.45537578],
       [ 0.45020562,  0.45020562,  0.45020562],
       [ 0.4432603 ,  0.4432603 ,  0.4432603 ],
       [ 0.43813892,  0.43813892,  0.43813892],
       [ 0.43153234,  0.43153234,  0.43153234],
       [ 0.42651237,  0.42651237,  0.42651237],
       [ 0.42022799,  0.42022799,  0.42022799],
       [ 0.4142501 ,  0.4142501 ,  0.4142501 ],
       [ 0.40957194,  0.40957194,  0.40957194],
       [ 0.40533897,  0.40533897,  0.40533897],
       [ 0.4004184 ,  0.4004184 ,  0.4004184 ],
       [ 0.39559527,  0.39559527,  0.39559527],
       [ 0.39176512,  0.39176512,  0.39176512],
       [ 0.38607878,  0.38607878,  0.38607878],
       [ 0.38066976,  0.38066976,  0.38066976],
       [ 0.37594213,  0.37594213,  0.37594213],
       [ 0.37247647,  0.37247647,  0.37247647],
       [ 0.3693406 ,  0.3693406 ,  0.3693406 ],
       [ 0.36419539,  0.36419539,  0.36419539],
       [ 0.35930111,  0.35930111,  0.35930111],
       [ 0.35466709,  0.35466709,  0.35466709],
       [ 0.35001151,  0.35001151,  0.35001151],
       [ 0.34558298,  0.34558298,  0.34558298],
       [ 0.34137044,  0.34137044,  0.34137044],
       [ 0.33736334,  0.33736334,  0.33736334],
       [ 0.33452589,  0.33452589,  0.33452589],
       [ 0.33071422,  0.33071422,  0.33071422],
       [ 0.32814679,  0.32814679,  0.32814679],
       [ 0.32452102,  0.32452102,  0.32452102],
       [ 0.32107208,  0.32107208,  0.32107208],
       [ 0.31652983,  0.31652983,  0.31652983],
       [ 0.31420672,  0.31420672,  0.31420672],
       [ 0.31092599,  0.31092599,  0.31092599],
       [ 0.30647367,  0.30647367,  0.30647367],
       [ 0.30437164,  0.30437164,  0.30437164],
       [ 0.30000749,  0.30000749,  0.30000749],
       [ 0.29688676,  0.29688676,  0.29688676],
       [ 0.29498476,  0.29498476,  0.29498476],
       [ 0.29070703,  0.29070703,  0.29070703],
       [ 0.2877385 ,  0.2877385 ,  0.2877385 ],
       [ 0.28354547,  0.28354547,  0.28354547],
       [ 0.28072172,  0.28072172,  0.28072172],
       [ 0.27803568,  0.27803568,  0.27803568],
       [ 0.27392568,  0.27392568,  0.27392568],
       [ 0.26989706,  0.26989706,  0.26989706],
       [ 0.26734202,  0.26734202,  0.26734202],
       [ 0.26491159,  0.26491159,  0.26491159],
       [ 0.26096275,  0.26096275,  0.26096275],
       [ 0.25865085,  0.25865085,  0.25865085],
       [ 0.25645171,  0.25645171,  0.25645171],
       [ 0.25473071,  0.25473071,  0.25473071],
       [ 0.25086006,  0.25086006,  0.25086006],
       [ 0.24706605,  0.24706605,  0.24706605],
       [ 0.24550882,  0.24550882,  0.24550882],
       [ 0.24341694,  0.24341694,  0.24341694],
       [ 0.23969805,  0.23969805,  0.23969805],
       [ 0.23770819,  0.23770819,  0.23770819],
       [ 0.23406294,  0.23406294,  0.23406294],
       [ 0.23217012,  0.23217012,  0.23217012],
       [ 0.23076109,  0.23076109,  0.23076109],
       [ 0.22718803,  0.22718803,  0.22718803],
       [ 0.22591308,  0.22591308,  0.22591308],
       [ 0.22241077,  0.22241077,  0.22241077],
       [ 0.22061026,  0.22061026,  0.22061026],
       [ 0.21889757,  0.21889757,  0.21889757],
       [ 0.21546461,  0.21546461,  0.21546461],
       [ 0.21209962,  0.21209962,  0.21209962],
       [ 0.210946  ,  0.210946  ,  0.210946  ],
       [ 0.20764765,  0.20764765,  0.20764765],
       [ 0.20601848,  0.20601848,  0.20601848],
       [ 0.20278544,  0.20278544,  0.20278544],
       [ 0.20123573,  0.20123573,  0.20123573],
       [ 0.19806671,  0.19806671,  0.19806671],
       [ 0.19702287,  0.19702287,  0.19702287],
       [ 0.1939166 ,  0.1939166 ,  0.1939166 ],
       [ 0.19244247,  0.19244247,  0.19244247],
       [ 0.18939771,  0.18939771,  0.18939771],
       [ 0.18799547,  0.18799547,  0.18799547],
       [ 0.18666162,  0.18666162,  0.18666162],
       [ 0.18539283,  0.18539283,  0.18539283],
       [ 0.18240836,  0.18240836,  0.18240836],
       [ 0.18120144,  0.18120144,  0.18120144],
       [ 0.18025694,  0.18025694,  0.18025694],
       [ 0.17910888,  0.17910888,  0.17910888],
       [ 0.17825426,  0.17825426,  0.17825426],
       [ 0.17532888,  0.17532888,  0.17532888],
       [ 0.17455559,  0.17455559,  0.17455559],
       [ 0.17385588,  0.17385588,  0.17385588],
       [ 0.17276382,  0.17276382,  0.17276382],
       [ 0.1721307 ,  0.1721307 ,  0.1721307 ],
       [ 0.16926325,  0.16926325,  0.16926325],
       [ 0.16822445,  0.16822445,  0.16822445],
       [ 0.16541378,  0.16541378,  0.16541378],
       [ 0.16442564,  0.16442564,  0.16442564],
       [ 0.1634857 ,  0.1634857 ,  0.1634857 ],
       [ 0.16073068,  0.16073068,  0.16073068],
       [ 0.15803022,  0.15803022,  0.15803022],
       [ 0.15713611,  0.15713611,  0.15713611],
       [ 0.15656324,  0.15656324,  0.15656324],
       [ 0.15391625,  0.15391625,  0.15391625],
       [ 0.15132167,  0.15132167,  0.15132167],
       [ 0.15047118,  0.15047118,  0.15047118],
       [ 0.14966216,  0.14966216,  0.14966216],
       [ 0.14889259,  0.14889259,  0.14889259],
       [ 0.14816056,  0.14816056,  0.14816056],
       [ 0.14746423,  0.14746423,  0.14746423],
       [ 0.14694588,  0.14694588,  0.14694588],
       [ 0.14440268,  0.14440268,  0.14440268],
       [ 0.14374031,  0.14374031,  0.14374031],
       [ 0.14311024,  0.14311024,  0.14311024],
       [ 0.14251091,  0.14251091,  0.14251091],
       [ 0.1419408 ,  0.1419408 ,  0.1419408 ],
       [ 0.1413985 ,  0.1413985 ,  0.1413985 ],
       [ 0.14092947,  0.14092947,  0.14092947],
       [ 0.13843663,  0.13843663,  0.13843663],
       [ 0.13599315,  0.13599315,  0.13599315],
       [ 0.13556876,  0.13556876,  0.13556876],
       [ 0.13518475,  0.13518475,  0.13518475],
       [ 0.1346689 ,  0.1346689 ,  0.1346689 ],
       [ 0.1322738 ,  0.1322738 ,  0.1322738 ],
       [ 0.12992613,  0.12992613,  0.12992613],
       [ 0.12943544,  0.12943544,  0.12943544],
       [ 0.12908797,  0.12908797,  0.12908797],
       [ 0.12862121,  0.12862121,  0.12862121],
       [ 0.12632003,  0.12632003,  0.12632003],
       [ 0.12600563,  0.12600563,  0.12600563],
       [ 0.12572115,  0.12572115,  0.12572115],
       [ 0.12346553,  0.12346553,  0.12346553],
       [ 0.12302153,  0.12302153,  0.12302153],
       [ 0.12259919,  0.12259919,  0.12259919],
       [ 0.12219744,  0.12219744,  0.12219744],
       [ 0.11998649,  0.11998649,  0.11998649],
       [ 0.11972908,  0.11972908,  0.11972908],
       [ 0.11949617,  0.11949617,  0.11949617],
       [ 0.11928542,  0.11928542,  0.11928542],
       [ 0.11890327,  0.11890327,  0.11890327],
       [ 0.1167361 ,  0.1167361 ,  0.1167361 ],
       [ 0.11461184,  0.11461184,  0.11461184],
       [ 0.11424832,  0.11424832,  0.11424832],
       [ 0.11405763,  0.11405763,  0.11405763],
       [ 0.11197543,  0.11197543,  0.11197543],
       [ 0.11162965,  0.11162965,  0.11162965],
       [ 0.11130072,  0.11130072,  0.11130072],
       [ 0.10925976,  0.10925976,  0.10925976],
       [ 0.10894688,  0.10894688,  0.10894688],
       [ 0.10877433,  0.10877433,  0.10877433],
       [ 0.10677378,  0.10677378,  0.10677378],
       [ 0.10647616,  0.10647616,  0.10647616],
       [ 0.10451522,  0.10451522,  0.10451522],
       [ 0.10259311,  0.10259311,  0.10259311],
       [ 0.10070906,  0.10070906,  0.10070906],
       [ 0.09886232,  0.09886232,  0.09886232],
       [ 0.09705214,  0.09705214,  0.09705214],
       [ 0.09527781,  0.09527781,  0.09527781],
       [ 0.09512169,  0.09512169,  0.09512169],
       [ 0.09483858,  0.09483858,  0.09483858],
       [ 0.09469731,  0.09469731,  0.09469731],
       [ 0.09456949,  0.09456949,  0.09456949],
       [ 0.09283029,  0.09283029,  0.09283029],
       [ 0.09271463,  0.09271463,  0.09271463],
       [ 0.09100987,  0.09100987,  0.09100987],
       [ 0.08933887,  0.08933887,  0.08933887],
       [ 0.08906957,  0.08906957,  0.08906957],
       [ 0.0888134 ,  0.0888134 ,  0.0888134 ],
       [ 0.08856973,  0.08856973,  0.08856973],
       [ 0.08846508,  0.08846508,  0.08846508],
       [ 0.08682717,  0.08682717,  0.08682717],
       [ 0.08522168,  0.08522168,  0.08522168],
       [ 0.0849899 ,  0.0849899 ,  0.0849899 ],
       [ 0.08341621,  0.08341621,  0.08341621],
       [ 0.08332151,  0.08332151,  0.08332151],
       [ 0.08323583,  0.08323583,  0.08323583],
       [ 0.08301534,  0.08301534,  0.08301534],
       [ 0.08147282,  0.08147282,  0.08147282],
       [ 0.07996083,  0.07996083,  0.07996083],
       [ 0.07847878,  0.07847878,  0.07847878],
       [ 0.07702609,  0.07702609,  0.07702609],
       [ 0.07681635,  0.07681635,  0.07681635],
       [ 0.07539242,  0.07539242,  0.07539242],
       [ 0.07399668,  0.07399668,  0.07399668],
       [ 0.07379718,  0.07379718,  0.07379718],
       [ 0.07360741,  0.07360741,  0.07360741],
       [ 0.07352988,  0.07352988,  0.07352988],
       [ 0.07345973,  0.07345973,  0.07345973],
       [ 0.07327921,  0.07327921,  0.07327921],
       [ 0.07321574,  0.07321574,  0.07321574],
       [ 0.0731583 ,  0.0731583 ,  0.0731583 ],
       [ 0.07298659,  0.07298659,  0.07298659],
       [ 0.07282325,  0.07282325,  0.07282325],
       [ 0.07266788,  0.07266788,  0.07266788],
       [ 0.07261591,  0.07261591,  0.07261591],
       [ 0.07124781,  0.07124781,  0.07124781],
       [ 0.0699068 ,  0.0699068 ,  0.0699068 ],
       [ 0.06985977,  0.06985977,  0.06985977],
       [ 0.06971198,  0.06971198,  0.06971198],
       [ 0.06966943,  0.06966943,  0.06966943],
       [ 0.06952884,  0.06952884,  0.06952884],
       [ 0.06949034,  0.06949034,  0.06949034],
       [ 0.06935661,  0.06935661,  0.06935661],
       [ 0.0692294 ,  0.0692294 ,  0.0692294 ],
       [ 0.0691084 ,  0.0691084 ,  0.0691084 ],
       [ 0.06907356,  0.06907356,  0.06907356],
       [ 0.06895846,  0.06895846,  0.06895846],
       [ 0.06764401,  0.06764401,  0.06764401],
       [ 0.06761248,  0.06761248,  0.06761248],
       [ 0.06758396,  0.06758396,  0.06758396],
       [ 0.06629553,  0.06629553,  0.06629553],
       [ 0.06618605,  0.06618605,  0.06618605],
       [ 0.06492313,  0.06492313,  0.06492313],
       [ 0.06481898,  0.06481898,  0.06481898],
       [ 0.06479317,  0.06479317,  0.06479317],
       [ 0.06476982,  0.06476982,  0.06476982],
       [ 0.06353191,  0.06353191,  0.06353191],
       [ 0.06231852,  0.06231852,  0.06231852],
       [ 0.06229739,  0.06229739,  0.06229739]])

In [96]:
ydiff_all = array(ydiff_all)
ydiff_norm = ydiff_all / sum(ydiff_all, axis=0).reshape(1, -1).repeat(3, axis=0)
ydiff_norm.shape
plot(ydiff_norm.T)


Out[96]:
[<matplotlib.lines.Line2D at 0x6788f50>,
 <matplotlib.lines.Line2D at 0x67a3110>,
 <matplotlib.lines.Line2D at 0x67a3290>]

In [59]:
from explauto.environment.toys.hierarchical_mild import f
x = linspace(0, 10, 100)
y_hist = []
for temp in linspace(1e-5, 10., 6):
    y = f(x, temp) + 100*rand()
    y_hist.append(y)
    plot(x, y)
    #plot(diff(y))



In [ ]:
plot(argmin(array(y_hist)).T)

In [24]:
argmin(array(y_hist), axis=0)


Out[24]:
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0])

In [5]:
commands = array(env.random_motors(100), dtype=int)
ms = env.dataset(commands)
plot(ms[:, 0], ms[:, 1], 'o')


Out[5]:
[<matplotlib.lines.Line2D at 0x6a78390>]

In [10]:
n_trials = 200
progr = zeros((n_trials, s_card))
for i in range(n_trials):
    expe.run(1)
    progr[i,:] = ag.i_model.progress()


---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-10-03be2db6938d> in <module>()
      2 progr = zeros((n_trials, s_card))
      3 for i in range(n_trials):
----> 4     expe.run(1)
      5     progr[i,:] = ag.i_model.progress()

/home/clement/Documents/Boulot/INRIA_FLOWERS/CODE/explauto/explauto/experiment/experiment.pyc in run(self, n_iter, bg)
     65             self._t.start()
     66         else:
---> 67             self._run(n_iter)
     68 
     69     def wait(self):

/home/clement/Documents/Boulot/INRIA_FLOWERS/CODE/explauto/explauto/experiment/experiment.pyc in _run(self, n_iter)
     83             self.notifications.queue.clear()
     84 
---> 85             m = self.ag.produce()
     86             try:
     87                 env_state = self.env.update(m)

/home/clement/Documents/Boulot/INRIA_FLOWERS/CODE/explauto/explauto/agent/agent.pyc in produce(self)
    128 
    129         x = self.choose()
--> 130         y = self.infer(self.expl_dims, self.inf_dims, x)
    131 
    132         self.m, self.s = self.extract_ms(x, y)

/home/clement/Documents/Boulot/INRIA_FLOWERS/CODE/explauto/explauto/agent/agent.pyc in infer(self, expl_dims, inf_dims, x)
     88             y = self.sensorimotor_model.infer(expl_dims,
     89                                               inf_dims,
---> 90                                               x.flatten())
     91         except ExplautoBootstrapError:
     92             logger.warning('Sensorimotor model not bootstrapped yet')

/home/clement/Documents/Boulot/INRIA_FLOWERS/CODE/explauto/explauto/sensorimotor_model/discrete.py in infer(self, in_dims, out_dims, x)
     19             p_out = self.joint_distr()[x, :]
     20         else:
---> 21             p_out = self.joint_distr()[:, x]
     22         p_out /= p_out.sum()
     23         return discrete_random_draw(p_out.flatten())

IndexError: arrays used as indices must be of integer (or boolean) type

In [ ]:
%debug


> /home/clement/Documents/Boulot/INRIA_FLOWERS/CODE/explauto/explauto/sensorimotor_model/discrete.py(21)infer()
     20         else:
---> 21             p_out = self.joint_distr()[:, x]
     22         p_out /= p_out.sum()

ipdb> p x
array([ 4.57077553])

In [8]:
clf()
plot(progr)
legend([str(i) for i in range(s_card)])
[text(n_trials * 1.1, progr[-1, x], str(x)) for x in range(s_card)]


Out[8]:
[<matplotlib.text.Text at 0x3eb5190>,
 <matplotlib.text.Text at 0x3eb3b10>,
 <matplotlib.text.Text at 0x3eb3690>,
 <matplotlib.text.Text at 0x3eb37d0>,
 <matplotlib.text.Text at 0x3eb0190>,
 <matplotlib.text.Text at 0x3eb0d10>,
 <matplotlib.text.Text at 0x3eb0150>]

In [7]:
figure()
for s in range(s_card):
    inds = nonzero(ag.i_model.choices.flatten()[:ag.i_model.t] == s)[0]
    plot(inds, ag.i_model.comps[inds], '*', ms=12)
legend([str(i) for i in range(s_card)])


Out[7]:
<matplotlib.legend.Legend at 0x40b4590>

In [9]:
2**8


Out[9]:
256

In [8]:
figure()
win = 40
spent = []
for t in range(expe_1.i_rec-40):
    spent.append([float(sum(expe_1.records[t:t+win, 1] == s))/win * 100 for s in range(s_card)])
plot(spent)
legend([str(i) for i in range(s_card)])


Out[8]:
<matplotlib.legend.Legend at 0x40c8650>

In [10]:
x = linspace(0., 1., 100)
plot(x, exp(2. * x) / exp(2.))


Out[10]:
[<matplotlib.lines.Line2D at 0x579d850>]

In [15]:
zip(range(4), x)


Out[15]:
[(0, 10), (1, 11), (2, 12), (3, 13)]

In [22]:
n_trials = 10
logs_x = zeros((n_trials, s_card))
for trial in range(n_trials):
    del ag, expe_1
    ag = Agent(**myconf)
    expe_1 = Experiment(env, ag, 0, [0,1])
    expe_1.run(10)
    logs_x[trial, :] = [(ag.choices[:ag.t, :] == x).sum() for x in range(s_card)]

In [23]:
logs_x.sum(axis=0)


Out[23]:
array([ 73.,   0.,   0.,  19.,   8.,   0.])

In [24]:
logs_x


Out[24]:
array([[  2.,   0.,   0.,   7.,   1.,   0.],
       [  1.,   0.,   0.,   9.,   0.,   0.],
       [  6.,   0.,   0.,   3.,   1.,   0.],
       [  8.,   0.,   0.,   0.,   2.,   0.],
       [  7.,   0.,   0.,   0.,   3.,   0.],
       [  9.,   0.,   0.,   0.,   1.,   0.],
       [ 10.,   0.,   0.,   0.,   0.,   0.],
       [ 10.,   0.,   0.,   0.,   0.,   0.],
       [ 10.,   0.,   0.,   0.,   0.,   0.],
       [ 10.,   0.,   0.,   0.,   0.,   0.]])

In [14]:
env.s_card


Out[14]:
6

In [23]:


In [18]:
print ag.choices[:ag.t, :]


[[ 2.]
 [ 5.]
 [ 2.]
 [ 3.]
 [ 4.]
 [ 1.]
 [ 5.]
 [ 3.]
 [ 4.]
 [ 0.]
 [ 2.]
 [ 2.]
 [ 2.]
 [ 2.]
 [ 5.]
 [ 4.]
 [ 1.]
 [ 5.]
 [ 5.]
 [ 4.]]

In [25]:
[(ag.choices[:ag.t, :] == x).sum() for x in range(s_card)]


Out[25]:
[2, 4, 11, 8, 7, 8]

In [2]:
from model.sm_model import LidstoneModel

In [3]:
model = LidstoneModel(3, 4)

In [4]:
model.counts


Out[4]:
array([[ 0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.]])

In [28]:
model.update(2, 1)

In [6]:
p


Out[6]:
array([[ 0.08333333,  0.08333333,  0.08333333,  0.08333333],
       [ 0.08333333,  0.08333333,  0.08333333,  0.08333333],
       [ 0.08333333,  0.08333333,  0.08333333,  0.08333333]])

In [36]:
from collections import deque

In [38]:
q = deque(zeros(3), 3)

In [54]:
q.append(array([2., 0.]))

In [55]:
print q


deque([array([ 0.,  0.]), array([ 1.,  0.]), array([ 2.,  0.])], maxlen=3)

In [61]:
cov(array(q), rowvar = 0)


Out[61]:
array([[ 1.,  0.],
       [ 0.,  0.]])

In [68]:
d = deque([[t, 0.] for t in range(4)])
print d


deque([[0, 0.0], [1, 0.0], [2, 0.0], [3, 0.0]])

In [69]:
cov(d, rowvar=0)


Out[69]:
array([[ 1.66666667,  0.        ],
       [ 0.        ,  0.        ]])

In [70]:
[0, 1] == 0


Out[70]:
False

In [11]:
d = array([0., 0., 0.])
d = d / d.sum()


-c:2: RuntimeWarning: invalid value encountered in divide

In [67]:
(discrete_random_draw(d, nb=10000) == 2).sum()


Out[67]:
3347

In [30]:
d


Out[30]:
array([0, 0, 0])

In [31]:
d = array([1, 1, 1])

In [34]:
d


Out[34]:
array([ 0.33333333,  0.33333333,  0.33333333])

In [ ]: