Sketchbook to experiment with RNN models to model the dynamics. Once tested, the code in this notebook will be incorporated into functions/classes/ python modules.

Dynamics Model

Models the dynamics of the system p(next observation | history of observations, action)

History of observations consist of exercises a student has done and whether the student solved each of them

Action is the next exercise chosen

Next observation is whether the student gets the chosen exercise correct

We want to use an RNN to model the dynamics. Input data represents history of observations, of shape (n_students, n_timesteps, observation_vec_size)

Output represents the probability of getting next exercise correctly, of shape (n_students, n_timesteps, n_exercises)

So at each timestep, we make a prediction for all actions.

For each action, the output vector specifies the predicted probability of the student getting the chosen exercise correctly.

The target output only contains binary values.


In [16]:
import sys
print sys.executable
%load_ext autoreload
%autoreload 2
%reload_ext autoreload


/Users/lisa1010/tf_venv/bin/python
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload

In [2]:
import sonnet as snt
import tensorflow as tf
import tflearn
import numpy as np

In [3]:
import dataset_utils

Load toy data set


In [5]:
data = dataset_utils.load_data(filename="../synthetic_data/toy.pickle")
input_data_, output_mask_, target_data_ = dataset_utils.preprocess_data_for_rnn(data)

Build RNN Model


In [6]:
tflearn.init_graph()


Out[6]:
gpu_options {
}
allow_soft_placement: true

In [9]:
n_hidden = 64
n_samples, n_timesteps, n_inputdim = input_data_.shape
_,_,n_outputdim = target_data_.shape

In [10]:
print n_timesteps


50

In [12]:
print n_inputdim


20

In [13]:
print n_outputdim


10

In [15]:
graph_to_use = tf.Graph()
with graph_to_use.as_default():
    net = tflearn.input_data([None, n_timesteps, n_inputdim],dtype=tf.float32, name='input_data')
    output_mask = tflearn.input_data([None, n_timesteps, n_outputdim], dtype=tf.float32, name='output_mask')
    net = tflearn.lstm(net, n_hidden, return_seq=True, name="lstm_1")
    net = tflearn.lstm(net, n_outputdim, return_seq=True, name="lstm_2")
    net = tf.stack(net, axis=1)
    preds = net
    net = net * output_mask
    net = tflearn.regression(net, optimizer='adam', learning_rate=0.001,
                             loss='mean_square')
    model = tflearn.DNN(net, tensorboard_verbose=0)
    model.fit([ input_data_, output_mask_], target_data_, validation_set=0.1)


---------------------------------
Run id: 2PTSBL
Log directory: /tmp/tflearn_logs/
WARNING:tensorflow:Error encountered when serializing layer_tensor/lstm_1.
Type is unsupported, or the types of the items don't match field type in CollectionDef.
'list' object has no attribute 'name'
WARNING:tensorflow:Error encountered when serializing layer_tensor/lstm_2.
Type is unsupported, or the types of the items don't match field type in CollectionDef.
'list' object has no attribute 'name'
---------------------------------
Training samples: 4
Validation samples: 1
--
Training Step: 1  | time: 2.116s
| Adam | epoch: 001 | loss: 0.00000 | val_loss: 0.04748 -- iter: 4/4
--
Training Step: 2  | total loss: 0.04031 | time: 1.067s
| Adam | epoch: 002 | loss: 0.04031 | val_loss: 0.04669 -- iter: 4/4
--
Training Step: 3  | total loss: 0.04343 | time: 1.065s
| Adam | epoch: 003 | loss: 0.04343 | val_loss: 0.04589 -- iter: 4/4
--
Training Step: 4  | total loss: 0.04341 | time: 1.061s
| Adam | epoch: 004 | loss: 0.04341 | val_loss: 0.04506 -- iter: 4/4
--
Training Step: 5  | total loss: 0.04293 | time: 1.060s
| Adam | epoch: 005 | loss: 0.04293 | val_loss: 0.04419 -- iter: 4/4
--
Training Step: 6  | total loss: 0.04231 | time: 1.056s
| Adam | epoch: 006 | loss: 0.04231 | val_loss: 0.04326 -- iter: 4/4
--
Training Step: 7  | total loss: 0.04168 | time: 1.061s
| Adam | epoch: 007 | loss: 0.04168 | val_loss: 0.04223 -- iter: 4/4
--
Training Step: 8  | total loss: 0.04099 | time: 1.055s
| Adam | epoch: 008 | loss: 0.04099 | val_loss: 0.04111 -- iter: 4/4
--
Training Step: 9  | total loss: 0.04024 | time: 1.059s
| Adam | epoch: 009 | loss: 0.04024 | val_loss: 0.03985 -- iter: 4/4
--
Training Step: 10  | total loss: 0.03943 | time: 1.053s
| Adam | epoch: 010 | loss: 0.03943 | val_loss: 0.03846 -- iter: 4/4
--

In [72]:
tf.get_collection(tf.GraphKeys.INPUTS)


Out[72]:
[<tf.Tensor 'output_mask:0' shape=(?, 50, 10) dtype=int32>,
 <tf.Tensor 'InputData/X:0' shape=(?, 50, 20) dtype=float32>,
 <tf.Tensor 'output_mask_1:0' shape=(?, 50, 10) dtype=int32>,
 <tf.Tensor 'InputData_1/X:0' shape=(?, 50, 20) dtype=float32>,
 <tf.Tensor 'output_mask_2:0' shape=(?, 50, 10) dtype=int32>,
 <tf.Tensor 'InputData_2/X:0' shape=(?, 50, 20) dtype=float32>,
 <tf.Tensor 'output_mask_3:0' shape=(?, 50, 10) dtype=float32>,
 <tf.Tensor 'InputData_3/X:0' shape=(?, 50, 20) dtype=float32>,
 <tf.Tensor 'input_data:0' shape=(?, 50, 20) dtype=float32>,
 <tf.Tensor 'output_mask_4:0' shape=(?, 50, 10) dtype=float32>,
 <tf.Tensor 'input_data_1:0' shape=(?, 50, 20) dtype=float32>,
 <tf.Tensor 'output_mask_5:0' shape=(?, 50, 10) dtype=float32>,
 <tf.Tensor 'input_data_2:0' shape=(?, 50, 20) dtype=float32>,
 <tf.Tensor 'output_mask_6:0' shape=(?, 50, 10) dtype=float32>,
 <tf.Tensor 'input_data_3:0' shape=(?, 50, 20) dtype=float32>,
 <tf.Tensor 'output_mask_7:0' shape=(?, 50, 10) dtype=float32>,
 <tf.Tensor 'input_data_4:0' shape=(?, 50, 20) dtype=float32>,
 <tf.Tensor 'output_mask_8:0' shape=(?, 50, 10) dtype=float32>,
 <tf.Tensor 'output_mask_9:0' shape=(?, 50, 10) dtype=float32>,
 <tf.Tensor 'input_data_5/X:0' shape=(?, 50, 20) dtype=float32>,
 <tf.Tensor 'input_data_6/X:0' shape=(?, 50, 20) dtype=float32>,
 <tf.Tensor 'output_mask_10/X:0' shape=(?, 50, 10) dtype=float32>,
 <tf.Tensor 'output_mask_9:0' shape=(?, 50, 10) dtype=float32>,
 <tf.Tensor 'input_data_7/X:0' shape=(?, 50, 20) dtype=float32>,
 <tf.Tensor 'output_mask_11/X:0' shape=(?, 50, 10) dtype=float32>,
 <tf.Tensor 'output_mask_9:0' shape=(?, 50, 10) dtype=float32>]

In [17]:
data = dataset_utils.load_data(filename="../synthetic_data/1000stud_100seq_expert.pickle")
input_data_, output_mask_, target_data_ = dataset_utils.preprocess_data_for_rnn(data)

In [20]:
n_samples, n_timesteps, n_inputdim = input_data_.shape
_,_,n_outputdim = target_data_.shape
print n_samples
print n_timesteps
print n_inputdim
print n_outputdim


1000
100
20
10

In [22]:
graph_to_use = tf.Graph()
with graph_to_use.as_default():
    net = tflearn.input_data([None, n_timesteps, n_inputdim],dtype=tf.float32, name='input_data')
    output_mask = tflearn.input_data([None, n_timesteps, n_outputdim], dtype=tf.float32, name='output_mask')
    net = tflearn.lstm(net, n_hidden, return_seq=True, name="lstm_1")
    net = tflearn.lstm(net, n_outputdim, return_seq=True, name="lstm_2")
    net = tf.stack(net, axis=1)
    preds = net
    net = net * output_mask
    net = tflearn.regression(net, optimizer='adam', learning_rate=0.001,
                             loss='mean_square')
    model = tflearn.DNN(net, tensorboard_verbose=2)
    model.fit([ input_data_, output_mask_], target_data_, n_epoch=32, validation_set=0.1)


Training Step: 479  | total loss: 0.02480 | time: 2.358s
| Adam | epoch: 032 | loss: 0.02480 -- iter: 896/900
Training Step: 480  | total loss: 0.02479 | time: 3.469s
| Adam | epoch: 032 | loss: 0.02479 | val_loss: 0.02476 -- iter: 900/900
--

In [23]:
from dynamics_model import *

In [39]:
model = load_model(model_id="test_model", load_checkpoint=False, is_training=True)


Loading RNN dynamics model...
Directory path for tensorboard summaries: ../tensorboard_logs/test_model/
Checkpoint directory path: ../checkpoints/test_model/
Model loaded.

In [40]:
train_data = (input_data_[:,:10,:], output_mask_[:,:10,:], target_data_[:,:10,:])

In [41]:
train(model, train_data)


Training Step: 44  | total loss: 0.03538 | time: 0.681s
| Adam | epoch: 003 | loss: 0.03538 -- iter: 896/900
Training Step: 45  | total loss: 0.03482 | time: 1.726s
| Adam | epoch: 003 | loss: 0.03482 | val_loss: 0.03224 -- iter: 900/900
--

KeyboardInterruptTraceback (most recent call last)
<ipython-input-41-175a1436c9bb> in <module>()
----> 1 train(model, train_data)

/Users/lisa1010/dev/smart-tutor/code/dynamics_model.pyc in train(model, train_data, load_checkpoint)
     95     date_time_string = datetime.datetime.now().strftime("%m-%d-%Y_%H-%M-%S")
     96     run_id = "{}".format(date_time_string)
---> 97     model.fit([input_data, output_mask], output_data, n_epoch=64, validation_set=0.1)
     98 
     99 

/Users/lisa1010/tf_venv/lib/python2.7/site-packages/tflearn/models/dnn.pyc in fit(self, X_inputs, Y_targets, n_epoch, validation_set, show_metric, batch_size, shuffle, snapshot_epoch, snapshot_step, excl_trainops, validation_batch_size, run_id, callbacks)
    213                          excl_trainops=excl_trainops,
    214                          run_id=run_id,
--> 215                          callbacks=callbacks)
    216 
    217     def predict(self, X):

/Users/lisa1010/tf_venv/lib/python2.7/site-packages/tflearn/helpers/trainer.pyc in fit(self, feed_dicts, n_epoch, val_feed_dicts, show_metric, snapshot_step, snapshot_epoch, shuffle_all, dprep_dict, daug_dict, excl_trainops, run_id, callbacks)
    344 
    345                     # Epoch end
--> 346                     caller.on_epoch_end(self.training_state)
    347 
    348             finally:

/Users/lisa1010/tf_venv/lib/python2.7/site-packages/tflearn/callbacks.pyc in on_epoch_end(self, training_state)
     78     def on_epoch_end(self, training_state):
     79         for callback in self.callbacks:
---> 80             callback.on_epoch_end(training_state)
     81 
     82     def on_train_end(self, training_state):

/Users/lisa1010/tf_venv/lib/python2.7/site-packages/tflearn/callbacks.pyc in on_epoch_end(self, training_state)
    273     def on_epoch_end(self, training_state):
    274         if self.snapshot_epoch:
--> 275             self.save(training_state.step)
    276 
    277     def on_batch_begin(self, training_state):

/Users/lisa1010/tf_venv/lib/python2.7/site-packages/tflearn/callbacks.pyc in save(self, training_step)
    302     def save(self, training_step=0):
    303         if self.snapshot_path:
--> 304             self.save_func(self.snapshot_path, training_step)
    305 
    306     def save_best(self, val_accuracy):

/Users/lisa1010/tf_venv/lib/python2.7/site-packages/tflearn/helpers/trainer.pyc in save(self, model_file, global_step)
    371         if not os.path.isabs(model_file):
    372             model_file = os.path.abspath(os.path.join(os.getcwd(), model_file))
--> 373         self.saver.save(self.session, model_file, global_step=global_step)
    374         utils.fix_saver(obj_lists)
    375 

/Users/lisa1010/tf_venv/lib/python2.7/site-packages/tensorflow/python/training/saver.pyc in save(self, sess, save_path, global_step, latest_filename, meta_graph_suffix, write_meta_graph, write_state)
   1373           checkpoint_file, meta_graph_suffix=meta_graph_suffix)
   1374       with sess.graph.as_default():
-> 1375         self.export_meta_graph(meta_graph_filename)
   1376 
   1377     if self._is_empty:

/Users/lisa1010/tf_venv/lib/python2.7/site-packages/tensorflow/python/training/saver.pyc in export_meta_graph(self, filename, collection_list, as_text, export_scope, clear_devices)
   1401     return export_meta_graph(
   1402         filename=filename,
-> 1403         graph_def=ops.get_default_graph().as_graph_def(add_shapes=True),
   1404         saver_def=self.saver_def,
   1405         collection_list=collection_list,

/Users/lisa1010/tf_venv/lib/python2.7/site-packages/tensorflow/python/framework/ops.pyc in as_graph_def(self, from_version, add_shapes)
   2187       ValueError: If the `graph_def` would be too large.
   2188     """
-> 2189     result, _ = self._as_graph_def(from_version, add_shapes)
   2190     return result
   2191 

/Users/lisa1010/tf_venv/lib/python2.7/site-packages/tensorflow/python/framework/ops.pyc in _as_graph_def(self, from_version, add_shapes)
   2146           if op.outputs and add_shapes:
   2147             assert "_output_shapes" not in graph.node[-1].attr
-> 2148             graph.node[-1].attr["_output_shapes"].list.shape.extend([
   2149                 output.get_shape().as_proto() for output in op.outputs])
   2150           bytesize += op.node_def.ByteSize()

/Users/lisa1010/tf_venv/lib/python2.7/site-packages/google/protobuf/internal/python_message.pyc in getter(self)
    703     if field_value is None:
    704       # Construct a new object to represent this field.
--> 705       field_value = field._default_constructor(self)
    706 
    707       # Atomically check if another thread has preempted us and, if not, swap

/Users/lisa1010/tf_venv/lib/python2.7/site-packages/google/protobuf/internal/python_message.pyc in MakeSubMessageDefault(message)
    424       result._SetListener(
    425           _OneofListener(message, field)
--> 426           if field.containing_oneof is not None
    427           else message._listener_for_children)
    428       return result

/Users/lisa1010/tf_venv/lib/python2.7/site-packages/google/protobuf/internal/python_message.pyc in __init__(self, parent_message, field)
   1407       field: The descriptor of the field being set in the parent message.
   1408     """
-> 1409     super(_OneofListener, self).__init__(parent_message)
   1410     self._field = field
   1411 

KeyboardInterrupt: 

In [43]:
predict(model, input_data_[:10, :10, :])


Out[43]:
[[[-0.007385450880974531,
   0.025550365447998047,
   0.011329921893775463,
   0.03704448416829109,
   0.039401981979608536,
   0.027618983760476112,
   0.017678581178188324,
   0.025421805679798126,
   0.01569029688835144,
   -0.02289711870253086],
  [-0.014513563364744186,
   0.0470220223069191,
   0.040145840495824814,
   0.09161123633384705,
   0.10280773788690567,
   0.06935103237628937,
   0.05226275697350502,
   0.041719403117895126,
   0.04879506677389145,
   -0.07392476499080658],
  [-0.03218340128660202,
   0.07135201245546341,
   0.0685245469212532,
   0.1436130404472351,
   0.19665291905403137,
   0.1438368409872055,
   0.1049000471830368,
   0.06255870312452316,
   0.1026017889380455,
   -0.11450488120317459],
  [-0.06921423971652985,
   0.08478009700775146,
   0.1043052151799202,
   0.19068250060081482,
   0.2767428159713745,
   0.23720687627792358,
   0.17555317282676697,
   0.10750668495893478,
   0.16081872582435608,
   -0.14396199584007263],
  [-0.11939936876296997,
   0.10138699412345886,
   0.13752755522727966,
   0.24032481014728546,
   0.3501714766025543,
   0.32600459456443787,
   0.26092734932899475,
   0.16472099721431732,
   0.22535595297813416,
   -0.17225711047649384],
  [-0.17461079359054565,
   0.1152421087026596,
   0.1707264631986618,
   0.2937884032726288,
   0.4026610255241394,
   0.4047226011753082,
   0.3470379114151001,
   0.2304689586162567,
   0.2942657470703125,
   -0.20128174126148224],
  [-0.23769868910312653,
   0.1092974990606308,
   0.2082236111164093,
   0.3569919466972351,
   0.4412563443183899,
   0.4894976019859314,
   0.422652930021286,
   0.32053667306900024,
   0.35850971937179565,
   -0.20820315182209015],
  [-0.2928321361541748,
   0.11726514250040054,
   0.22756768763065338,
   0.41967979073524475,
   0.4618801772594452,
   0.5404962301254272,
   0.47515639662742615,
   0.4064711928367615,
   0.41023027896881104,
   -0.218476802110672],
  [-0.34365835785865784,
   0.11928804963827133,
   0.2321077436208725,
   0.4715743064880371,
   0.4840993285179138,
   0.5689555406570435,
   0.5109775066375732,
   0.4563242495059967,
   0.46420493721961975,
   -0.2140037715435028],
  [-0.38566485047340393,
   0.1281179040670395,
   0.24280832707881927,
   0.5128179788589478,
   0.48435699939727783,
   0.6077094674110413,
   0.544141948223114,
   0.49809062480926514,
   0.5067498683929443,
   -0.2280215620994568]],
 [[0.001626665354706347,
   0.037479467689991,
   0.019430629909038544,
   0.03404674306511879,
   0.012370062991976738,
   0.03443584218621254,
   0.023976050317287445,
   0.009455711580812931,
   0.026551280170679092,
   -0.015678660944104195],
  [-0.014866768382489681,
   0.07031739503145218,
   0.037209831178188324,
   0.08887828141450882,
   0.06364879757165909,
   0.08181287348270416,
   0.05092545226216316,
   0.03311295807361603,
   0.051180221140384674,
   -0.04780920222401619],
  [-0.03127816319465637,
   0.09107029438018799,
   0.07287603616714478,
   0.15671797096729279,
   0.1430317461490631,
   0.14184020459651947,
   0.09511943906545639,
   0.050354160368442535,
   0.09149425476789474,
   -0.11039866507053375],
  [-0.060976751148700714,
   0.09801477193832397,
   0.11476847529411316,
   0.24379442632198334,
   0.2399071305990219,
   0.21377906203269958,
   0.16034163534641266,
   0.07652704417705536,
   0.14679443836212158,
   -0.17103984951972961],
  [-0.10324911773204803,
   0.11362592875957489,
   0.1525913029909134,
   0.31725165247917175,
   0.34183502197265625,
   0.3059535622596741,
   0.24026139080524445,
   0.1149774044752121,
   0.22379063069820404,
   -0.2214277982711792],
  [-0.1607619673013687,
   0.12293191254138947,
   0.1840815544128418,
   0.38961559534072876,
   0.40708380937576294,
   0.4095291197299957,
   0.3294735550880432,
   0.18316788971424103,
   0.2974473237991333,
   -0.26686352491378784],
  [-0.2251933366060257,
   0.129822239279747,
   0.21352751553058624,
   0.45145541429519653,
   0.44791433215141296,
   0.5020326375961304,
   0.4140762388706207,
   0.2700747847557068,
   0.3686193525791168,
   -0.3029167652130127],
  [-0.2910039722919464,
   0.1354573369026184,
   0.24534107744693756,
   0.4955673813819885,
   0.47051355242729187,
   0.5658157467842102,
   0.4853389859199524,
   0.3722583055496216,
   0.43880945444107056,
   -0.3212301731109619],
  [-0.3583813011646271,
   0.1330350786447525,
   0.279346227645874,
   0.5353736877441406,
   0.49029308557510376,
   0.6222624182701111,
   0.5359885692596436,
   0.47413772344589233,
   0.4909423291683197,
   -0.30825361609458923],
  [-0.4008702039718628,
   0.14655952155590057,
   0.29526978731155396,
   0.5669235587120056,
   0.4961823523044586,
   0.6476258635520935,
   0.5633453726768494,
   0.547645092010498,
   0.5222209692001343,
   -0.30059829354286194]],
 [[0.001626665354706347,
   0.037479467689991,
   0.019430629909038544,
   0.03404674306511879,
   0.012370062991976738,
   0.03443584218621254,
   0.023976050317287445,
   0.009455711580812931,
   0.026551280170679092,
   -0.015678660944104195],
  [-0.005862789694219828,
   0.08617755770683289,
   0.04532817006111145,
   0.0842113122344017,
   0.03539608418941498,
   0.08861597627401352,
   0.05721341073513031,
   0.0165015310049057,
   0.06334442645311356,
   -0.03913375735282898],
  [-0.031976114958524704,
   0.1187528744339943,
   0.07019568234682083,
   0.15418589115142822,
   0.10436864197254181,
   0.15436528623104095,
   0.09397280961275101,
   0.04188372194766998,
   0.0956481471657753,
   -0.08248429745435715],
  [-0.058634355664253235,
   0.1337626725435257,
   0.11179366707801819,
   0.23167745769023895,
   0.1990850567817688,
   0.22867818176746368,
   0.14849594235420227,
   0.06368935108184814,
   0.1434764862060547,
   -0.15573088824748993],
  [-0.09788542985916138,
   0.14540064334869385,
   0.1553751826286316,
   0.30297383666038513,
   0.3119584918022156,
   0.3226514458656311,
   0.22293721139431,
   0.0977112278342247,
   0.21700194478034973,
   -0.21283042430877686],
  [-0.1556023359298706,
   0.14597055315971375,
   0.19833093881607056,
   0.36874639987945557,
   0.38667625188827515,
   0.4158157408237457,
   0.3109106123447418,
   0.16846641898155212,
   0.2929220497608185,
   -0.2547641396522522],
  [-0.22068890929222107,
   0.15480518341064453,
   0.23436932265758514,
   0.4227451980113983,
   0.4372327923774719,
   0.4868674576282501,
   0.3958335220813751,
   0.2547925114631653,
   0.36704960465431213,
   -0.29151248931884766],
  [-0.29112106561660767,
   0.1480877697467804,
   0.27615466713905334,
   0.4806373417377472,
   0.46889474987983704,
   0.5629793405532837,
   0.47029396891593933,
   0.36612704396247864,
   0.436046302318573,
   -0.2990073561668396],
  [-0.3489680290222168,
   0.15215113759040833,
   0.29507625102996826,
   0.5293710827827454,
   0.47888946533203125,
   0.6004736423492432,
   0.5143836140632629,
   0.4624844491481781,
   0.4735715985298157,
   -0.3080124258995056],
  [-0.4017753005027771,
   0.16196505725383759,
   0.3105895519256592,
   0.564799427986145,
   0.4860957860946655,
   0.6261220574378967,
   0.5451611280441284,
   0.5394534468650818,
   0.5054194927215576,
   -0.312355101108551]],
 [[0.001626665354706347,
   0.037479467689991,
   0.019430629909038544,
   0.03404674306511879,
   0.012370062991976738,
   0.03443584218621254,
   0.023976050317287445,
   0.009455711580812931,
   0.026551280170679092,
   -0.015678660944104195],
  [-0.014866768382489681,
   0.07031739503145218,
   0.037209831178188324,
   0.08887828141450882,
   0.06364879757165909,
   0.08181287348270416,
   0.05092545226216316,
   0.03311295807361603,
   0.051180221140384674,
   -0.04780920222401619],
  [-0.03420348837971687,
   0.08495858311653137,
   0.0664568692445755,
   0.15616393089294434,
   0.13940481841564178,
   0.1283305585384369,
   0.08701346069574356,
   0.05158423259854317,
   0.08342994749546051,
   -0.08944535255432129],
  [-0.060907602310180664,
   0.09095901250839233,
   0.10269273072481155,
   0.232721745967865,
   0.23196962475776672,
   0.18219909071922302,
   0.13328084349632263,
   0.0728285014629364,
   0.12778891623020172,
   -0.14081640541553497],
  [-0.09361746907234192,
   0.0981183871626854,
   0.14935459196567535,
   0.3159990906715393,
   0.3259636163711548,
   0.25963741540908813,
   0.20035579800605774,
   0.10048488527536392,
   0.1935679167509079,
   -0.22298680245876312],
  [-0.1449628323316574,
   0.09888078272342682,
   0.19733908772468567,
   0.40551644563674927,
   0.4001621603965759,
   0.3491721451282501,
   0.2830325663089752,
   0.14919288456439972,
   0.2731110155582428,
   -0.2958981990814209],
  [-0.20487850904464722,
   0.10200666636228561,
   0.24171984195709229,
   0.4854142367839813,
   0.44695594906806946,
   0.4349038898944855,
   0.36876046657562256,
   0.21863026916980743,
   0.3542357385158539,
   -0.36073917150497437],
  [-0.2662322223186493,
   0.12092211842536926,
   0.28500714898109436,
   0.5353193879127502,
   0.4847032427787781,
   0.5113058090209961,
   0.44373786449432373,
   0.3023030459880829,
   0.43968331813812256,
   -0.4079033434391022],
  [-0.3269027769565582,
   0.13777515292167664,
   0.3185170888900757,
   0.5725964307785034,
   0.49580419063568115,
   0.5755017995834351,
   0.5088567733764648,
   0.40885812044143677,
   0.5052441358566284,
   -0.4334279000759125],
  [-0.3791249692440033,
   0.16812054812908173,
   0.3467302620410919,
   0.5862396955490112,
   0.5050480365753174,
   0.6146808862686157,
   0.546154797077179,
   0.49721837043762207,
   0.5499317646026611,
   -0.4526500999927521]],
 [[-0.007385450880974531,
   0.025550365447998047,
   0.011329921893775463,
   0.03704448416829109,
   0.039401981979608536,
   0.027618983760476112,
   0.017678581178188324,
   0.025421805679798126,
   0.01569029688835144,
   -0.02289711870253086],
  [-0.016434617340564728,
   0.04202163591980934,
   0.03387784957885742,
   0.09133366495370865,
   0.09891992807388306,
   0.05646282434463501,
   0.04472542181611061,
   0.0426873117685318,
   0.04018104821443558,
   -0.05347241833806038],
  [-0.02927660197019577,
   0.058802951127290726,
   0.07028387486934662,
   0.16001532971858978,
   0.18033723533153534,
   0.10849107801914215,
   0.09034308046102524,
   0.057852793484926224,
   0.0845632553100586,
   -0.11631344258785248],
  [-0.057247862219810486,
   0.08008474111557007,
   0.10546114295721054,
   0.2235705405473709,
   0.2839057445526123,
   0.19722367823123932,
   0.15593916177749634,
   0.0836041271686554,
   0.15329524874687195,
   -0.1659594476222992],
  [-0.10434599965810776,
   0.09520581364631653,
   0.1370006948709488,
   0.29061993956565857,
   0.3641490042209625,
   0.3098030984401703,
   0.24149750173091888,
   0.13650289177894592,
   0.22524765133857727,
   -0.21232809126377106],
  [-0.167626291513443,
   0.10245420038700104,
   0.17349250614643097,
   0.34856370091438293,
   0.41738319396972656,
   0.4118785858154297,
   0.3323492109775543,
   0.21695390343666077,
   0.30281487107276917,
   -0.2412136197090149],
  [-0.24291224777698517,
   0.10032673925161362,
   0.208566814661026,
   0.40867307782173157,
   0.4561443626880646,
   0.5027825236320496,
   0.41346457600593567,
   0.32292020320892334,
   0.37091270089149475,
   -0.24030572175979614],
  [-0.3041929602622986,
   0.11394774913787842,
   0.2258629947900772,
   0.46385255455970764,
   0.47372329235076904,
   0.5544580817222595,
   0.4730626344680786,
   0.4212020933628082,
   0.42455315589904785,
   -0.24243329465389252],
  [-0.3572836220264435,
   0.1321299821138382,
   0.22877410054206848,
   0.5150938034057617,
   0.489890992641449,
   0.5821978449821472,
   0.508150041103363,
   0.47752436995506287,
   0.4684116542339325,
   -0.25712013244628906],
  [-0.40285348892211914,
   0.1360579878091812,
   0.2240927666425705,
   0.5503410696983337,
   0.507268488407135,
   0.602586567401886,
   0.5363503098487854,
   0.5220065116882324,
   0.5163953304290771,
   -0.25548022985458374]],
 [[0.001626665354706347,
   0.037479467689991,
   0.019430629909038544,
   0.03404674306511879,
   0.012370062991976738,
   0.03443584218621254,
   0.023976050317287445,
   0.009455711580812931,
   0.026551280170679092,
   -0.015678660944104195],
  [-0.005862789694219828,
   0.08617755770683289,
   0.04532817006111145,
   0.0842113122344017,
   0.03539608418941498,
   0.08861597627401352,
   0.05721341073513031,
   0.0165015310049057,
   0.06334442645311356,
   -0.03913375735282898],
  [-0.023123623803257942,
   0.13784949481487274,
   0.07789395749568939,
   0.1472705751657486,
   0.07509306073188782,
   0.1604142040014267,
   0.09974164515733719,
   0.02444310672581196,
   0.10856691002845764,
   -0.07207928597927094],
  [-0.05961983650922775,
   0.16516649723052979,
   0.10900723189115524,
   0.22902508080005646,
   0.1622152030467987,
   0.2404518872499466,
   0.14683611690998077,
   0.05489376187324524,
   0.1488243192434311,
   -0.12612617015838623],
  [-0.09663838893175125,
   0.17179463803768158,
   0.15472890436649323,
   0.31104403734207153,
   0.2644428312778473,
   0.32147520780563354,
   0.21147173643112183,
   0.08512993156909943,
   0.20477068424224854,
   -0.20753706991672516],
  [-0.1476699709892273,
   0.16157129406929016,
   0.2033478319644928,
   0.4017667770385742,
   0.3551658093929291,
   0.40257886052131653,
   0.2912360429763794,
   0.13407102227210999,
   0.2749353349208832,
   -0.2794758379459381],
  [-0.20422206819057465,
   0.16578540205955505,
   0.25083914399147034,
   0.46711477637290955,
   0.43075132369995117,
   0.47971758246421814,
   0.3764747381210327,
   0.2016037255525589,
   0.36287787556648254,
   -0.33587899804115295],
  [-0.26813051104545593,
   0.1662045270204544,
   0.28966930508613586,
   0.5192147493362427,
   0.46377167105674744,
   0.5471265316009521,
   0.45559510588645935,
   0.3056398928165436,
   0.44026315212249756,
   -0.37334367632865906],
  [-0.3387893736362457,
   0.16086536645889282,
   0.3295435309410095,
   0.5598940849304199,
   0.4894082844257355,
   0.6116238236427307,
   0.5161793231964111,
   0.4252799153327942,
   0.5001445412635803,
   -0.37833917140960693],
  [-0.38768890500068665,
   0.1690857857465744,
   0.351093590259552,
   0.5892294645309448,
   0.49744564294815063,
   0.6454594731330872,
   0.5497063398361206,
   0.5187475085258484,
   0.5375459790229797,
   -0.3840266764163971]],
 [[-0.007385450880974531,
   0.025550365447998047,
   0.011329921893775463,
   0.03704448416829109,
   0.039401981979608536,
   0.027618983760476112,
   0.017678581178188324,
   0.025421805679798126,
   0.01569029688835144,
   -0.02289711870253086],
  [-0.016434617340564728,
   0.04202163591980934,
   0.03387784957885742,
   0.09133366495370865,
   0.09891992807388306,
   0.05646282434463501,
   0.04472542181611061,
   0.0426873117685318,
   0.04018104821443558,
   -0.05347241833806038],
  [-0.02927660197019577,
   0.058802951127290726,
   0.07028387486934662,
   0.16001532971858978,
   0.18033723533153534,
   0.10849107801914215,
   0.09034308046102524,
   0.057852793484926224,
   0.0845632553100586,
   -0.11631344258785248],
  [-0.058248504996299744,
   0.0690665915608406,
   0.11151129007339478,
   0.2467028647661209,
   0.2727585434913635,
   0.17967411875724792,
   0.15757513046264648,
   0.08372584730386734,
   0.1453004628419876,
   -0.1782950609922409],
  [-0.10167629271745682,
   0.08879223465919495,
   0.14719465374946594,
   0.31988584995269775,
   0.3662298917770386,
   0.2775735855102539,
   0.23900474607944489,
   0.12316052615642548,
   0.22674421966075897,
   -0.22890068590641022],
  [-0.1624550074338913,
   0.09984000772237778,
   0.18285717070102692,
   0.3849408030509949,
   0.42069122195243835,
   0.37895315885543823,
   0.3279584050178528,
   0.19788342714309692,
   0.3057538866996765,
   -0.26531723141670227],
  [-0.23845139145851135,
   0.10098498314619064,
   0.2175826132297516,
   0.4453183114528656,
   0.4605388939380646,
   0.47577205300331116,
   0.40886789560317993,
   0.3037096858024597,
   0.3751377761363983,
   -0.27267539501190186],
  [-0.30577918887138367,
   0.11194641143083572,
   0.23216237127780914,
   0.4968763589859009,
   0.474517285823822,
   0.530599057674408,
   0.4657686650753021,
   0.4047018885612488,
   0.42059555649757385,
   -0.27866390347480774],
  [-0.3670070171356201,
   0.12681278586387634,
   0.24325476586818695,
   0.5356999039649963,
   0.48277804255485535,
   0.5690165162086487,
   0.507861852645874,
   0.4898432195186615,
   0.45843878388404846,
   -0.2800208032131195],
  [-0.4167822003364563,
   0.14109325408935547,
   0.250868558883667,
   0.5658891201019287,
   0.4877112805843353,
   0.5969432592391968,
   0.5366802215576172,
   0.5557106733322144,
   0.4856134355068207,
   -0.2771853506565094]],
 [[-0.007385450880974531,
   0.025550365447998047,
   0.011329921893775463,
   0.03704448416829109,
   0.039401981979608536,
   0.027618983760476112,
   0.017678581178188324,
   0.025421805679798126,
   0.01569029688835144,
   -0.02289711870253086],
  [-0.016434617340564728,
   0.04202163591980934,
   0.03387784957885742,
   0.09133366495370865,
   0.09891992807388306,
   0.05646282434463501,
   0.04472542181611061,
   0.0426873117685318,
   0.04018104821443558,
   -0.05347241833806038],
  [-0.02927660197019577,
   0.058802951127290726,
   0.07028387486934662,
   0.16001532971858978,
   0.18033723533153534,
   0.10849107801914215,
   0.09034308046102524,
   0.057852793484926224,
   0.0845632553100586,
   -0.11631344258785248],
  [-0.058248504996299744,
   0.0690665915608406,
   0.11151129007339478,
   0.2467028647661209,
   0.2727585434913635,
   0.17967411875724792,
   0.15757513046264648,
   0.08372584730386734,
   0.1453004628419876,
   -0.1782950609922409],
  [-0.10211649537086487,
   0.07676362991333008,
   0.15105488896369934,
   0.34038570523262024,
   0.3536534905433655,
   0.26201990246772766,
   0.23862825334072113,
   0.12332990765571594,
   0.21722912788391113,
   -0.23994718492031097],
  [-0.1574753373861313,
   0.0961245521903038,
   0.18575678765773773,
   0.412643164396286,
   0.42478057742118835,
   0.35893985629081726,
   0.32350555062294006,
   0.17743146419525146,
   0.3039356470108032,
   -0.2907536029815674],
  [-0.22027722001075745,
   0.11041035503149033,
   0.2142597734928131,
   0.4794926047325134,
   0.45950832962989807,
   0.46049273014068604,
   0.40720677375793457,
   0.25972336530685425,
   0.3760809004306793,
   -0.3337879776954651],
  [-0.2838820815086365,
   0.12443418055772781,
   0.24316547811031342,
   0.528627872467041,
   0.48025697469711304,
   0.5452204942703247,
   0.47800135612487793,
   0.35232090950012207,
   0.43860000371932983,
   -0.3651861548423767],
  [-0.34057822823524475,
   0.13886594772338867,
   0.27253884077072144,
   0.5640837550163269,
   0.4918398857116699,
   0.607871949672699,
   0.5323774814605713,
   0.4408590495586395,
   0.48914486169815063,
   -0.3854181170463562],
  [-0.38752374053001404,
   0.15441176295280457,
   0.3025898039340973,
   0.5882860422134399,
   0.49890464544296265,
   0.6517040729522705,
   0.5718855857849121,
   0.5152390003204346,
   0.5280877351760864,
   -0.3967254161834717]],
 [[0.001626665354706347,
   0.037479467689991,
   0.019430629909038544,
   0.03404674306511879,
   0.012370062991976738,
   0.03443584218621254,
   0.023976050317287445,
   0.009455711580812931,
   0.026551280170679092,
   -0.015678660944104195],
  [-0.014866768382489681,
   0.07031739503145218,
   0.037209831178188324,
   0.08887828141450882,
   0.06364879757165909,
   0.08181287348270416,
   0.05092545226216316,
   0.03311295807361603,
   0.051180221140384674,
   -0.04780920222401619],
  [-0.03127816319465637,
   0.09107029438018799,
   0.07287603616714478,
   0.15671797096729279,
   0.1430317461490631,
   0.14184020459651947,
   0.09511943906545639,
   0.050354160368442535,
   0.09149425476789474,
   -0.11039866507053375],
  [-0.059480901807546616,
   0.11047469079494476,
   0.10961417853832245,
   0.22083410620689392,
   0.2510247528553009,
   0.23019792139530182,
   0.15908633172512054,
   0.07584737241268158,
   0.15513871610164642,
   -0.16032931208610535],
  [-0.10637495666742325,
   0.1215302050113678,
   0.14279113709926605,
   0.28813230991363525,
   0.3402049243450165,
   0.3377043902873993,
   0.2429751753807068,
   0.12786360085010529,
   0.22257700562477112,
   -0.2063126415014267],
  [-0.16847631335258484,
   0.12283267825841904,
   0.18033909797668457,
   0.34652021527290344,
   0.40150222182273865,
   0.43177521228790283,
   0.33234745264053345,
   0.2077411562204361,
   0.29760900139808655,
   -0.23532544076442719],
  [-0.23343591392040253,
   0.13489453494548798,
   0.21264009177684784,
   0.39830711483955383,
   0.4422001242637634,
   0.4987843334674835,
   0.41326168179512024,
   0.2947700619697571,
   0.3689412474632263,
   -0.26142585277557373],
  [-0.3018285632133484,
   0.1307247281074524,
   0.2514607012271881,
   0.4561883807182312,
   0.4691040813922882,
   0.5695794224739075,
   0.48165783286094666,
   0.3989560306072235,
   0.43338218331336975,
   -0.2604142129421234],
  [-0.3521101176738739,
   0.141281396150589,
   0.271355003118515,
   0.505827009677887,
   0.48101359605789185,
   0.6069961786270142,
   0.5242564678192139,
   0.48353955149650574,
   0.476498544216156,
   -0.26560845971107483],
  [-0.39889639616012573,
   0.14655019342899323,
   0.2772831618785858,
   0.5432493686676025,
   0.5000923871994019,
   0.6239126920700073,
   0.5520086288452148,
   0.5288835763931274,
   0.5202540159225464,
   -0.25520721077919006]],
 [[-0.007385450880974531,
   0.025550365447998047,
   0.011329921893775463,
   0.03704448416829109,
   0.039401981979608536,
   0.027618983760476112,
   0.017678581178188324,
   0.025421805679798126,
   0.01569029688835144,
   -0.02289711870253086],
  [-0.016434617340564728,
   0.04202163591980934,
   0.03387784957885742,
   0.09133366495370865,
   0.09891992807388306,
   0.05646282434463501,
   0.04472542181611061,
   0.0426873117685318,
   0.04018104821443558,
   -0.05347241833806038],
  [-0.02927660197019577,
   0.058802951127290726,
   0.07028387486934662,
   0.16001532971858978,
   0.18033723533153534,
   0.10849107801914215,
   0.09034308046102524,
   0.057852793484926224,
   0.0845632553100586,
   -0.11631344258785248],
  [-0.057247862219810486,
   0.08008474111557007,
   0.10546114295721054,
   0.2235705405473709,
   0.2839057445526123,
   0.19722367823123932,
   0.15593916177749634,
   0.0836041271686554,
   0.15329524874687195,
   -0.1659594476222992],
  [-0.10634910315275192,
   0.09204169362783432,
   0.14448414742946625,
   0.2828052043914795,
   0.3588876724243164,
   0.30093318223953247,
   0.23928526043891907,
   0.14087560772895813,
   0.22557410597801208,
   -0.20268292725086212],
  [-0.16713985800743103,
   0.10960502922534943,
   0.17833401262760162,
   0.3384683132171631,
   0.415674090385437,
   0.39195284247398376,
   0.3292734920978546,
   0.2125479131937027,
   0.299818217754364,
   -0.23643545806407928],
  [-0.23756016790866852,
   0.10867643356323242,
   0.2169070541858673,
   0.4016650319099426,
   0.45342424511909485,
   0.4848024845123291,
   0.4118563234806061,
   0.31293436884880066,
   0.3700428307056427,
   -0.24362049996852875],
  [-0.3006177842617035,
   0.11638326942920685,
   0.2345806509256363,
   0.4584605097770691,
   0.4678575098514557,
   0.536447286605835,
   0.468919575214386,
   0.4064263701438904,
   0.4150925278663635,
   -0.25325480103492737],
  [-0.3589994013309479,
   0.12805280089378357,
   0.2480306476354599,
   0.5039001703262329,
   0.47728273272514343,
   0.5728580355644226,
   0.5102958679199219,
   0.48629629611968994,
   0.4536206126213074,
   -0.2593274414539337],
  [-0.40764355659484863,
   0.13997741043567657,
   0.2571384608745575,
   0.54050213098526,
   0.4833797812461853,
   0.5996587872505188,
   0.5384678840637207,
   0.5493736267089844,
   0.4821021854877472,
   -0.2614176869392395]]]

In [ ]:


In [52]:
import dynamics_model_class as dm

In [53]:
dmodel = dm.DynamicsModel(model_id="test_model",  load_checkpoint=False)


Loading RNN dynamics model...
Directory path for tensorboard summaries: ../tensorboard_logs/test_model/
Checkpoint directory path: ../checkpoints/test_model/
Model loaded.

In [56]:
dmodel.train(train_data)


Training Step: 1019  | total loss: 0.02487 | time: 0.567s
| Adam | epoch: 068 | loss: 0.02487 -- iter: 896/900

KeyboardInterruptTraceback (most recent call last)
<ipython-input-56-56596559dd51> in <module>()
----> 1 dmodel.train(train_data)

/Users/lisa1010/dev/smart-tutor/code/dynamics_model_class.pyc in train(self, train_data, load_checkpoint)
     86         date_time_string = datetime.datetime.now().strftime("%m-%d-%Y_%H-%M-%S")
     87         run_id = "{}".format(date_time_string)
---> 88         self.model.fit([input_data, output_mask], output_data, n_epoch=64, validation_set=0.1)
     89 
     90 

/Users/lisa1010/tf_venv/lib/python2.7/site-packages/tflearn/models/dnn.pyc in fit(self, X_inputs, Y_targets, n_epoch, validation_set, show_metric, batch_size, shuffle, snapshot_epoch, snapshot_step, excl_trainops, validation_batch_size, run_id, callbacks)
    213                          excl_trainops=excl_trainops,
    214                          run_id=run_id,
--> 215                          callbacks=callbacks)
    216 
    217     def predict(self, X):

/Users/lisa1010/tf_venv/lib/python2.7/site-packages/tflearn/helpers/trainer.pyc in fit(self, feed_dicts, n_epoch, val_feed_dicts, show_metric, snapshot_step, snapshot_epoch, shuffle_all, dprep_dict, daug_dict, excl_trainops, run_id, callbacks)
    331                                                        (bool(self.best_checkpoint_path) | snapshot_epoch),
    332                                                        snapshot_step,
--> 333                                                        show_metric)
    334 
    335                             # Update training state

/Users/lisa1010/tf_venv/lib/python2.7/site-packages/tflearn/helpers/trainer.pyc in _train(self, training_step, snapshot_epoch, snapshot_step, show_metric)
    801             if show_metric and self.metric is not None:
    802                 eval_ops.append(self.metric)
--> 803             e = evaluate_flow(self.session, eval_ops, self.test_dflow)
    804             self.val_loss = e[0]
    805             if show_metric and self.metric is not None:

/Users/lisa1010/tf_venv/lib/python2.7/site-packages/tflearn/helpers/trainer.pyc in evaluate_flow(session, ops_to_evaluate, dataflow)
    941             for i in range(len(r)):
    942                 res[i] += r[i] * current_batch_size
--> 943             feed_batch = dataflow.next()
    944         res = [r / dataflow.n_samples for r in res]
    945         return res

/Users/lisa1010/tf_venv/lib/python2.7/site-packages/tflearn/data_flow.pyc in next(self, timeout)
    127         """
    128         self.data_status.update()
--> 129         return self.feed_dict_queue.get(timeout=timeout)
    130 
    131     def start(self, reset_status=True):

/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/Queue.pyc in get(self, block, timeout)
    166             elif timeout is None:
    167                 while not self._qsize():
--> 168                     self.not_empty.wait()
    169             elif timeout < 0:
    170                 raise ValueError("'timeout' must be a non-negative number")

/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/threading.pyc in wait(self, timeout)
    338         try:    # restore state no matter what (e.g., KeyboardInterrupt)
    339             if timeout is None:
--> 340                 waiter.acquire()
    341                 if __debug__:
    342                     self._note("%s.wait(): got it", self)

KeyboardInterrupt: 

In [57]:
preds = dmodel.predict(input_data_[:10,:10, :])


Out[57]:
[[[0.49599671363830566,
   0.505977988243103,
   0.49741360545158386,
   0.5000333189964294,
   0.498515248298645,
   0.5003623366355896,
   0.5020152926445007,
   0.4968193471431732,
   0.5021724104881287,
   0.49825751781463623],
  [0.48964592814445496,
   0.5083518624305725,
   0.49532902240753174,
   0.5052512884140015,
   0.49550819396972656,
   0.5047191381454468,
   0.49943220615386963,
   0.49179941415786743,
   0.5016341805458069,
   0.4988301992416382],
  [0.48247480392456055,
   0.5100704431533813,
   0.4926689863204956,
   0.5104095339775085,
   0.4944813549518585,
   0.5111148357391357,
   0.49826157093048096,
   0.4853592813014984,
   0.4960009455680847,
   0.49970728158950806],
  [0.4703198969364166,
   0.5135535001754761,
   0.4915870130062103,
   0.5127080678939819,
   0.49413204193115234,
   0.5221576690673828,
   0.502672016620636,
   0.47870734333992004,
   0.4900960922241211,
   0.5019947290420532],
  [0.4537630081176758,
   0.520048201084137,
   0.49032503366470337,
   0.5139161944389343,
   0.4934532642364502,
   0.5316247344017029,
   0.5167728662490845,
   0.4716108441352844,
   0.4958451986312866,
   0.4966897964477539],
  [0.43617871403694153,
   0.5289715528488159,
   0.48939430713653564,
   0.5126253962516785,
   0.4925546646118164,
   0.5344418883323669,
   0.5401607155799866,
   0.4698314368724823,
   0.5158414244651794,
   0.4840362071990967],
  [0.42516717314720154,
   0.5388312339782715,
   0.48898425698280334,
   0.508516788482666,
   0.49308663606643677,
   0.5259639620780945,
   0.5653541088104248,
   0.4776161313056946,
   0.5385245084762573,
   0.46874111890792847],
  [0.4183195233345032,
   0.5485021471977234,
   0.4903010129928589,
   0.5027713179588318,
   0.48928093910217285,
   0.5091527700424194,
   0.5853952765464783,
   0.4925719201564789,
   0.557790994644165,
   0.456940621137619],
  [0.41411229968070984,
   0.559482753276825,
   0.49123120307922363,
   0.497923344373703,
   0.4830032289028168,
   0.48612239956855774,
   0.6010857820510864,
   0.5239127278327942,
   0.5803111791610718,
   0.44925323128700256],
  [0.4249594509601593,
   0.569526195526123,
   0.4904220998287201,
   0.4949359595775604,
   0.46843722462654114,
   0.45716696977615356,
   0.6109771728515625,
   0.5648123025894165,
   0.5905900597572327,
   0.4446971118450165]],
 [[0.4935073256492615,
   0.5066367387771606,
   0.5006406307220459,
   0.4931355118751526,
   0.5011196732521057,
   0.5027900338172913,
   0.5059390664100647,
   0.5037103891372681,
   0.5045592188835144,
   0.49498361349105835],
  [0.4888042211532593,
   0.5124666690826416,
   0.49886104464530945,
   0.48768845200538635,
   0.502082884311676,
   0.5015733242034912,
   0.5139681100845337,
   0.5072578191757202,
   0.5110902190208435,
   0.4893607795238495],
  [0.4837706387042999,
   0.5132308006286621,
   0.4966180622577667,
   0.4896356761455536,
   0.5010759830474854,
   0.5032262802124023,
   0.5173490047454834,
   0.5093849897384644,
   0.5153816342353821,
   0.48739975690841675],
  [0.47697913646698,
   0.5159948468208313,
   0.4962058365345001,
   0.5016492605209351,
   0.5036230087280273,
   0.5028076767921448,
   0.5183938145637512,
   0.5096410512924194,
   0.5118613839149475,
   0.48518186807632446],
  [0.47057032585144043,
   0.5163217782974243,
   0.49360454082489014,
   0.5125990509986877,
   0.5019816160202026,
   0.5014887452125549,
   0.5191658139228821,
   0.5111626386642456,
   0.5062094330787659,
   0.483379989862442],
  [0.4688839614391327,
   0.5220568776130676,
   0.49056243896484375,
   0.5243667960166931,
   0.49032846093177795,
   0.4946741461753845,
   0.5149816870689392,
   0.5242074728012085,
   0.5099191665649414,
   0.47194865345954895],
  [0.47019073367118835,
   0.5310572981834412,
   0.4862220287322998,
   0.5377153158187866,
   0.47338351607322693,
   0.4833642244338989,
   0.4996504783630371,
   0.5456037521362305,
   0.5182302594184875,
   0.456155389547348],
  [0.469103068113327,
   0.5361096262931824,
   0.47985541820526123,
   0.5500144362449646,
   0.4661874771118164,
   0.475056529045105,
   0.4756297767162323,
   0.5621740221977234,
   0.5216403603553772,
   0.4482935965061188],
  [0.4695298969745636,
   0.5345526933670044,
   0.4739120900630951,
   0.5572773814201355,
   0.4666861593723297,
   0.4686569571495056,
   0.4480878710746765,
   0.5711529850959778,
   0.522301435470581,
   0.44190213084220886],
  [0.468523770570755,
   0.525304913520813,
   0.4741703271865845,
   0.5627244710922241,
   0.4680138826370239,
   0.4681204557418823,
   0.42252233624458313,
   0.5644826889038086,
   0.5175504088401794,
   0.4397382140159607]],
 [[0.4935073256492615,
   0.5066367387771606,
   0.5006406307220459,
   0.4931355118751526,
   0.5011196732521057,
   0.5027900338172913,
   0.5059390664100647,
   0.5037103891372681,
   0.5045592188835144,
   0.49498361349105835],
  [0.4864151179790497,
   0.5134059190750122,
   0.5013522505760193,
   0.4798593819141388,
   0.5054095983505249,
   0.503986120223999,
   0.5183268785476685,
   0.5136097073554993,
   0.5136396884918213,
   0.4860149323940277],
  [0.48251864314079285,
   0.5177899599075317,
   0.49937862157821655,
   0.4698704779148102,
   0.5096774697303772,
   0.4998442530632019,
   0.5328044295310974,
   0.5246942043304443,
   0.5241942405700684,
   0.4782741963863373],
  [0.48027995228767395,
   0.5149013996124268,
   0.49656760692596436,
   0.46929213404655457,
   0.511268675327301,
   0.4979315996170044,
   0.5424811840057373,
   0.5347463488578796,
   0.5325235724449158,
   0.4752316176891327],
  [0.48054632544517517,
   0.5101013779640198,
   0.49272894859313965,
   0.4788307845592499,
   0.5128154158592224,
   0.495523601770401,
   0.5500748753547668,
   0.5437270998954773,
   0.5377021431922913,
   0.473631352186203],
  [0.47937169671058655,
   0.5055421590805054,
   0.48871538043022156,
   0.49055591225624084,
   0.5119155049324036,
   0.497022807598114,
   0.5583524703979492,
   0.5503260493278503,
   0.5420855283737183,
   0.47418126463890076],
  [0.47849005460739136,
   0.5042028427124023,
   0.48473697900772095,
   0.5034549832344055,
   0.5114812254905701,
   0.4919143319129944,
   0.5679455995559692,
   0.5568810105323792,
   0.552671492099762,
   0.4691372215747833],
  [0.47729960083961487,
   0.5018625855445862,
   0.4813368618488312,
   0.5148546695709229,
   0.5120936632156372,
   0.482450932264328,
   0.5753395557403564,
   0.5665517449378967,
   0.5611532926559448,
   0.4625680148601532],
  [0.4716723561286926,
   0.4921557605266571,
   0.4792180359363556,
   0.5260759592056274,
   0.5021523833274841,
   0.4689690172672272,
   0.5777547359466553,
   0.5858073830604553,
   0.5734472870826721,
   0.4575400948524475],
  [0.46615177392959595,
   0.47498705983161926,
   0.47839879989624023,
   0.5374900102615356,
   0.48372694849967957,
   0.4545191824436188,
   0.5819700360298157,
   0.609631359577179,
   0.585426926612854,
   0.4527510702610016]],
 [[0.4935073256492615,
   0.5066367387771606,
   0.5006406307220459,
   0.4931355118751526,
   0.5011196732521057,
   0.5027900338172913,
   0.5059390664100647,
   0.5037103891372681,
   0.5045592188835144,
   0.49498361349105835],
  [0.4888042211532593,
   0.5124666690826416,
   0.49886104464530945,
   0.48768845200538635,
   0.502082884311676,
   0.5015733242034912,
   0.5139681100845337,
   0.5072578191757202,
   0.5110902190208435,
   0.4893607795238495],
  [0.4929240345954895,
   0.516448974609375,
   0.4938470125198364,
   0.49369972944259644,
   0.5044259428977966,
   0.4983064830303192,
   0.5112334489822388,
   0.5068015456199646,
   0.5119608640670776,
   0.4841019809246063],
  [0.5007560849189758,
   0.5182065367698669,
   0.4872228801250458,
   0.5079019665718079,
   0.5078688263893127,
   0.49604564905166626,
   0.4969274699687958,
   0.5014709830284119,
   0.5072579383850098,
   0.48117199540138245],
  [0.5020344257354736,
   0.5155802965164185,
   0.48553648591041565,
   0.521838903427124,
   0.5078266859054565,
   0.5006723403930664,
   0.4766537845134735,
   0.49318230152130127,
   0.4999012053012848,
   0.4852582514286041],
  [0.4984493851661682,
   0.5138178467750549,
   0.4931369125843048,
   0.5304259061813354,
   0.5104289650917053,
   0.5083906054496765,
   0.4594757854938507,
   0.48229333758354187,
   0.483430951833725,
   0.49318283796310425],
  [0.48839783668518066,
   0.5124683380126953,
   0.5079970359802246,
   0.5351042151451111,
   0.5117958188056946,
   0.5207926630973816,
   0.4490090012550354,
   0.4710412621498108,
   0.4594714045524597,
   0.5025239586830139],
  [0.4691604673862457,
   0.5104724168777466,
   0.5210692286491394,
   0.5382067561149597,
   0.5093942284584045,
   0.5367593765258789,
   0.44400638341903687,
   0.4590891897678375,
   0.4302017092704773,
   0.513044536113739],
  [0.43615126609802246,
   0.5113045573234558,
   0.5329358577728271,
   0.5375150442123413,
   0.5057876706123352,
   0.558469295501709,
   0.44797810912132263,
   0.44826582074165344,
   0.4029180407524109,
   0.5220268368721008],
  [0.39566367864608765,
   0.5144246220588684,
   0.5379319787025452,
   0.5362714529037476,
   0.5033884644508362,
   0.5835130214691162,
   0.4657069146633148,
   0.4361811578273773,
   0.38456860184669495,
   0.5213557481765747]],
 [[0.49599671363830566,
   0.505977988243103,
   0.49741360545158386,
   0.5000333189964294,
   0.498515248298645,
   0.5003623366355896,
   0.5020152926445007,
   0.4968193471431732,
   0.5021724104881287,
   0.49825751781463623],
  [0.4982739984989166,
   0.5114309191703796,
   0.4927102029323578,
   0.5087923407554626,
   0.4982874393463135,
   0.5002113580703735,
   0.4950408637523651,
   0.49080711603164673,
   0.4985194206237793,
   0.495974600315094],
  [0.495180606842041,
   0.5128109455108643,
   0.49134019017219543,
   0.5186378359794617,
   0.4964984357357025,
   0.505379319190979,
   0.4830396771430969,
   0.4827333390712738,
   0.4921058714389801,
   0.49838364124298096],
  [0.488839328289032,
   0.5130492448806763,
   0.4918130338191986,
   0.5242824554443359,
   0.4961698055267334,
   0.5140914916992188,
   0.47307127714157104,
   0.47317343950271606,
   0.4796505272388458,
   0.5027008652687073],
  [0.48134538531303406,
   0.5185043215751648,
   0.4974914789199829,
   0.5281264185905457,
   0.49063757061958313,
   0.5196006894111633,
   0.46588170528411865,
   0.4708947241306305,
   0.4721260964870453,
   0.49853694438934326],
  [0.46692249178886414,
   0.5220759510993958,
   0.5048929452896118,
   0.5288379788398743,
   0.4898432195186615,
   0.5277711153030396,
   0.4598439633846283,
   0.46979770064353943,
   0.4605463743209839,
   0.4996195435523987],
  [0.44905009865760803,
   0.52361661195755,
   0.5140159130096436,
   0.5291934609413147,
   0.4894247353076935,
   0.5352762937545776,
   0.45834168791770935,
   0.4670118987560272,
   0.44874122738838196,
   0.496447890996933],
  [0.42665520310401917,
   0.5251380205154419,
   0.5297182202339172,
   0.5279824733734131,
   0.48798325657844543,
   0.5414242148399353,
   0.4594918489456177,
   0.46260207891464233,
   0.43408021330833435,
   0.49125513434410095],
  [0.4031164348125458,
   0.5257107615470886,
   0.5459632873535156,
   0.5304495096206665,
   0.4891797602176666,
   0.5490188598632812,
   0.4588281512260437,
   0.44747695326805115,
   0.4023078680038452,
   0.4877915680408478],
  [0.3729502260684967,
   0.5255358219146729,
   0.5609372854232788,
   0.5321497917175293,
   0.48992788791656494,
   0.5642527341842651,
   0.46293866634368896,
   0.4378302991390228,
   0.3744458854198456,
   0.4855373203754425]],
 [[0.4935073256492615,
   0.5066367387771606,
   0.5006406307220459,
   0.4931355118751526,
   0.5011196732521057,
   0.5027900338172913,
   0.5059390664100647,
   0.5037103891372681,
   0.5045592188835144,
   0.49498361349105835],
  [0.4864151179790497,
   0.5134059190750122,
   0.5013522505760193,
   0.4798593819141388,
   0.5054095983505249,
   0.503986120223999,
   0.5183268785476685,
   0.5136097073554993,
   0.5136396884918213,
   0.4860149323940277],
  [0.4804767370223999,
   0.5187111496925354,
   0.5010325312614441,
   0.4613550305366516,
   0.5138576626777649,
   0.5021450519561768,
   0.5373328924179077,
   0.5301012396812439,
   0.5266373753547668,
   0.4751938283443451],
  [0.4786783754825592,
   0.5192803740501404,
   0.49868643283843994,
   0.4485275149345398,
   0.5217024087905884,
   0.4942969083786011,
   0.5571663975715637,
   0.548382043838501,
   0.5394854545593262,
   0.4674091041088104],
  [0.4800170063972473,
   0.5085130929946899,
   0.4956192672252655,
   0.44785237312316895,
   0.5255739688873291,
   0.4887203872203827,
   0.5711811184883118,
   0.5651856660842896,
   0.5490503311157227,
   0.46478331089019775],
  [0.48153260350227356,
   0.49942296743392944,
   0.4930483400821686,
   0.4832015037536621,
   0.5232486724853516,
   0.47565048933029175,
   0.5736194252967834,
   0.5832151174545288,
   0.5526918172836304,
   0.46285244822502136],
  [0.4864833652973175,
   0.48481810092926025,
   0.48882997035980225,
   0.5194024443626404,
   0.513222873210907,
   0.46280092000961304,
   0.5727747678756714,
   0.5997238159179688,
   0.5543747544288635,
   0.45990189909935],
  [0.4895656704902649,
   0.4684159457683563,
   0.4830937385559082,
   0.5518429279327393,
   0.4951283931732178,
   0.4554482400417328,
   0.5680369734764099,
   0.6143181324005127,
   0.5561201572418213,
   0.4577709138393402],
  [0.4911474585533142,
   0.4511832296848297,
   0.4759911000728607,
   0.5732237100601196,
   0.48121845722198486,
   0.4484676420688629,
   0.5565111637115479,
   0.6265982389450073,
   0.5558149814605713,
   0.4528338611125946],
  [0.4923345148563385,
   0.4308885931968689,
   0.4694708287715912,
   0.5909002423286438,
   0.46782487630844116,
   0.4447915256023407,
   0.5333371758460999,
   0.6277155876159668,
   0.553862452507019,
   0.4495563507080078]],
 [[0.49599671363830566,
   0.505977988243103,
   0.49741360545158386,
   0.5000333189964294,
   0.498515248298645,
   0.5003623366355896,
   0.5020152926445007,
   0.4968193471431732,
   0.5021724104881287,
   0.49825751781463623],
  [0.4982739984989166,
   0.5114309191703796,
   0.4927102029323578,
   0.5087923407554626,
   0.4982874393463135,
   0.5002113580703735,
   0.4950408637523651,
   0.49080711603164673,
   0.4985194206237793,
   0.495974600315094],
  [0.495180606842041,
   0.5128109455108643,
   0.49134019017219543,
   0.5186378359794617,
   0.4964984357357025,
   0.505379319190979,
   0.4830396771430969,
   0.4827333390712738,
   0.4921058714389801,
   0.49838364124298096],
  [0.48699524998664856,
   0.514747679233551,
   0.4969709813594818,
   0.5244477987289429,
   0.5002292394638062,
   0.5135923027992249,
   0.47346407175064087,
   0.4732531011104584,
   0.47692635655403137,
   0.5020098090171814],
  [0.4732399582862854,
   0.5144741535186768,
   0.501699686050415,
   0.5289419889450073,
   0.5003151893615723,
   0.5248336791992188,
   0.46726882457733154,
   0.46334242820739746,
   0.4561832547187805,
   0.5071861743927002],
  [0.44964221119880676,
   0.5158843994140625,
   0.5079138875007629,
   0.5293943881988525,
   0.4989471137523651,
   0.5415902733802795,
   0.4693746268749237,
   0.45465245842933655,
   0.43583324551582336,
   0.5127857327461243],
  [0.4221138060092926,
   0.5176545977592468,
   0.5147653222084045,
   0.5295987129211426,
   0.49693167209625244,
   0.5577367544174194,
   0.47876542806625366,
   0.4452420175075531,
   0.41869810223579407,
   0.5082810521125793],
  [0.38811028003692627,
   0.5214492082595825,
   0.5184506177902222,
   0.5253399014472961,
   0.49174442887306213,
   0.5693498253822327,
   0.5011019706726074,
   0.45353057980537415,
   0.4164471924304962,
   0.4968683123588562],
  [0.36234354972839355,
   0.5278410911560059,
   0.5153379440307617,
   0.5202239751815796,
   0.48276323080062866,
   0.572359561920166,
   0.5301076769828796,
   0.4818229079246521,
   0.4330885708332062,
   0.4801550805568695],
  [0.35057127475738525,
   0.5377871990203857,
   0.5082552433013916,
   0.5131154656410217,
   0.4674236476421356,
   0.5599708557128906,
   0.5598285794258118,
   0.533362865447998,
   0.4811905324459076,
   0.4621477723121643]],
 [[0.49599671363830566,
   0.505977988243103,
   0.49741360545158386,
   0.5000333189964294,
   0.498515248298645,
   0.5003623366355896,
   0.5020152926445007,
   0.4968193471431732,
   0.5021724104881287,
   0.49825751781463623],
  [0.4982739984989166,
   0.5114309191703796,
   0.4927102029323578,
   0.5087923407554626,
   0.4982874393463135,
   0.5002113580703735,
   0.4950408637523651,
   0.49080711603164673,
   0.4985194206237793,
   0.495974600315094],
  [0.495180606842041,
   0.5128109455108643,
   0.49134019017219543,
   0.5186378359794617,
   0.4964984357357025,
   0.505379319190979,
   0.4830396771430969,
   0.4827333390712738,
   0.4921058714389801,
   0.49838364124298096],
  [0.48699524998664856,
   0.514747679233551,
   0.4969709813594818,
   0.5244477987289429,
   0.5002292394638062,
   0.5135923027992249,
   0.47346407175064087,
   0.4732531011104584,
   0.47692635655403137,
   0.5020098090171814],
  [0.47155246138572693,
   0.5163626074790955,
   0.5073354840278625,
   0.5284625291824341,
   0.5036620497703552,
   0.5252302885055542,
   0.4685995578765869,
   0.4637703597545624,
   0.4550075829029083,
   0.505318284034729],
  [0.44791990518569946,
   0.5157428979873657,
   0.5143195390701294,
   0.5320842862129211,
   0.5027382373809814,
   0.5384891629219055,
   0.46804317831993103,
   0.4545551836490631,
   0.4287174344062805,
   0.509219229221344],
  [0.4274927079677582,
   0.5203395485877991,
   0.5206096172332764,
   0.5363749265670776,
   0.4941402077674866,
   0.5440685153007507,
   0.47415691614151,
   0.456539124250412,
   0.4133474826812744,
   0.49973049759864807],
  [0.41368141770362854,
   0.5285932421684265,
   0.5237830281257629,
   0.5407341122627258,
   0.47680577635765076,
   0.5372069478034973,
   0.4789503216743469,
   0.47505199909210205,
   0.40986666083335876,
   0.4800032675266266],
  [0.40893372893333435,
   0.540021538734436,
   0.5230647921562195,
   0.5466728806495667,
   0.453422874212265,
   0.5180574059486389,
   0.4767017364501953,
   0.5112937092781067,
   0.4170326292514801,
   0.4565376937389374],
  [0.412627249956131,
   0.5536556839942932,
   0.5195404291152954,
   0.5547642111778259,
   0.4317971467971802,
   0.49344614148139954,
   0.4625207781791687,
   0.5525162220001221,
   0.4331713318824768,
   0.43483832478523254]],
 [[0.4935073256492615,
   0.5066367387771606,
   0.5006406307220459,
   0.4931355118751526,
   0.5011196732521057,
   0.5027900338172913,
   0.5059390664100647,
   0.5037103891372681,
   0.5045592188835144,
   0.49498361349105835],
  [0.4888042211532593,
   0.5124666690826416,
   0.49886104464530945,
   0.48768845200538635,
   0.502082884311676,
   0.5015733242034912,
   0.5139681100845337,
   0.5072578191757202,
   0.5110902190208435,
   0.4893607795238495],
  [0.4837706387042999,
   0.5132308006286621,
   0.4966180622577667,
   0.4896356761455536,
   0.5010759830474854,
   0.5032262802124023,
   0.5173490047454834,
   0.5093849897384644,
   0.5153816342353821,
   0.48739975690841675],
  [0.4796925485134125,
   0.5129684805870056,
   0.49323591589927673,
   0.495869517326355,
   0.501282811164856,
   0.50560462474823,
   0.5209391713142395,
   0.5099912881851196,
   0.5158118009567261,
   0.4864618182182312],
  [0.4783409535884857,
   0.5181713700294495,
   0.49011996388435364,
   0.503265917301178,
   0.49712297320365906,
   0.5029231905937195,
   0.5176870822906494,
   0.5175051093101501,
   0.5209943652153015,
   0.4768555462360382],
  [0.47424039244651794,
   0.5219233632087708,
   0.4868907630443573,
   0.5113959312438965,
   0.49551692605018616,
   0.5024139285087585,
   0.5127443671226501,
   0.5239893794059753,
   0.5231109261512756,
   0.4731998145580292],
  [0.4707843065261841,
   0.5274010300636292,
   0.4835318326950073,
   0.5203237533569336,
   0.4962611496448517,
   0.4970822036266327,
   0.5096275210380554,
   0.527242124080658,
   0.5329503417015076,
   0.46647271513938904],
  [0.4672967195510864,
   0.5302364826202393,
   0.4804839491844177,
   0.525964617729187,
   0.49858424067497253,
   0.4899025857448578,
   0.5045422911643982,
   0.5310878157615662,
   0.5402598977088928,
   0.4598078727722168],
  [0.46306389570236206,
   0.5290700197219849,
   0.4814821779727936,
   0.529658317565918,
   0.49663886427879333,
   0.48454758524894714,
   0.49519726634025574,
   0.5276616215705872,
   0.5427936315536499,
   0.45509234070777893],
  [0.4558480978012085,
   0.53023362159729,
   0.4821963608264923,
   0.5352409482002258,
   0.4964573383331299,
   0.4807736575603485,
   0.4828425943851471,
   0.5261974334716797,
   0.5431912541389465,
   0.4511813819408417]],
 [[0.49599671363830566,
   0.505977988243103,
   0.49741360545158386,
   0.5000333189964294,
   0.498515248298645,
   0.5003623366355896,
   0.5020152926445007,
   0.4968193471431732,
   0.5021724104881287,
   0.49825751781463623],
  [0.4982739984989166,
   0.5114309191703796,
   0.4927102029323578,
   0.5087923407554626,
   0.4982874393463135,
   0.5002113580703735,
   0.4950408637523651,
   0.49080711603164673,
   0.4985194206237793,
   0.495974600315094],
  [0.495180606842041,
   0.5128109455108643,
   0.49134019017219543,
   0.5186378359794617,
   0.4964984357357025,
   0.505379319190979,
   0.4830396771430969,
   0.4827333390712738,
   0.4921058714389801,
   0.49838364124298096],
  [0.488839328289032,
   0.5130492448806763,
   0.4918130338191986,
   0.5242824554443359,
   0.4961698055267334,
   0.5140914916992188,
   0.47307127714157104,
   0.47317343950271606,
   0.4796505272388458,
   0.5027008652687073],
  [0.47480571269989014,
   0.5146147012710571,
   0.49649596214294434,
   0.5256413817405701,
   0.4959128499031067,
   0.5280333161354065,
   0.4700377285480499,
   0.46416711807250977,
   0.4655117690563202,
   0.5094239115715027],
  [0.4534894526004791,
   0.5187195539474487,
   0.5011202096939087,
   0.5259437561035156,
   0.49559882283210754,
   0.5434014797210693,
   0.4788511395454407,
   0.45399218797683716,
   0.4594345688819885,
   0.509539783000946],
  [0.4300541579723358,
   0.5228154063224792,
   0.5059940814971924,
   0.5240729451179504,
   0.4956946074962616,
   0.5538676977157593,
   0.493369996547699,
   0.4469212293624878,
   0.45663511753082275,
   0.4996611475944519],
  [0.4003496468067169,
   0.5280584096908569,
   0.5086494088172913,
   0.5175948143005371,
   0.4922581613063812,
   0.5592440366744995,
   0.5186375975608826,
   0.4586865305900574,
   0.46907535195350647,
   0.4862392842769623],
  [0.3774715065956116,
   0.535921037197113,
   0.5061038136482239,
   0.5092968940734863,
   0.4850473701953888,
   0.5549939870834351,
   0.5482835173606873,
   0.4893377721309662,
   0.5029461979866028,
   0.47050967812538147],
  [0.3672855794429779,
   0.5468208193778992,
   0.501282811164856,
   0.49827826023101807,
   0.4723666310310364,
   0.535592257976532,
   0.5762588977813721,
   0.5382131934165955,
   0.5526970624923706,
   0.45590513944625854]]]

In [74]:
generator_model = dm.DynamicsModel(model_id="test_model", timesteps=1, load_checkpoint=False)


Loading RNN dynamics model...
d
Directory path for tensorboard summaries: ../tensorboard_logs/test_model/
Checkpoint directory path: ../checkpoints/test_model/
Model loaded.

In [70]:
preds = generator_model.predict(input_data_[:10,:1, :])

In [72]:
print preds[0]


[[0.5023449063301086, 0.4977347254753113, 0.4989054501056671, 0.4967387318611145, 0.5014218091964722, 0.5007814168930054, 0.49788951873779297, 0.5013884902000427, 0.5018420815467834, 0.49715477228164673]]

In [ ]: