In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import datetime
import tensorflow as tf
import tflearn
import numpy as np
from sklearn.model_selection import train_test_split

import drqn
import student as st

import data_generator as dg
import concept_dependency_graph as cdg
from experience_buffer import ExperienceBuffer
import dataset_utils as d_utils
import utils
import models_dict_utils
from drqn_tests import *

%load_ext autoreload
%autoreload 2
%reload_ext autoreload

General Workflow

1. Create Data Set


In [2]:
n_concepts = 4
use_student2 = True
student2_str = '2' if use_student2 else ''
learn_prob = 0.15
lp_str = '-lp{}'.format(int(learn_prob*100)) if not use_student2 else ''
n_students = 100000
seqlen = 7
filter_mastery = False
filter_str = '' if not filter_mastery else '-filtered'
policy = 'expert'
filename = 'test{}-n{}-l{}{}-{}{}.pickle'.format(student2_str, n_students, seqlen,
                                                    lp_str, policy, filter_str)

only run the next two cells if dataset hasn't been created yet


In [3]:
#concept_tree = sm.create_custom_dependency()
concept_tree = cdg.ConceptDependencyGraph()
concept_tree.init_default_tree(n_concepts)
if not use_student2:
    test_student = st.Student(n=n_concepts,p_trans_satisfied=learn_prob, p_trans_not_satisfied=0.0, p_get_ex_correct_if_concepts_learned=1.0)
else:
    test_student = st.Student2(n_concepts)
print(filename)


test2-n100000-l7-expert.pickle

In [4]:
print ("Initializing synthetic data sets...")
dg.generate_data(concept_tree, student=test_student, n_students=n_students, filter_mastery=filter_mastery, seqlen=seqlen, policy=policy, filename="{}{}".format(dg.SYN_DATA_DIR, filename))
print ("Data generation completed. ")


Initializing synthetic data sets...
Generating data for 100000 students with behavior policy expert and sequence length 7.
Data generation completed. 

In [5]:
data = d_utils.load_data(filename="../synthetic_data/{}".format(filename))
dqn_data = d_utils.preprocess_data_for_dqn(data, reward_model="semisparse")
dqn_data_train, dqn_data_test = train_test_split(dqn_data, test_size=0.2)

In [6]:
# Creating training and validation data
train_buffer = ExperienceBuffer()
train_buffer.buffer = dqn_data_train
train_buffer.buffer_sz = len(train_buffer.buffer)

val_buffer = ExperienceBuffer()
val_buffer.buffer = dqn_data_test
val_buffer.buffer_sz = len(val_buffer.buffer)

In [7]:
print (train_buffer.sample(1))


[[[array([ 0.,  0.,  0.,  0.,  0.,  1.,  0.,  0.])
   array([ 0.,  1.,  0.,  0.]) 0.0
   array([ 0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.])]
  [array([ 0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.])
   array([ 0.,  0.,  1.,  0.]) 0.0
   array([ 0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.])]
  [array([ 0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.])
   array([ 0.,  0.,  1.,  0.]) 0.0
   array([ 0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.])]
  [array([ 0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.])
   array([ 0.,  0.,  0.,  1.]) 0.0
   array([ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.])]
  [array([ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.])
   array([ 0.,  0.,  0.,  1.]) 0.0
   array([ 0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.])]
  [array([ 0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.])
   array([ 1.,  0.,  0.,  0.]) 4.0
   array([ 1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.])]]]

2. Create Model and Train


In [9]:
model_id = "test2_model_drqn_mid_expert"
model = drqn.DRQNModel(model_id, timesteps=seqlen-1)
model.init_trainer()


Loaded model test2_model_drqn_mid_expert

In [ ]:
# train the model (uses the previously initialized trainer object)
date_time_string = datetime.datetime.now().strftime("%m-%d-%Y_%H-%M-%S")
run_id = "{}".format(date_time_string)
model.train(train_buffer, val_buffer, n_epoch=60,
              run_id=run_id, load_checkpoint=True)


Training Step: 128  | total loss: 2.61190 | time: 3.357s
| Optimizer | epoch: 001 | loss: 2.61190 -- iter: 08192/80000

In [47]:
test_drqn(model_id=model_id, DEBUG=True)


Testing model: test2_model_drqn_mid
horizon: 6
Loaded model test2_model_drqn_mid
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 0: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 1: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 2: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 3: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 4: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 5: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 6: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 7: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 8: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 9: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 10: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 11: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 12: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 13: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 14: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 15: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 16: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 17: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 18: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 19: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 20: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 21: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 22: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 23: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 24: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 25: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 26: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 27: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 28: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 29: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 30: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 31: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 32: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 33: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 34: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 35: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 36: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 37: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 38: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 39: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 40: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 41: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 42: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 43: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 44: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 45: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 46: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 47: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 48: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 49: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 50: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 51: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 52: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 53: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 54: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 55: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 56: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 57: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 58: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 59: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 60: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 61: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 62: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 63: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 64: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 65: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 66: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 67: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 68: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 69: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 70: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 71: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 72: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 73: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 74: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 75: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 76: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 77: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 78: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 79: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 80: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 81: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 82: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 83: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 84: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 85: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 86: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 87: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 88: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 89: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 90: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 91: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 92: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 93: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 94: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 95: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 96: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 97: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 98: acc: 0.75
ERROR [ 1.  0.  1.  0.] executed non-optimal action 2 
 with predicted q value [ 1.59116113  1.50520086  1.56009531  1.50421655] 


traj i 99: acc: 0.75
Generating data for 1000 students with behavior policy expert and sequence length 6.
Average posttest true: 1.0
Average posttest drqn: 0.75

In [12]:
model_id = "test2_model_drqn_mid"
model = drqn.DRQNModel(model_id, timesteps=seqlen-1)
model.init_trainer()
# train the model (uses the previously initialized trainer object)
date_time_string = datetime.datetime.now().strftime("%m-%d-%Y_%H-%M-%S")
run_id = "{}".format(date_time_string)
model.train(train_buffer, val_buffer, n_epoch=1,
              run_id=run_id, load_checkpoint=True)


Training Step: 2499  | total loss: 0.00137 | time: 54.716s
| Optimizer | epoch: 001 | loss: 0.00137 -- iter: 79936/80000
Training Step: 2500  | total loss: 0.00174 | time: 57.191s
| Optimizer | epoch: 001 | loss: 0.00174 | val_loss: 0.00247 -- iter: 80000/80000
--
WARNING:tensorflow:Error encountered when serializing layer_tensor/lstm_2.
Type is unsupported, or the types of the items don't match field type in CollectionDef.
'list' object has no attribute 'name'
WARNING:tensorflow:Error encountered when serializing layer_tensor/lstm_1.
Type is unsupported, or the types of the items don't match field type in CollectionDef.
'list' object has no attribute 'name'

In [31]:
a = np.array([[1,2,3], [0,5,6]])

In [33]:
np.argmax(a, axis=1)


Out[33]:
array([2, 2])

In [ ]: