In [ ]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.python.framework import ops
import load_data as load
from models.model import Model
from models.customlayers import *
from models.activations import *
from training import *
import moviepy.editor as mpe
from models.AELSTM import *
L = tf.layers
import matplotlib.pyplot as plt
% matplotlib inline
data_dir = os.path.expanduser('~/Insight/video-representations/frames')
In [ ]:
batchsize = 1
sequence_length = 64
model = Model(encoder, lstm_cell, tied_decoder, batchsize, sequence_length)
## LSTM-Encoder Training Graph ##
training_inputs, training_targets = load.inputs('training', batchsize, 1, shuffle=False)
compare_inputs, compare_targets = load.inputs('training', batchsize, 1, shuffle=False)
encoded, transitioned, decoded = model.build(training_inputs)
loss = tf.reduce_mean(tf.pow(decoded - training_targets, 2))
fakeloss = tf.reduce_mean(tf.pow(decoded - training_inputs, 2))
doublefakeloss = tf.reduce_mean(tf.pow(training_targets - training_inputs, 2))
## LSTM-Encoder Validation Graph ##
validation_inputs, validation_targets = load.inputs('validation', batchsize, 1)
encoded_validation, transitioned_validation, decoded_validation = model.build(validation_inputs, reuse=True)
validation_loss = tf.reduce_mean(tf.pow(decoded_validation - validation_targets, 2))
In [ ]:
saver = tf.train.Saver(tf.trainable_variables())
init_global = tf.global_variables_initializer()
init_local = tf.local_variables_initializer()
coord = tf.train.Coordinator()
loss_cache = []
fakeloss_cache = []
dfloss_cache = []
with tf.Session() as sesh:
saver.restore(sesh, 'ptypelstm-tied-relu')
sesh.run(init_local)
threads = tf.train.start_queue_runners(coord=coord)
try:
step = 0
while not coord.should_stop():
loss_val, fakeloss_val, dfloss_val = sesh.run([loss, fakeloss, doublefakeloss])
loss_cache.append(loss_val)
fakeloss_cache.append(fakeloss_val)
dfloss_cache.append(dfloss_val)
step += 1
except tf.errors.OutOfRangeError:
print('Encoder validated: {:.2f}'.format(loss_val))
finally:
coord.request_stop()
coord.join(threads)