Sketchbook to experiment with RNN models to model the dynamics. Once tested, the code in this notebook will be incorporated into functions/classes/ python modules.
Models the dynamics of the system p(next observation | history of observations, action)
History of observations consist of exercises a student has done and whether the student solved each of them
Action is the next exercise chosen
Next observation is whether the student gets the chosen exercise correct
We want to use an RNN to model the dynamics. Input data represents history of observations, of shape (n_students, n_timesteps, observation_vec_size)
Output represents the probability of getting next exercise correctly, of shape (n_students, n_timesteps, n_exercises)
So at each timestep, we make a prediction for all actions.
For each action, the output vector specifies the predicted probability of the student getting the chosen exercise correctly.
The target output only contains binary values.
In [16]:
import sys
print sys.executable
%load_ext autoreload
%autoreload 2
%reload_ext autoreload
/Users/lisa1010/tf_venv/bin/python
The autoreload extension is already loaded. To reload it, use:
%reload_ext autoreload
In [2]:
import sonnet as snt
import tensorflow as tf
import tflearn
import numpy as np
In [3]:
import dataset_utils
Load toy data set
In [5]:
data = dataset_utils.load_data(filename="../synthetic_data/toy.pickle")
input_data_, output_mask_, target_data_ = dataset_utils.preprocess_data_for_rnn(data)
In [6]:
tflearn.init_graph()
Out[6]:
gpu_options {
}
allow_soft_placement: true
In [9]:
n_hidden = 64
n_samples, n_timesteps, n_inputdim = input_data_.shape
_,_,n_outputdim = target_data_.shape
In [10]:
print n_timesteps
50
In [12]:
print n_inputdim
20
In [13]:
print n_outputdim
10
In [15]:
graph_to_use = tf.Graph()
with graph_to_use.as_default():
net = tflearn.input_data([None, n_timesteps, n_inputdim],dtype=tf.float32, name='input_data')
output_mask = tflearn.input_data([None, n_timesteps, n_outputdim], dtype=tf.float32, name='output_mask')
net = tflearn.lstm(net, n_hidden, return_seq=True, name="lstm_1")
net = tflearn.lstm(net, n_outputdim, return_seq=True, name="lstm_2")
net = tf.stack(net, axis=1)
preds = net
net = net * output_mask
net = tflearn.regression(net, optimizer='adam', learning_rate=0.001,
loss='mean_square')
model = tflearn.DNN(net, tensorboard_verbose=0)
model.fit([ input_data_, output_mask_], target_data_, validation_set=0.1)
---------------------------------
Run id: 2PTSBL
Log directory: /tmp/tflearn_logs/
WARNING:tensorflow:Error encountered when serializing layer_tensor/lstm_1.
Type is unsupported, or the types of the items don't match field type in CollectionDef.
'list' object has no attribute 'name'
WARNING:tensorflow:Error encountered when serializing layer_tensor/lstm_2.
Type is unsupported, or the types of the items don't match field type in CollectionDef.
'list' object has no attribute 'name'
---------------------------------
Training samples: 4
Validation samples: 1
--
Training Step: 1 | time: 2.116s
| Adam | epoch: 001 | loss: 0.00000 | val_loss: 0.04748 -- iter: 4/4
--
Training Step: 2 | total loss: 0.04031 | time: 1.067s
| Adam | epoch: 002 | loss: 0.04031 | val_loss: 0.04669 -- iter: 4/4
--
Training Step: 3 | total loss: 0.04343 | time: 1.065s
| Adam | epoch: 003 | loss: 0.04343 | val_loss: 0.04589 -- iter: 4/4
--
Training Step: 4 | total loss: 0.04341 | time: 1.061s
| Adam | epoch: 004 | loss: 0.04341 | val_loss: 0.04506 -- iter: 4/4
--
Training Step: 5 | total loss: 0.04293 | time: 1.060s
| Adam | epoch: 005 | loss: 0.04293 | val_loss: 0.04419 -- iter: 4/4
--
Training Step: 6 | total loss: 0.04231 | time: 1.056s
| Adam | epoch: 006 | loss: 0.04231 | val_loss: 0.04326 -- iter: 4/4
--
Training Step: 7 | total loss: 0.04168 | time: 1.061s
| Adam | epoch: 007 | loss: 0.04168 | val_loss: 0.04223 -- iter: 4/4
--
Training Step: 8 | total loss: 0.04099 | time: 1.055s
| Adam | epoch: 008 | loss: 0.04099 | val_loss: 0.04111 -- iter: 4/4
--
Training Step: 9 | total loss: 0.04024 | time: 1.059s
| Adam | epoch: 009 | loss: 0.04024 | val_loss: 0.03985 -- iter: 4/4
--
Training Step: 10 | total loss: 0.03943 | time: 1.053s
| Adam | epoch: 010 | loss: 0.03943 | val_loss: 0.03846 -- iter: 4/4
--
In [72]:
tf.get_collection(tf.GraphKeys.INPUTS)
Out[72]:
[<tf.Tensor 'output_mask:0' shape=(?, 50, 10) dtype=int32>,
<tf.Tensor 'InputData/X:0' shape=(?, 50, 20) dtype=float32>,
<tf.Tensor 'output_mask_1:0' shape=(?, 50, 10) dtype=int32>,
<tf.Tensor 'InputData_1/X:0' shape=(?, 50, 20) dtype=float32>,
<tf.Tensor 'output_mask_2:0' shape=(?, 50, 10) dtype=int32>,
<tf.Tensor 'InputData_2/X:0' shape=(?, 50, 20) dtype=float32>,
<tf.Tensor 'output_mask_3:0' shape=(?, 50, 10) dtype=float32>,
<tf.Tensor 'InputData_3/X:0' shape=(?, 50, 20) dtype=float32>,
<tf.Tensor 'input_data:0' shape=(?, 50, 20) dtype=float32>,
<tf.Tensor 'output_mask_4:0' shape=(?, 50, 10) dtype=float32>,
<tf.Tensor 'input_data_1:0' shape=(?, 50, 20) dtype=float32>,
<tf.Tensor 'output_mask_5:0' shape=(?, 50, 10) dtype=float32>,
<tf.Tensor 'input_data_2:0' shape=(?, 50, 20) dtype=float32>,
<tf.Tensor 'output_mask_6:0' shape=(?, 50, 10) dtype=float32>,
<tf.Tensor 'input_data_3:0' shape=(?, 50, 20) dtype=float32>,
<tf.Tensor 'output_mask_7:0' shape=(?, 50, 10) dtype=float32>,
<tf.Tensor 'input_data_4:0' shape=(?, 50, 20) dtype=float32>,
<tf.Tensor 'output_mask_8:0' shape=(?, 50, 10) dtype=float32>,
<tf.Tensor 'output_mask_9:0' shape=(?, 50, 10) dtype=float32>,
<tf.Tensor 'input_data_5/X:0' shape=(?, 50, 20) dtype=float32>,
<tf.Tensor 'input_data_6/X:0' shape=(?, 50, 20) dtype=float32>,
<tf.Tensor 'output_mask_10/X:0' shape=(?, 50, 10) dtype=float32>,
<tf.Tensor 'output_mask_9:0' shape=(?, 50, 10) dtype=float32>,
<tf.Tensor 'input_data_7/X:0' shape=(?, 50, 20) dtype=float32>,
<tf.Tensor 'output_mask_11/X:0' shape=(?, 50, 10) dtype=float32>,
<tf.Tensor 'output_mask_9:0' shape=(?, 50, 10) dtype=float32>]
In [17]:
data = dataset_utils.load_data(filename="../synthetic_data/1000stud_100seq_expert.pickle")
input_data_, output_mask_, target_data_ = dataset_utils.preprocess_data_for_rnn(data)
In [20]:
n_samples, n_timesteps, n_inputdim = input_data_.shape
_,_,n_outputdim = target_data_.shape
print n_samples
print n_timesteps
print n_inputdim
print n_outputdim
1000
100
20
10
In [22]:
graph_to_use = tf.Graph()
with graph_to_use.as_default():
net = tflearn.input_data([None, n_timesteps, n_inputdim],dtype=tf.float32, name='input_data')
output_mask = tflearn.input_data([None, n_timesteps, n_outputdim], dtype=tf.float32, name='output_mask')
net = tflearn.lstm(net, n_hidden, return_seq=True, name="lstm_1")
net = tflearn.lstm(net, n_outputdim, return_seq=True, name="lstm_2")
net = tf.stack(net, axis=1)
preds = net
net = net * output_mask
net = tflearn.regression(net, optimizer='adam', learning_rate=0.001,
loss='mean_square')
model = tflearn.DNN(net, tensorboard_verbose=2)
model.fit([ input_data_, output_mask_], target_data_, n_epoch=32, validation_set=0.1)
Training Step: 479 | total loss: 0.02480 | time: 2.358s
| Adam | epoch: 032 | loss: 0.02480 -- iter: 896/900
Training Step: 480 | total loss: 0.02479 | time: 3.469s
| Adam | epoch: 032 | loss: 0.02479 | val_loss: 0.02476 -- iter: 900/900
--
In [23]:
from dynamics_model import *
In [39]:
model = load_model(model_id="test_model", load_checkpoint=False, is_training=True)
Loading RNN dynamics model...
Directory path for tensorboard summaries: ../tensorboard_logs/test_model/
Checkpoint directory path: ../checkpoints/test_model/
Model loaded.
In [40]:
train_data = (input_data_[:,:10,:], output_mask_[:,:10,:], target_data_[:,:10,:])
In [41]:
train(model, train_data)
Training Step: 44 | total loss: 0.03538 | time: 0.681s
| Adam | epoch: 003 | loss: 0.03538 -- iter: 896/900
Training Step: 45 | total loss: 0.03482 | time: 1.726s
| Adam | epoch: 003 | loss: 0.03482 | val_loss: 0.03224 -- iter: 900/900
--
KeyboardInterruptTraceback (most recent call last)
<ipython-input-41-175a1436c9bb> in <module>()
----> 1 train(model, train_data)
/Users/lisa1010/dev/smart-tutor/code/dynamics_model.pyc in train(model, train_data, load_checkpoint)
95 date_time_string = datetime.datetime.now().strftime("%m-%d-%Y_%H-%M-%S")
96 run_id = "{}".format(date_time_string)
---> 97 model.fit([input_data, output_mask], output_data, n_epoch=64, validation_set=0.1)
98
99
/Users/lisa1010/tf_venv/lib/python2.7/site-packages/tflearn/models/dnn.pyc in fit(self, X_inputs, Y_targets, n_epoch, validation_set, show_metric, batch_size, shuffle, snapshot_epoch, snapshot_step, excl_trainops, validation_batch_size, run_id, callbacks)
213 excl_trainops=excl_trainops,
214 run_id=run_id,
--> 215 callbacks=callbacks)
216
217 def predict(self, X):
/Users/lisa1010/tf_venv/lib/python2.7/site-packages/tflearn/helpers/trainer.pyc in fit(self, feed_dicts, n_epoch, val_feed_dicts, show_metric, snapshot_step, snapshot_epoch, shuffle_all, dprep_dict, daug_dict, excl_trainops, run_id, callbacks)
344
345 # Epoch end
--> 346 caller.on_epoch_end(self.training_state)
347
348 finally:
/Users/lisa1010/tf_venv/lib/python2.7/site-packages/tflearn/callbacks.pyc in on_epoch_end(self, training_state)
78 def on_epoch_end(self, training_state):
79 for callback in self.callbacks:
---> 80 callback.on_epoch_end(training_state)
81
82 def on_train_end(self, training_state):
/Users/lisa1010/tf_venv/lib/python2.7/site-packages/tflearn/callbacks.pyc in on_epoch_end(self, training_state)
273 def on_epoch_end(self, training_state):
274 if self.snapshot_epoch:
--> 275 self.save(training_state.step)
276
277 def on_batch_begin(self, training_state):
/Users/lisa1010/tf_venv/lib/python2.7/site-packages/tflearn/callbacks.pyc in save(self, training_step)
302 def save(self, training_step=0):
303 if self.snapshot_path:
--> 304 self.save_func(self.snapshot_path, training_step)
305
306 def save_best(self, val_accuracy):
/Users/lisa1010/tf_venv/lib/python2.7/site-packages/tflearn/helpers/trainer.pyc in save(self, model_file, global_step)
371 if not os.path.isabs(model_file):
372 model_file = os.path.abspath(os.path.join(os.getcwd(), model_file))
--> 373 self.saver.save(self.session, model_file, global_step=global_step)
374 utils.fix_saver(obj_lists)
375
/Users/lisa1010/tf_venv/lib/python2.7/site-packages/tensorflow/python/training/saver.pyc in save(self, sess, save_path, global_step, latest_filename, meta_graph_suffix, write_meta_graph, write_state)
1373 checkpoint_file, meta_graph_suffix=meta_graph_suffix)
1374 with sess.graph.as_default():
-> 1375 self.export_meta_graph(meta_graph_filename)
1376
1377 if self._is_empty:
/Users/lisa1010/tf_venv/lib/python2.7/site-packages/tensorflow/python/training/saver.pyc in export_meta_graph(self, filename, collection_list, as_text, export_scope, clear_devices)
1401 return export_meta_graph(
1402 filename=filename,
-> 1403 graph_def=ops.get_default_graph().as_graph_def(add_shapes=True),
1404 saver_def=self.saver_def,
1405 collection_list=collection_list,
/Users/lisa1010/tf_venv/lib/python2.7/site-packages/tensorflow/python/framework/ops.pyc in as_graph_def(self, from_version, add_shapes)
2187 ValueError: If the `graph_def` would be too large.
2188 """
-> 2189 result, _ = self._as_graph_def(from_version, add_shapes)
2190 return result
2191
/Users/lisa1010/tf_venv/lib/python2.7/site-packages/tensorflow/python/framework/ops.pyc in _as_graph_def(self, from_version, add_shapes)
2146 if op.outputs and add_shapes:
2147 assert "_output_shapes" not in graph.node[-1].attr
-> 2148 graph.node[-1].attr["_output_shapes"].list.shape.extend([
2149 output.get_shape().as_proto() for output in op.outputs])
2150 bytesize += op.node_def.ByteSize()
/Users/lisa1010/tf_venv/lib/python2.7/site-packages/google/protobuf/internal/python_message.pyc in getter(self)
703 if field_value is None:
704 # Construct a new object to represent this field.
--> 705 field_value = field._default_constructor(self)
706
707 # Atomically check if another thread has preempted us and, if not, swap
/Users/lisa1010/tf_venv/lib/python2.7/site-packages/google/protobuf/internal/python_message.pyc in MakeSubMessageDefault(message)
424 result._SetListener(
425 _OneofListener(message, field)
--> 426 if field.containing_oneof is not None
427 else message._listener_for_children)
428 return result
/Users/lisa1010/tf_venv/lib/python2.7/site-packages/google/protobuf/internal/python_message.pyc in __init__(self, parent_message, field)
1407 field: The descriptor of the field being set in the parent message.
1408 """
-> 1409 super(_OneofListener, self).__init__(parent_message)
1410 self._field = field
1411
KeyboardInterrupt:
In [43]:
predict(model, input_data_[:10, :10, :])
Out[43]:
[[[-0.007385450880974531,
0.025550365447998047,
0.011329921893775463,
0.03704448416829109,
0.039401981979608536,
0.027618983760476112,
0.017678581178188324,
0.025421805679798126,
0.01569029688835144,
-0.02289711870253086],
[-0.014513563364744186,
0.0470220223069191,
0.040145840495824814,
0.09161123633384705,
0.10280773788690567,
0.06935103237628937,
0.05226275697350502,
0.041719403117895126,
0.04879506677389145,
-0.07392476499080658],
[-0.03218340128660202,
0.07135201245546341,
0.0685245469212532,
0.1436130404472351,
0.19665291905403137,
0.1438368409872055,
0.1049000471830368,
0.06255870312452316,
0.1026017889380455,
-0.11450488120317459],
[-0.06921423971652985,
0.08478009700775146,
0.1043052151799202,
0.19068250060081482,
0.2767428159713745,
0.23720687627792358,
0.17555317282676697,
0.10750668495893478,
0.16081872582435608,
-0.14396199584007263],
[-0.11939936876296997,
0.10138699412345886,
0.13752755522727966,
0.24032481014728546,
0.3501714766025543,
0.32600459456443787,
0.26092734932899475,
0.16472099721431732,
0.22535595297813416,
-0.17225711047649384],
[-0.17461079359054565,
0.1152421087026596,
0.1707264631986618,
0.2937884032726288,
0.4026610255241394,
0.4047226011753082,
0.3470379114151001,
0.2304689586162567,
0.2942657470703125,
-0.20128174126148224],
[-0.23769868910312653,
0.1092974990606308,
0.2082236111164093,
0.3569919466972351,
0.4412563443183899,
0.4894976019859314,
0.422652930021286,
0.32053667306900024,
0.35850971937179565,
-0.20820315182209015],
[-0.2928321361541748,
0.11726514250040054,
0.22756768763065338,
0.41967979073524475,
0.4618801772594452,
0.5404962301254272,
0.47515639662742615,
0.4064711928367615,
0.41023027896881104,
-0.218476802110672],
[-0.34365835785865784,
0.11928804963827133,
0.2321077436208725,
0.4715743064880371,
0.4840993285179138,
0.5689555406570435,
0.5109775066375732,
0.4563242495059967,
0.46420493721961975,
-0.2140037715435028],
[-0.38566485047340393,
0.1281179040670395,
0.24280832707881927,
0.5128179788589478,
0.48435699939727783,
0.6077094674110413,
0.544141948223114,
0.49809062480926514,
0.5067498683929443,
-0.2280215620994568]],
[[0.001626665354706347,
0.037479467689991,
0.019430629909038544,
0.03404674306511879,
0.012370062991976738,
0.03443584218621254,
0.023976050317287445,
0.009455711580812931,
0.026551280170679092,
-0.015678660944104195],
[-0.014866768382489681,
0.07031739503145218,
0.037209831178188324,
0.08887828141450882,
0.06364879757165909,
0.08181287348270416,
0.05092545226216316,
0.03311295807361603,
0.051180221140384674,
-0.04780920222401619],
[-0.03127816319465637,
0.09107029438018799,
0.07287603616714478,
0.15671797096729279,
0.1430317461490631,
0.14184020459651947,
0.09511943906545639,
0.050354160368442535,
0.09149425476789474,
-0.11039866507053375],
[-0.060976751148700714,
0.09801477193832397,
0.11476847529411316,
0.24379442632198334,
0.2399071305990219,
0.21377906203269958,
0.16034163534641266,
0.07652704417705536,
0.14679443836212158,
-0.17103984951972961],
[-0.10324911773204803,
0.11362592875957489,
0.1525913029909134,
0.31725165247917175,
0.34183502197265625,
0.3059535622596741,
0.24026139080524445,
0.1149774044752121,
0.22379063069820404,
-0.2214277982711792],
[-0.1607619673013687,
0.12293191254138947,
0.1840815544128418,
0.38961559534072876,
0.40708380937576294,
0.4095291197299957,
0.3294735550880432,
0.18316788971424103,
0.2974473237991333,
-0.26686352491378784],
[-0.2251933366060257,
0.129822239279747,
0.21352751553058624,
0.45145541429519653,
0.44791433215141296,
0.5020326375961304,
0.4140762388706207,
0.2700747847557068,
0.3686193525791168,
-0.3029167652130127],
[-0.2910039722919464,
0.1354573369026184,
0.24534107744693756,
0.4955673813819885,
0.47051355242729187,
0.5658157467842102,
0.4853389859199524,
0.3722583055496216,
0.43880945444107056,
-0.3212301731109619],
[-0.3583813011646271,
0.1330350786447525,
0.279346227645874,
0.5353736877441406,
0.49029308557510376,
0.6222624182701111,
0.5359885692596436,
0.47413772344589233,
0.4909423291683197,
-0.30825361609458923],
[-0.4008702039718628,
0.14655952155590057,
0.29526978731155396,
0.5669235587120056,
0.4961823523044586,
0.6476258635520935,
0.5633453726768494,
0.547645092010498,
0.5222209692001343,
-0.30059829354286194]],
[[0.001626665354706347,
0.037479467689991,
0.019430629909038544,
0.03404674306511879,
0.012370062991976738,
0.03443584218621254,
0.023976050317287445,
0.009455711580812931,
0.026551280170679092,
-0.015678660944104195],
[-0.005862789694219828,
0.08617755770683289,
0.04532817006111145,
0.0842113122344017,
0.03539608418941498,
0.08861597627401352,
0.05721341073513031,
0.0165015310049057,
0.06334442645311356,
-0.03913375735282898],
[-0.031976114958524704,
0.1187528744339943,
0.07019568234682083,
0.15418589115142822,
0.10436864197254181,
0.15436528623104095,
0.09397280961275101,
0.04188372194766998,
0.0956481471657753,
-0.08248429745435715],
[-0.058634355664253235,
0.1337626725435257,
0.11179366707801819,
0.23167745769023895,
0.1990850567817688,
0.22867818176746368,
0.14849594235420227,
0.06368935108184814,
0.1434764862060547,
-0.15573088824748993],
[-0.09788542985916138,
0.14540064334869385,
0.1553751826286316,
0.30297383666038513,
0.3119584918022156,
0.3226514458656311,
0.22293721139431,
0.0977112278342247,
0.21700194478034973,
-0.21283042430877686],
[-0.1556023359298706,
0.14597055315971375,
0.19833093881607056,
0.36874639987945557,
0.38667625188827515,
0.4158157408237457,
0.3109106123447418,
0.16846641898155212,
0.2929220497608185,
-0.2547641396522522],
[-0.22068890929222107,
0.15480518341064453,
0.23436932265758514,
0.4227451980113983,
0.4372327923774719,
0.4868674576282501,
0.3958335220813751,
0.2547925114631653,
0.36704960465431213,
-0.29151248931884766],
[-0.29112106561660767,
0.1480877697467804,
0.27615466713905334,
0.4806373417377472,
0.46889474987983704,
0.5629793405532837,
0.47029396891593933,
0.36612704396247864,
0.436046302318573,
-0.2990073561668396],
[-0.3489680290222168,
0.15215113759040833,
0.29507625102996826,
0.5293710827827454,
0.47888946533203125,
0.6004736423492432,
0.5143836140632629,
0.4624844491481781,
0.4735715985298157,
-0.3080124258995056],
[-0.4017753005027771,
0.16196505725383759,
0.3105895519256592,
0.564799427986145,
0.4860957860946655,
0.6261220574378967,
0.5451611280441284,
0.5394534468650818,
0.5054194927215576,
-0.312355101108551]],
[[0.001626665354706347,
0.037479467689991,
0.019430629909038544,
0.03404674306511879,
0.012370062991976738,
0.03443584218621254,
0.023976050317287445,
0.009455711580812931,
0.026551280170679092,
-0.015678660944104195],
[-0.014866768382489681,
0.07031739503145218,
0.037209831178188324,
0.08887828141450882,
0.06364879757165909,
0.08181287348270416,
0.05092545226216316,
0.03311295807361603,
0.051180221140384674,
-0.04780920222401619],
[-0.03420348837971687,
0.08495858311653137,
0.0664568692445755,
0.15616393089294434,
0.13940481841564178,
0.1283305585384369,
0.08701346069574356,
0.05158423259854317,
0.08342994749546051,
-0.08944535255432129],
[-0.060907602310180664,
0.09095901250839233,
0.10269273072481155,
0.232721745967865,
0.23196962475776672,
0.18219909071922302,
0.13328084349632263,
0.0728285014629364,
0.12778891623020172,
-0.14081640541553497],
[-0.09361746907234192,
0.0981183871626854,
0.14935459196567535,
0.3159990906715393,
0.3259636163711548,
0.25963741540908813,
0.20035579800605774,
0.10048488527536392,
0.1935679167509079,
-0.22298680245876312],
[-0.1449628323316574,
0.09888078272342682,
0.19733908772468567,
0.40551644563674927,
0.4001621603965759,
0.3491721451282501,
0.2830325663089752,
0.14919288456439972,
0.2731110155582428,
-0.2958981990814209],
[-0.20487850904464722,
0.10200666636228561,
0.24171984195709229,
0.4854142367839813,
0.44695594906806946,
0.4349038898944855,
0.36876046657562256,
0.21863026916980743,
0.3542357385158539,
-0.36073917150497437],
[-0.2662322223186493,
0.12092211842536926,
0.28500714898109436,
0.5353193879127502,
0.4847032427787781,
0.5113058090209961,
0.44373786449432373,
0.3023030459880829,
0.43968331813812256,
-0.4079033434391022],
[-0.3269027769565582,
0.13777515292167664,
0.3185170888900757,
0.5725964307785034,
0.49580419063568115,
0.5755017995834351,
0.5088567733764648,
0.40885812044143677,
0.5052441358566284,
-0.4334279000759125],
[-0.3791249692440033,
0.16812054812908173,
0.3467302620410919,
0.5862396955490112,
0.5050480365753174,
0.6146808862686157,
0.546154797077179,
0.49721837043762207,
0.5499317646026611,
-0.4526500999927521]],
[[-0.007385450880974531,
0.025550365447998047,
0.011329921893775463,
0.03704448416829109,
0.039401981979608536,
0.027618983760476112,
0.017678581178188324,
0.025421805679798126,
0.01569029688835144,
-0.02289711870253086],
[-0.016434617340564728,
0.04202163591980934,
0.03387784957885742,
0.09133366495370865,
0.09891992807388306,
0.05646282434463501,
0.04472542181611061,
0.0426873117685318,
0.04018104821443558,
-0.05347241833806038],
[-0.02927660197019577,
0.058802951127290726,
0.07028387486934662,
0.16001532971858978,
0.18033723533153534,
0.10849107801914215,
0.09034308046102524,
0.057852793484926224,
0.0845632553100586,
-0.11631344258785248],
[-0.057247862219810486,
0.08008474111557007,
0.10546114295721054,
0.2235705405473709,
0.2839057445526123,
0.19722367823123932,
0.15593916177749634,
0.0836041271686554,
0.15329524874687195,
-0.1659594476222992],
[-0.10434599965810776,
0.09520581364631653,
0.1370006948709488,
0.29061993956565857,
0.3641490042209625,
0.3098030984401703,
0.24149750173091888,
0.13650289177894592,
0.22524765133857727,
-0.21232809126377106],
[-0.167626291513443,
0.10245420038700104,
0.17349250614643097,
0.34856370091438293,
0.41738319396972656,
0.4118785858154297,
0.3323492109775543,
0.21695390343666077,
0.30281487107276917,
-0.2412136197090149],
[-0.24291224777698517,
0.10032673925161362,
0.208566814661026,
0.40867307782173157,
0.4561443626880646,
0.5027825236320496,
0.41346457600593567,
0.32292020320892334,
0.37091270089149475,
-0.24030572175979614],
[-0.3041929602622986,
0.11394774913787842,
0.2258629947900772,
0.46385255455970764,
0.47372329235076904,
0.5544580817222595,
0.4730626344680786,
0.4212020933628082,
0.42455315589904785,
-0.24243329465389252],
[-0.3572836220264435,
0.1321299821138382,
0.22877410054206848,
0.5150938034057617,
0.489890992641449,
0.5821978449821472,
0.508150041103363,
0.47752436995506287,
0.4684116542339325,
-0.25712013244628906],
[-0.40285348892211914,
0.1360579878091812,
0.2240927666425705,
0.5503410696983337,
0.507268488407135,
0.602586567401886,
0.5363503098487854,
0.5220065116882324,
0.5163953304290771,
-0.25548022985458374]],
[[0.001626665354706347,
0.037479467689991,
0.019430629909038544,
0.03404674306511879,
0.012370062991976738,
0.03443584218621254,
0.023976050317287445,
0.009455711580812931,
0.026551280170679092,
-0.015678660944104195],
[-0.005862789694219828,
0.08617755770683289,
0.04532817006111145,
0.0842113122344017,
0.03539608418941498,
0.08861597627401352,
0.05721341073513031,
0.0165015310049057,
0.06334442645311356,
-0.03913375735282898],
[-0.023123623803257942,
0.13784949481487274,
0.07789395749568939,
0.1472705751657486,
0.07509306073188782,
0.1604142040014267,
0.09974164515733719,
0.02444310672581196,
0.10856691002845764,
-0.07207928597927094],
[-0.05961983650922775,
0.16516649723052979,
0.10900723189115524,
0.22902508080005646,
0.1622152030467987,
0.2404518872499466,
0.14683611690998077,
0.05489376187324524,
0.1488243192434311,
-0.12612617015838623],
[-0.09663838893175125,
0.17179463803768158,
0.15472890436649323,
0.31104403734207153,
0.2644428312778473,
0.32147520780563354,
0.21147173643112183,
0.08512993156909943,
0.20477068424224854,
-0.20753706991672516],
[-0.1476699709892273,
0.16157129406929016,
0.2033478319644928,
0.4017667770385742,
0.3551658093929291,
0.40257886052131653,
0.2912360429763794,
0.13407102227210999,
0.2749353349208832,
-0.2794758379459381],
[-0.20422206819057465,
0.16578540205955505,
0.25083914399147034,
0.46711477637290955,
0.43075132369995117,
0.47971758246421814,
0.3764747381210327,
0.2016037255525589,
0.36287787556648254,
-0.33587899804115295],
[-0.26813051104545593,
0.1662045270204544,
0.28966930508613586,
0.5192147493362427,
0.46377167105674744,
0.5471265316009521,
0.45559510588645935,
0.3056398928165436,
0.44026315212249756,
-0.37334367632865906],
[-0.3387893736362457,
0.16086536645889282,
0.3295435309410095,
0.5598940849304199,
0.4894082844257355,
0.6116238236427307,
0.5161793231964111,
0.4252799153327942,
0.5001445412635803,
-0.37833917140960693],
[-0.38768890500068665,
0.1690857857465744,
0.351093590259552,
0.5892294645309448,
0.49744564294815063,
0.6454594731330872,
0.5497063398361206,
0.5187475085258484,
0.5375459790229797,
-0.3840266764163971]],
[[-0.007385450880974531,
0.025550365447998047,
0.011329921893775463,
0.03704448416829109,
0.039401981979608536,
0.027618983760476112,
0.017678581178188324,
0.025421805679798126,
0.01569029688835144,
-0.02289711870253086],
[-0.016434617340564728,
0.04202163591980934,
0.03387784957885742,
0.09133366495370865,
0.09891992807388306,
0.05646282434463501,
0.04472542181611061,
0.0426873117685318,
0.04018104821443558,
-0.05347241833806038],
[-0.02927660197019577,
0.058802951127290726,
0.07028387486934662,
0.16001532971858978,
0.18033723533153534,
0.10849107801914215,
0.09034308046102524,
0.057852793484926224,
0.0845632553100586,
-0.11631344258785248],
[-0.058248504996299744,
0.0690665915608406,
0.11151129007339478,
0.2467028647661209,
0.2727585434913635,
0.17967411875724792,
0.15757513046264648,
0.08372584730386734,
0.1453004628419876,
-0.1782950609922409],
[-0.10167629271745682,
0.08879223465919495,
0.14719465374946594,
0.31988584995269775,
0.3662298917770386,
0.2775735855102539,
0.23900474607944489,
0.12316052615642548,
0.22674421966075897,
-0.22890068590641022],
[-0.1624550074338913,
0.09984000772237778,
0.18285717070102692,
0.3849408030509949,
0.42069122195243835,
0.37895315885543823,
0.3279584050178528,
0.19788342714309692,
0.3057538866996765,
-0.26531723141670227],
[-0.23845139145851135,
0.10098498314619064,
0.2175826132297516,
0.4453183114528656,
0.4605388939380646,
0.47577205300331116,
0.40886789560317993,
0.3037096858024597,
0.3751377761363983,
-0.27267539501190186],
[-0.30577918887138367,
0.11194641143083572,
0.23216237127780914,
0.4968763589859009,
0.474517285823822,
0.530599057674408,
0.4657686650753021,
0.4047018885612488,
0.42059555649757385,
-0.27866390347480774],
[-0.3670070171356201,
0.12681278586387634,
0.24325476586818695,
0.5356999039649963,
0.48277804255485535,
0.5690165162086487,
0.507861852645874,
0.4898432195186615,
0.45843878388404846,
-0.2800208032131195],
[-0.4167822003364563,
0.14109325408935547,
0.250868558883667,
0.5658891201019287,
0.4877112805843353,
0.5969432592391968,
0.5366802215576172,
0.5557106733322144,
0.4856134355068207,
-0.2771853506565094]],
[[-0.007385450880974531,
0.025550365447998047,
0.011329921893775463,
0.03704448416829109,
0.039401981979608536,
0.027618983760476112,
0.017678581178188324,
0.025421805679798126,
0.01569029688835144,
-0.02289711870253086],
[-0.016434617340564728,
0.04202163591980934,
0.03387784957885742,
0.09133366495370865,
0.09891992807388306,
0.05646282434463501,
0.04472542181611061,
0.0426873117685318,
0.04018104821443558,
-0.05347241833806038],
[-0.02927660197019577,
0.058802951127290726,
0.07028387486934662,
0.16001532971858978,
0.18033723533153534,
0.10849107801914215,
0.09034308046102524,
0.057852793484926224,
0.0845632553100586,
-0.11631344258785248],
[-0.058248504996299744,
0.0690665915608406,
0.11151129007339478,
0.2467028647661209,
0.2727585434913635,
0.17967411875724792,
0.15757513046264648,
0.08372584730386734,
0.1453004628419876,
-0.1782950609922409],
[-0.10211649537086487,
0.07676362991333008,
0.15105488896369934,
0.34038570523262024,
0.3536534905433655,
0.26201990246772766,
0.23862825334072113,
0.12332990765571594,
0.21722912788391113,
-0.23994718492031097],
[-0.1574753373861313,
0.0961245521903038,
0.18575678765773773,
0.412643164396286,
0.42478057742118835,
0.35893985629081726,
0.32350555062294006,
0.17743146419525146,
0.3039356470108032,
-0.2907536029815674],
[-0.22027722001075745,
0.11041035503149033,
0.2142597734928131,
0.4794926047325134,
0.45950832962989807,
0.46049273014068604,
0.40720677375793457,
0.25972336530685425,
0.3760809004306793,
-0.3337879776954651],
[-0.2838820815086365,
0.12443418055772781,
0.24316547811031342,
0.528627872467041,
0.48025697469711304,
0.5452204942703247,
0.47800135612487793,
0.35232090950012207,
0.43860000371932983,
-0.3651861548423767],
[-0.34057822823524475,
0.13886594772338867,
0.27253884077072144,
0.5640837550163269,
0.4918398857116699,
0.607871949672699,
0.5323774814605713,
0.4408590495586395,
0.48914486169815063,
-0.3854181170463562],
[-0.38752374053001404,
0.15441176295280457,
0.3025898039340973,
0.5882860422134399,
0.49890464544296265,
0.6517040729522705,
0.5718855857849121,
0.5152390003204346,
0.5280877351760864,
-0.3967254161834717]],
[[0.001626665354706347,
0.037479467689991,
0.019430629909038544,
0.03404674306511879,
0.012370062991976738,
0.03443584218621254,
0.023976050317287445,
0.009455711580812931,
0.026551280170679092,
-0.015678660944104195],
[-0.014866768382489681,
0.07031739503145218,
0.037209831178188324,
0.08887828141450882,
0.06364879757165909,
0.08181287348270416,
0.05092545226216316,
0.03311295807361603,
0.051180221140384674,
-0.04780920222401619],
[-0.03127816319465637,
0.09107029438018799,
0.07287603616714478,
0.15671797096729279,
0.1430317461490631,
0.14184020459651947,
0.09511943906545639,
0.050354160368442535,
0.09149425476789474,
-0.11039866507053375],
[-0.059480901807546616,
0.11047469079494476,
0.10961417853832245,
0.22083410620689392,
0.2510247528553009,
0.23019792139530182,
0.15908633172512054,
0.07584737241268158,
0.15513871610164642,
-0.16032931208610535],
[-0.10637495666742325,
0.1215302050113678,
0.14279113709926605,
0.28813230991363525,
0.3402049243450165,
0.3377043902873993,
0.2429751753807068,
0.12786360085010529,
0.22257700562477112,
-0.2063126415014267],
[-0.16847631335258484,
0.12283267825841904,
0.18033909797668457,
0.34652021527290344,
0.40150222182273865,
0.43177521228790283,
0.33234745264053345,
0.2077411562204361,
0.29760900139808655,
-0.23532544076442719],
[-0.23343591392040253,
0.13489453494548798,
0.21264009177684784,
0.39830711483955383,
0.4422001242637634,
0.4987843334674835,
0.41326168179512024,
0.2947700619697571,
0.3689412474632263,
-0.26142585277557373],
[-0.3018285632133484,
0.1307247281074524,
0.2514607012271881,
0.4561883807182312,
0.4691040813922882,
0.5695794224739075,
0.48165783286094666,
0.3989560306072235,
0.43338218331336975,
-0.2604142129421234],
[-0.3521101176738739,
0.141281396150589,
0.271355003118515,
0.505827009677887,
0.48101359605789185,
0.6069961786270142,
0.5242564678192139,
0.48353955149650574,
0.476498544216156,
-0.26560845971107483],
[-0.39889639616012573,
0.14655019342899323,
0.2772831618785858,
0.5432493686676025,
0.5000923871994019,
0.6239126920700073,
0.5520086288452148,
0.5288835763931274,
0.5202540159225464,
-0.25520721077919006]],
[[-0.007385450880974531,
0.025550365447998047,
0.011329921893775463,
0.03704448416829109,
0.039401981979608536,
0.027618983760476112,
0.017678581178188324,
0.025421805679798126,
0.01569029688835144,
-0.02289711870253086],
[-0.016434617340564728,
0.04202163591980934,
0.03387784957885742,
0.09133366495370865,
0.09891992807388306,
0.05646282434463501,
0.04472542181611061,
0.0426873117685318,
0.04018104821443558,
-0.05347241833806038],
[-0.02927660197019577,
0.058802951127290726,
0.07028387486934662,
0.16001532971858978,
0.18033723533153534,
0.10849107801914215,
0.09034308046102524,
0.057852793484926224,
0.0845632553100586,
-0.11631344258785248],
[-0.057247862219810486,
0.08008474111557007,
0.10546114295721054,
0.2235705405473709,
0.2839057445526123,
0.19722367823123932,
0.15593916177749634,
0.0836041271686554,
0.15329524874687195,
-0.1659594476222992],
[-0.10634910315275192,
0.09204169362783432,
0.14448414742946625,
0.2828052043914795,
0.3588876724243164,
0.30093318223953247,
0.23928526043891907,
0.14087560772895813,
0.22557410597801208,
-0.20268292725086212],
[-0.16713985800743103,
0.10960502922534943,
0.17833401262760162,
0.3384683132171631,
0.415674090385437,
0.39195284247398376,
0.3292734920978546,
0.2125479131937027,
0.299818217754364,
-0.23643545806407928],
[-0.23756016790866852,
0.10867643356323242,
0.2169070541858673,
0.4016650319099426,
0.45342424511909485,
0.4848024845123291,
0.4118563234806061,
0.31293436884880066,
0.3700428307056427,
-0.24362049996852875],
[-0.3006177842617035,
0.11638326942920685,
0.2345806509256363,
0.4584605097770691,
0.4678575098514557,
0.536447286605835,
0.468919575214386,
0.4064263701438904,
0.4150925278663635,
-0.25325480103492737],
[-0.3589994013309479,
0.12805280089378357,
0.2480306476354599,
0.5039001703262329,
0.47728273272514343,
0.5728580355644226,
0.5102958679199219,
0.48629629611968994,
0.4536206126213074,
-0.2593274414539337],
[-0.40764355659484863,
0.13997741043567657,
0.2571384608745575,
0.54050213098526,
0.4833797812461853,
0.5996587872505188,
0.5384678840637207,
0.5493736267089844,
0.4821021854877472,
-0.2614176869392395]]]
In [ ]:
In [52]:
import dynamics_model_class as dm
In [53]:
dmodel = dm.DynamicsModel(model_id="test_model", load_checkpoint=False)
Loading RNN dynamics model...
Directory path for tensorboard summaries: ../tensorboard_logs/test_model/
Checkpoint directory path: ../checkpoints/test_model/
Model loaded.
In [56]:
dmodel.train(train_data)
Training Step: 1019 | total loss: 0.02487 | time: 0.567s
| Adam | epoch: 068 | loss: 0.02487 -- iter: 896/900
KeyboardInterruptTraceback (most recent call last)
<ipython-input-56-56596559dd51> in <module>()
----> 1 dmodel.train(train_data)
/Users/lisa1010/dev/smart-tutor/code/dynamics_model_class.pyc in train(self, train_data, load_checkpoint)
86 date_time_string = datetime.datetime.now().strftime("%m-%d-%Y_%H-%M-%S")
87 run_id = "{}".format(date_time_string)
---> 88 self.model.fit([input_data, output_mask], output_data, n_epoch=64, validation_set=0.1)
89
90
/Users/lisa1010/tf_venv/lib/python2.7/site-packages/tflearn/models/dnn.pyc in fit(self, X_inputs, Y_targets, n_epoch, validation_set, show_metric, batch_size, shuffle, snapshot_epoch, snapshot_step, excl_trainops, validation_batch_size, run_id, callbacks)
213 excl_trainops=excl_trainops,
214 run_id=run_id,
--> 215 callbacks=callbacks)
216
217 def predict(self, X):
/Users/lisa1010/tf_venv/lib/python2.7/site-packages/tflearn/helpers/trainer.pyc in fit(self, feed_dicts, n_epoch, val_feed_dicts, show_metric, snapshot_step, snapshot_epoch, shuffle_all, dprep_dict, daug_dict, excl_trainops, run_id, callbacks)
331 (bool(self.best_checkpoint_path) | snapshot_epoch),
332 snapshot_step,
--> 333 show_metric)
334
335 # Update training state
/Users/lisa1010/tf_venv/lib/python2.7/site-packages/tflearn/helpers/trainer.pyc in _train(self, training_step, snapshot_epoch, snapshot_step, show_metric)
801 if show_metric and self.metric is not None:
802 eval_ops.append(self.metric)
--> 803 e = evaluate_flow(self.session, eval_ops, self.test_dflow)
804 self.val_loss = e[0]
805 if show_metric and self.metric is not None:
/Users/lisa1010/tf_venv/lib/python2.7/site-packages/tflearn/helpers/trainer.pyc in evaluate_flow(session, ops_to_evaluate, dataflow)
941 for i in range(len(r)):
942 res[i] += r[i] * current_batch_size
--> 943 feed_batch = dataflow.next()
944 res = [r / dataflow.n_samples for r in res]
945 return res
/Users/lisa1010/tf_venv/lib/python2.7/site-packages/tflearn/data_flow.pyc in next(self, timeout)
127 """
128 self.data_status.update()
--> 129 return self.feed_dict_queue.get(timeout=timeout)
130
131 def start(self, reset_status=True):
/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/Queue.pyc in get(self, block, timeout)
166 elif timeout is None:
167 while not self._qsize():
--> 168 self.not_empty.wait()
169 elif timeout < 0:
170 raise ValueError("'timeout' must be a non-negative number")
/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/threading.pyc in wait(self, timeout)
338 try: # restore state no matter what (e.g., KeyboardInterrupt)
339 if timeout is None:
--> 340 waiter.acquire()
341 if __debug__:
342 self._note("%s.wait(): got it", self)
KeyboardInterrupt:
In [57]:
preds = dmodel.predict(input_data_[:10,:10, :])
Out[57]:
[[[0.49599671363830566,
0.505977988243103,
0.49741360545158386,
0.5000333189964294,
0.498515248298645,
0.5003623366355896,
0.5020152926445007,
0.4968193471431732,
0.5021724104881287,
0.49825751781463623],
[0.48964592814445496,
0.5083518624305725,
0.49532902240753174,
0.5052512884140015,
0.49550819396972656,
0.5047191381454468,
0.49943220615386963,
0.49179941415786743,
0.5016341805458069,
0.4988301992416382],
[0.48247480392456055,
0.5100704431533813,
0.4926689863204956,
0.5104095339775085,
0.4944813549518585,
0.5111148357391357,
0.49826157093048096,
0.4853592813014984,
0.4960009455680847,
0.49970728158950806],
[0.4703198969364166,
0.5135535001754761,
0.4915870130062103,
0.5127080678939819,
0.49413204193115234,
0.5221576690673828,
0.502672016620636,
0.47870734333992004,
0.4900960922241211,
0.5019947290420532],
[0.4537630081176758,
0.520048201084137,
0.49032503366470337,
0.5139161944389343,
0.4934532642364502,
0.5316247344017029,
0.5167728662490845,
0.4716108441352844,
0.4958451986312866,
0.4966897964477539],
[0.43617871403694153,
0.5289715528488159,
0.48939430713653564,
0.5126253962516785,
0.4925546646118164,
0.5344418883323669,
0.5401607155799866,
0.4698314368724823,
0.5158414244651794,
0.4840362071990967],
[0.42516717314720154,
0.5388312339782715,
0.48898425698280334,
0.508516788482666,
0.49308663606643677,
0.5259639620780945,
0.5653541088104248,
0.4776161313056946,
0.5385245084762573,
0.46874111890792847],
[0.4183195233345032,
0.5485021471977234,
0.4903010129928589,
0.5027713179588318,
0.48928093910217285,
0.5091527700424194,
0.5853952765464783,
0.4925719201564789,
0.557790994644165,
0.456940621137619],
[0.41411229968070984,
0.559482753276825,
0.49123120307922363,
0.497923344373703,
0.4830032289028168,
0.48612239956855774,
0.6010857820510864,
0.5239127278327942,
0.5803111791610718,
0.44925323128700256],
[0.4249594509601593,
0.569526195526123,
0.4904220998287201,
0.4949359595775604,
0.46843722462654114,
0.45716696977615356,
0.6109771728515625,
0.5648123025894165,
0.5905900597572327,
0.4446971118450165]],
[[0.4935073256492615,
0.5066367387771606,
0.5006406307220459,
0.4931355118751526,
0.5011196732521057,
0.5027900338172913,
0.5059390664100647,
0.5037103891372681,
0.5045592188835144,
0.49498361349105835],
[0.4888042211532593,
0.5124666690826416,
0.49886104464530945,
0.48768845200538635,
0.502082884311676,
0.5015733242034912,
0.5139681100845337,
0.5072578191757202,
0.5110902190208435,
0.4893607795238495],
[0.4837706387042999,
0.5132308006286621,
0.4966180622577667,
0.4896356761455536,
0.5010759830474854,
0.5032262802124023,
0.5173490047454834,
0.5093849897384644,
0.5153816342353821,
0.48739975690841675],
[0.47697913646698,
0.5159948468208313,
0.4962058365345001,
0.5016492605209351,
0.5036230087280273,
0.5028076767921448,
0.5183938145637512,
0.5096410512924194,
0.5118613839149475,
0.48518186807632446],
[0.47057032585144043,
0.5163217782974243,
0.49360454082489014,
0.5125990509986877,
0.5019816160202026,
0.5014887452125549,
0.5191658139228821,
0.5111626386642456,
0.5062094330787659,
0.483379989862442],
[0.4688839614391327,
0.5220568776130676,
0.49056243896484375,
0.5243667960166931,
0.49032846093177795,
0.4946741461753845,
0.5149816870689392,
0.5242074728012085,
0.5099191665649414,
0.47194865345954895],
[0.47019073367118835,
0.5310572981834412,
0.4862220287322998,
0.5377153158187866,
0.47338351607322693,
0.4833642244338989,
0.4996504783630371,
0.5456037521362305,
0.5182302594184875,
0.456155389547348],
[0.469103068113327,
0.5361096262931824,
0.47985541820526123,
0.5500144362449646,
0.4661874771118164,
0.475056529045105,
0.4756297767162323,
0.5621740221977234,
0.5216403603553772,
0.4482935965061188],
[0.4695298969745636,
0.5345526933670044,
0.4739120900630951,
0.5572773814201355,
0.4666861593723297,
0.4686569571495056,
0.4480878710746765,
0.5711529850959778,
0.522301435470581,
0.44190213084220886],
[0.468523770570755,
0.525304913520813,
0.4741703271865845,
0.5627244710922241,
0.4680138826370239,
0.4681204557418823,
0.42252233624458313,
0.5644826889038086,
0.5175504088401794,
0.4397382140159607]],
[[0.4935073256492615,
0.5066367387771606,
0.5006406307220459,
0.4931355118751526,
0.5011196732521057,
0.5027900338172913,
0.5059390664100647,
0.5037103891372681,
0.5045592188835144,
0.49498361349105835],
[0.4864151179790497,
0.5134059190750122,
0.5013522505760193,
0.4798593819141388,
0.5054095983505249,
0.503986120223999,
0.5183268785476685,
0.5136097073554993,
0.5136396884918213,
0.4860149323940277],
[0.48251864314079285,
0.5177899599075317,
0.49937862157821655,
0.4698704779148102,
0.5096774697303772,
0.4998442530632019,
0.5328044295310974,
0.5246942043304443,
0.5241942405700684,
0.4782741963863373],
[0.48027995228767395,
0.5149013996124268,
0.49656760692596436,
0.46929213404655457,
0.511268675327301,
0.4979315996170044,
0.5424811840057373,
0.5347463488578796,
0.5325235724449158,
0.4752316176891327],
[0.48054632544517517,
0.5101013779640198,
0.49272894859313965,
0.4788307845592499,
0.5128154158592224,
0.495523601770401,
0.5500748753547668,
0.5437270998954773,
0.5377021431922913,
0.473631352186203],
[0.47937169671058655,
0.5055421590805054,
0.48871538043022156,
0.49055591225624084,
0.5119155049324036,
0.497022807598114,
0.5583524703979492,
0.5503260493278503,
0.5420855283737183,
0.47418126463890076],
[0.47849005460739136,
0.5042028427124023,
0.48473697900772095,
0.5034549832344055,
0.5114812254905701,
0.4919143319129944,
0.5679455995559692,
0.5568810105323792,
0.552671492099762,
0.4691372215747833],
[0.47729960083961487,
0.5018625855445862,
0.4813368618488312,
0.5148546695709229,
0.5120936632156372,
0.482450932264328,
0.5753395557403564,
0.5665517449378967,
0.5611532926559448,
0.4625680148601532],
[0.4716723561286926,
0.4921557605266571,
0.4792180359363556,
0.5260759592056274,
0.5021523833274841,
0.4689690172672272,
0.5777547359466553,
0.5858073830604553,
0.5734472870826721,
0.4575400948524475],
[0.46615177392959595,
0.47498705983161926,
0.47839879989624023,
0.5374900102615356,
0.48372694849967957,
0.4545191824436188,
0.5819700360298157,
0.609631359577179,
0.585426926612854,
0.4527510702610016]],
[[0.4935073256492615,
0.5066367387771606,
0.5006406307220459,
0.4931355118751526,
0.5011196732521057,
0.5027900338172913,
0.5059390664100647,
0.5037103891372681,
0.5045592188835144,
0.49498361349105835],
[0.4888042211532593,
0.5124666690826416,
0.49886104464530945,
0.48768845200538635,
0.502082884311676,
0.5015733242034912,
0.5139681100845337,
0.5072578191757202,
0.5110902190208435,
0.4893607795238495],
[0.4929240345954895,
0.516448974609375,
0.4938470125198364,
0.49369972944259644,
0.5044259428977966,
0.4983064830303192,
0.5112334489822388,
0.5068015456199646,
0.5119608640670776,
0.4841019809246063],
[0.5007560849189758,
0.5182065367698669,
0.4872228801250458,
0.5079019665718079,
0.5078688263893127,
0.49604564905166626,
0.4969274699687958,
0.5014709830284119,
0.5072579383850098,
0.48117199540138245],
[0.5020344257354736,
0.5155802965164185,
0.48553648591041565,
0.521838903427124,
0.5078266859054565,
0.5006723403930664,
0.4766537845134735,
0.49318230152130127,
0.4999012053012848,
0.4852582514286041],
[0.4984493851661682,
0.5138178467750549,
0.4931369125843048,
0.5304259061813354,
0.5104289650917053,
0.5083906054496765,
0.4594757854938507,
0.48229333758354187,
0.483430951833725,
0.49318283796310425],
[0.48839783668518066,
0.5124683380126953,
0.5079970359802246,
0.5351042151451111,
0.5117958188056946,
0.5207926630973816,
0.4490090012550354,
0.4710412621498108,
0.4594714045524597,
0.5025239586830139],
[0.4691604673862457,
0.5104724168777466,
0.5210692286491394,
0.5382067561149597,
0.5093942284584045,
0.5367593765258789,
0.44400638341903687,
0.4590891897678375,
0.4302017092704773,
0.513044536113739],
[0.43615126609802246,
0.5113045573234558,
0.5329358577728271,
0.5375150442123413,
0.5057876706123352,
0.558469295501709,
0.44797810912132263,
0.44826582074165344,
0.4029180407524109,
0.5220268368721008],
[0.39566367864608765,
0.5144246220588684,
0.5379319787025452,
0.5362714529037476,
0.5033884644508362,
0.5835130214691162,
0.4657069146633148,
0.4361811578273773,
0.38456860184669495,
0.5213557481765747]],
[[0.49599671363830566,
0.505977988243103,
0.49741360545158386,
0.5000333189964294,
0.498515248298645,
0.5003623366355896,
0.5020152926445007,
0.4968193471431732,
0.5021724104881287,
0.49825751781463623],
[0.4982739984989166,
0.5114309191703796,
0.4927102029323578,
0.5087923407554626,
0.4982874393463135,
0.5002113580703735,
0.4950408637523651,
0.49080711603164673,
0.4985194206237793,
0.495974600315094],
[0.495180606842041,
0.5128109455108643,
0.49134019017219543,
0.5186378359794617,
0.4964984357357025,
0.505379319190979,
0.4830396771430969,
0.4827333390712738,
0.4921058714389801,
0.49838364124298096],
[0.488839328289032,
0.5130492448806763,
0.4918130338191986,
0.5242824554443359,
0.4961698055267334,
0.5140914916992188,
0.47307127714157104,
0.47317343950271606,
0.4796505272388458,
0.5027008652687073],
[0.48134538531303406,
0.5185043215751648,
0.4974914789199829,
0.5281264185905457,
0.49063757061958313,
0.5196006894111633,
0.46588170528411865,
0.4708947241306305,
0.4721260964870453,
0.49853694438934326],
[0.46692249178886414,
0.5220759510993958,
0.5048929452896118,
0.5288379788398743,
0.4898432195186615,
0.5277711153030396,
0.4598439633846283,
0.46979770064353943,
0.4605463743209839,
0.4996195435523987],
[0.44905009865760803,
0.52361661195755,
0.5140159130096436,
0.5291934609413147,
0.4894247353076935,
0.5352762937545776,
0.45834168791770935,
0.4670118987560272,
0.44874122738838196,
0.496447890996933],
[0.42665520310401917,
0.5251380205154419,
0.5297182202339172,
0.5279824733734131,
0.48798325657844543,
0.5414242148399353,
0.4594918489456177,
0.46260207891464233,
0.43408021330833435,
0.49125513434410095],
[0.4031164348125458,
0.5257107615470886,
0.5459632873535156,
0.5304495096206665,
0.4891797602176666,
0.5490188598632812,
0.4588281512260437,
0.44747695326805115,
0.4023078680038452,
0.4877915680408478],
[0.3729502260684967,
0.5255358219146729,
0.5609372854232788,
0.5321497917175293,
0.48992788791656494,
0.5642527341842651,
0.46293866634368896,
0.4378302991390228,
0.3744458854198456,
0.4855373203754425]],
[[0.4935073256492615,
0.5066367387771606,
0.5006406307220459,
0.4931355118751526,
0.5011196732521057,
0.5027900338172913,
0.5059390664100647,
0.5037103891372681,
0.5045592188835144,
0.49498361349105835],
[0.4864151179790497,
0.5134059190750122,
0.5013522505760193,
0.4798593819141388,
0.5054095983505249,
0.503986120223999,
0.5183268785476685,
0.5136097073554993,
0.5136396884918213,
0.4860149323940277],
[0.4804767370223999,
0.5187111496925354,
0.5010325312614441,
0.4613550305366516,
0.5138576626777649,
0.5021450519561768,
0.5373328924179077,
0.5301012396812439,
0.5266373753547668,
0.4751938283443451],
[0.4786783754825592,
0.5192803740501404,
0.49868643283843994,
0.4485275149345398,
0.5217024087905884,
0.4942969083786011,
0.5571663975715637,
0.548382043838501,
0.5394854545593262,
0.4674091041088104],
[0.4800170063972473,
0.5085130929946899,
0.4956192672252655,
0.44785237312316895,
0.5255739688873291,
0.4887203872203827,
0.5711811184883118,
0.5651856660842896,
0.5490503311157227,
0.46478331089019775],
[0.48153260350227356,
0.49942296743392944,
0.4930483400821686,
0.4832015037536621,
0.5232486724853516,
0.47565048933029175,
0.5736194252967834,
0.5832151174545288,
0.5526918172836304,
0.46285244822502136],
[0.4864833652973175,
0.48481810092926025,
0.48882997035980225,
0.5194024443626404,
0.513222873210907,
0.46280092000961304,
0.5727747678756714,
0.5997238159179688,
0.5543747544288635,
0.45990189909935],
[0.4895656704902649,
0.4684159457683563,
0.4830937385559082,
0.5518429279327393,
0.4951283931732178,
0.4554482400417328,
0.5680369734764099,
0.6143181324005127,
0.5561201572418213,
0.4577709138393402],
[0.4911474585533142,
0.4511832296848297,
0.4759911000728607,
0.5732237100601196,
0.48121845722198486,
0.4484676420688629,
0.5565111637115479,
0.6265982389450073,
0.5558149814605713,
0.4528338611125946],
[0.4923345148563385,
0.4308885931968689,
0.4694708287715912,
0.5909002423286438,
0.46782487630844116,
0.4447915256023407,
0.5333371758460999,
0.6277155876159668,
0.553862452507019,
0.4495563507080078]],
[[0.49599671363830566,
0.505977988243103,
0.49741360545158386,
0.5000333189964294,
0.498515248298645,
0.5003623366355896,
0.5020152926445007,
0.4968193471431732,
0.5021724104881287,
0.49825751781463623],
[0.4982739984989166,
0.5114309191703796,
0.4927102029323578,
0.5087923407554626,
0.4982874393463135,
0.5002113580703735,
0.4950408637523651,
0.49080711603164673,
0.4985194206237793,
0.495974600315094],
[0.495180606842041,
0.5128109455108643,
0.49134019017219543,
0.5186378359794617,
0.4964984357357025,
0.505379319190979,
0.4830396771430969,
0.4827333390712738,
0.4921058714389801,
0.49838364124298096],
[0.48699524998664856,
0.514747679233551,
0.4969709813594818,
0.5244477987289429,
0.5002292394638062,
0.5135923027992249,
0.47346407175064087,
0.4732531011104584,
0.47692635655403137,
0.5020098090171814],
[0.4732399582862854,
0.5144741535186768,
0.501699686050415,
0.5289419889450073,
0.5003151893615723,
0.5248336791992188,
0.46726882457733154,
0.46334242820739746,
0.4561832547187805,
0.5071861743927002],
[0.44964221119880676,
0.5158843994140625,
0.5079138875007629,
0.5293943881988525,
0.4989471137523651,
0.5415902733802795,
0.4693746268749237,
0.45465245842933655,
0.43583324551582336,
0.5127857327461243],
[0.4221138060092926,
0.5176545977592468,
0.5147653222084045,
0.5295987129211426,
0.49693167209625244,
0.5577367544174194,
0.47876542806625366,
0.4452420175075531,
0.41869810223579407,
0.5082810521125793],
[0.38811028003692627,
0.5214492082595825,
0.5184506177902222,
0.5253399014472961,
0.49174442887306213,
0.5693498253822327,
0.5011019706726074,
0.45353057980537415,
0.4164471924304962,
0.4968683123588562],
[0.36234354972839355,
0.5278410911560059,
0.5153379440307617,
0.5202239751815796,
0.48276323080062866,
0.572359561920166,
0.5301076769828796,
0.4818229079246521,
0.4330885708332062,
0.4801550805568695],
[0.35057127475738525,
0.5377871990203857,
0.5082552433013916,
0.5131154656410217,
0.4674236476421356,
0.5599708557128906,
0.5598285794258118,
0.533362865447998,
0.4811905324459076,
0.4621477723121643]],
[[0.49599671363830566,
0.505977988243103,
0.49741360545158386,
0.5000333189964294,
0.498515248298645,
0.5003623366355896,
0.5020152926445007,
0.4968193471431732,
0.5021724104881287,
0.49825751781463623],
[0.4982739984989166,
0.5114309191703796,
0.4927102029323578,
0.5087923407554626,
0.4982874393463135,
0.5002113580703735,
0.4950408637523651,
0.49080711603164673,
0.4985194206237793,
0.495974600315094],
[0.495180606842041,
0.5128109455108643,
0.49134019017219543,
0.5186378359794617,
0.4964984357357025,
0.505379319190979,
0.4830396771430969,
0.4827333390712738,
0.4921058714389801,
0.49838364124298096],
[0.48699524998664856,
0.514747679233551,
0.4969709813594818,
0.5244477987289429,
0.5002292394638062,
0.5135923027992249,
0.47346407175064087,
0.4732531011104584,
0.47692635655403137,
0.5020098090171814],
[0.47155246138572693,
0.5163626074790955,
0.5073354840278625,
0.5284625291824341,
0.5036620497703552,
0.5252302885055542,
0.4685995578765869,
0.4637703597545624,
0.4550075829029083,
0.505318284034729],
[0.44791990518569946,
0.5157428979873657,
0.5143195390701294,
0.5320842862129211,
0.5027382373809814,
0.5384891629219055,
0.46804317831993103,
0.4545551836490631,
0.4287174344062805,
0.509219229221344],
[0.4274927079677582,
0.5203395485877991,
0.5206096172332764,
0.5363749265670776,
0.4941402077674866,
0.5440685153007507,
0.47415691614151,
0.456539124250412,
0.4133474826812744,
0.49973049759864807],
[0.41368141770362854,
0.5285932421684265,
0.5237830281257629,
0.5407341122627258,
0.47680577635765076,
0.5372069478034973,
0.4789503216743469,
0.47505199909210205,
0.40986666083335876,
0.4800032675266266],
[0.40893372893333435,
0.540021538734436,
0.5230647921562195,
0.5466728806495667,
0.453422874212265,
0.5180574059486389,
0.4767017364501953,
0.5112937092781067,
0.4170326292514801,
0.4565376937389374],
[0.412627249956131,
0.5536556839942932,
0.5195404291152954,
0.5547642111778259,
0.4317971467971802,
0.49344614148139954,
0.4625207781791687,
0.5525162220001221,
0.4331713318824768,
0.43483832478523254]],
[[0.4935073256492615,
0.5066367387771606,
0.5006406307220459,
0.4931355118751526,
0.5011196732521057,
0.5027900338172913,
0.5059390664100647,
0.5037103891372681,
0.5045592188835144,
0.49498361349105835],
[0.4888042211532593,
0.5124666690826416,
0.49886104464530945,
0.48768845200538635,
0.502082884311676,
0.5015733242034912,
0.5139681100845337,
0.5072578191757202,
0.5110902190208435,
0.4893607795238495],
[0.4837706387042999,
0.5132308006286621,
0.4966180622577667,
0.4896356761455536,
0.5010759830474854,
0.5032262802124023,
0.5173490047454834,
0.5093849897384644,
0.5153816342353821,
0.48739975690841675],
[0.4796925485134125,
0.5129684805870056,
0.49323591589927673,
0.495869517326355,
0.501282811164856,
0.50560462474823,
0.5209391713142395,
0.5099912881851196,
0.5158118009567261,
0.4864618182182312],
[0.4783409535884857,
0.5181713700294495,
0.49011996388435364,
0.503265917301178,
0.49712297320365906,
0.5029231905937195,
0.5176870822906494,
0.5175051093101501,
0.5209943652153015,
0.4768555462360382],
[0.47424039244651794,
0.5219233632087708,
0.4868907630443573,
0.5113959312438965,
0.49551692605018616,
0.5024139285087585,
0.5127443671226501,
0.5239893794059753,
0.5231109261512756,
0.4731998145580292],
[0.4707843065261841,
0.5274010300636292,
0.4835318326950073,
0.5203237533569336,
0.4962611496448517,
0.4970822036266327,
0.5096275210380554,
0.527242124080658,
0.5329503417015076,
0.46647271513938904],
[0.4672967195510864,
0.5302364826202393,
0.4804839491844177,
0.525964617729187,
0.49858424067497253,
0.4899025857448578,
0.5045422911643982,
0.5310878157615662,
0.5402598977088928,
0.4598078727722168],
[0.46306389570236206,
0.5290700197219849,
0.4814821779727936,
0.529658317565918,
0.49663886427879333,
0.48454758524894714,
0.49519726634025574,
0.5276616215705872,
0.5427936315536499,
0.45509234070777893],
[0.4558480978012085,
0.53023362159729,
0.4821963608264923,
0.5352409482002258,
0.4964573383331299,
0.4807736575603485,
0.4828425943851471,
0.5261974334716797,
0.5431912541389465,
0.4511813819408417]],
[[0.49599671363830566,
0.505977988243103,
0.49741360545158386,
0.5000333189964294,
0.498515248298645,
0.5003623366355896,
0.5020152926445007,
0.4968193471431732,
0.5021724104881287,
0.49825751781463623],
[0.4982739984989166,
0.5114309191703796,
0.4927102029323578,
0.5087923407554626,
0.4982874393463135,
0.5002113580703735,
0.4950408637523651,
0.49080711603164673,
0.4985194206237793,
0.495974600315094],
[0.495180606842041,
0.5128109455108643,
0.49134019017219543,
0.5186378359794617,
0.4964984357357025,
0.505379319190979,
0.4830396771430969,
0.4827333390712738,
0.4921058714389801,
0.49838364124298096],
[0.488839328289032,
0.5130492448806763,
0.4918130338191986,
0.5242824554443359,
0.4961698055267334,
0.5140914916992188,
0.47307127714157104,
0.47317343950271606,
0.4796505272388458,
0.5027008652687073],
[0.47480571269989014,
0.5146147012710571,
0.49649596214294434,
0.5256413817405701,
0.4959128499031067,
0.5280333161354065,
0.4700377285480499,
0.46416711807250977,
0.4655117690563202,
0.5094239115715027],
[0.4534894526004791,
0.5187195539474487,
0.5011202096939087,
0.5259437561035156,
0.49559882283210754,
0.5434014797210693,
0.4788511395454407,
0.45399218797683716,
0.4594345688819885,
0.509539783000946],
[0.4300541579723358,
0.5228154063224792,
0.5059940814971924,
0.5240729451179504,
0.4956946074962616,
0.5538676977157593,
0.493369996547699,
0.4469212293624878,
0.45663511753082275,
0.4996611475944519],
[0.4003496468067169,
0.5280584096908569,
0.5086494088172913,
0.5175948143005371,
0.4922581613063812,
0.5592440366744995,
0.5186375975608826,
0.4586865305900574,
0.46907535195350647,
0.4862392842769623],
[0.3774715065956116,
0.535921037197113,
0.5061038136482239,
0.5092968940734863,
0.4850473701953888,
0.5549939870834351,
0.5482835173606873,
0.4893377721309662,
0.5029461979866028,
0.47050967812538147],
[0.3672855794429779,
0.5468208193778992,
0.501282811164856,
0.49827826023101807,
0.4723666310310364,
0.535592257976532,
0.5762588977813721,
0.5382131934165955,
0.5526970624923706,
0.45590513944625854]]]
In [74]:
generator_model = dm.DynamicsModel(model_id="test_model", timesteps=1, load_checkpoint=False)
Loading RNN dynamics model...
d
Directory path for tensorboard summaries: ../tensorboard_logs/test_model/
Checkpoint directory path: ../checkpoints/test_model/
Model loaded.
In [70]:
preds = generator_model.predict(input_data_[:10,:1, :])
In [72]:
print preds[0]
[[0.5023449063301086, 0.4977347254753113, 0.4989054501056671, 0.4967387318611145, 0.5014218091964722, 0.5007814168930054, 0.49788951873779297, 0.5013884902000427, 0.5018420815467834, 0.49715477228164673]]
In [ ]:
Content source: lisa-1010/smart-tutor
Similar notebooks: