In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from datetime import datetime
import math
import time

import os
import numpy as np
import tensorflow as tf
from six.moves import xrange
import deepwarp
import load_dataset2

from config import get_config
conf,_ = get_config()

FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_string('eval_dir', '/log_pred', """Directory where to write event logs.""")
tf.app.flags.DEFINE_string('checkpoint_dir', '/checkpoints', """Directory where to read model checkpoints.""")
tf.app.flags.DEFINE_integer('eval_interval_secs', 60 * 5, """How often to run the eval.""")
tf.app.flags.DEFINE_integer('num_examples', 10000, """Number of examples to run.""")
tf.app.flags.DEFINE_boolean('run_once', False, """Whether to run eval only once.""")
conf.dataset = 'dirl_v2'
tf.app.flags.DEFINE_string('data_dir', '../../dataset/', """Path to the dataset directory.""")
validation_portion=0.05


Namespace(agl_dim=2, batch_size=128, channel=3, dataset='None', ef_dim=14, encoded_agl_dim=16, epochs=500, eye='None', height=41, is_cfw_only=False, load_weights='None', lr=0.0001, width=51)

In [2]:
# def eval_once(saver, summary_writer, top_k_op, summary_op):
#     """Run Eval once.
#     Args:
#       saver: Saver.
#       summary_writer: Summary writer.
#       top_k_op: Top K op.
#       summary_op: Summary op.
#     """
#     with tf.Session() as sess:
#         ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
#         if ckpt and ckpt.model_checkpoint_path:
#             # Restores from checkpoint
#             saver.restore(sess, ckpt.model_checkpoint_path)
#             # Assuming model_checkpoint_path looks something like:
#             #   /my-favorite-path/cifar10_train/model.ckpt-0,
#             # extract global_step from it.
#             global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
#         else:
#             print('No checkpoint file found')
#             return

#         # Start the queue runners.
#         coord = tf.train.Coordinator()
#         try:
#             threads = []
#             for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
#                 threads.extend(qr.create_threads(sess, coord=coord, daemon=True, start=True))

#             num_iter = int(math.ceil(FLAGS.num_examples / FLAGS.batch_size))
#             true_count = 0  # Counts the number of correct predictions.
#             total_sample_count = num_iter * FLAGS.batch_size
#             step = 0
#             while step < num_iter and not coord.should_stop():
#                 predictions = sess.run([top_k_op])
#                 true_count += np.sum(predictions)
#                 step += 1

#             # Compute precision @ 1.
#             precision = true_count / total_sample_count
#             print('%s: precision @ 1 = %.3f' % (datetime.now(), precision))

#             summary = tf.Summary()
#             summary.ParseFromString(sess.run(summary_op))
#             summary.value.add(tag='Precision @ 1', simple_value=precision)
#             summary_writer.add_summary(summary, global_step)
#         except Exception as e:  # pylint: disable=broad-except
#             coord.request_stop(e)

#         coord.request_stop()
#         coord.join(threads, stop_grace_period_secs=10)

In [3]:
def get_pair(imgs):
    for uid in xrange(len(imgs)):
        # print(imgs[uid].shape)
        n_img = np.arange(len(imgs[uid]))
        sur, tar = np.meshgrid(n_img, n_img)
        if uid == 0:
            pairs = np.concatenate((np.expand_dims(np.repeat(uid, len(imgs[uid])*len(imgs[uid])), axis = 1),
                                    np.expand_dims(np.reshape(sur,-1), axis = 1),
                                    np.expand_dims(np.reshape(tar,-1), axis = 1)), axis = 1)
        else:
            pairs = np.concatenate((pairs, np.concatenate((np.expand_dims(np.repeat(uid, len(imgs[uid])*len(imgs[uid])), axis = 1),
                                                           np.expand_dims(np.reshape(sur,-1), axis = 1),
                                                            np.expand_dims(np.reshape(tar,-1), axis = 1)), axis = 1)),
                                  axis = 0)
    print(pairs.shape)
    return pairs

In [4]:
def data_iterator(imgs, agls, anchor_maps, pairs, batch_size):
    batch_idx = 0
    while True:
        idxs = np.arange(0, len(pairs))
        np.random.shuffle(idxs)
        for batch_idx in range(0, len(idxs), batch_size):
            cur_idxs = idxs[batch_idx:batch_idx+batch_size]
            pairs_batch = pairs[cur_idxs]
            img_batch = []
            fp_batch = []
            agl_batch = []
            img__batch = []
            for pair_idx in range(len(pairs_batch)):
                uID = pairs_batch[pair_idx,0]
                surID = pairs_batch[pair_idx,1]
                tarID = pairs_batch[pair_idx,2]
                print(uID, surID, tarID)
                if pair_idx == 0:
                    img_batch.append(imgs[uID][surID])
                    agl_batch.append(agls[uID][tarID] - agls[uID][surID])
                    fp_batch.append(anchor_maps[uID][surID])
                    img__batch.append(imgs[uID][tarID])
                else:
                    img_batch.append(imgs[uID][surID])
                    agl_batch.append(agls[uID][tarID] - agls[uID][surID])
                    fp_batch.append(anchor_maps[uID][surID])
                    img__batch.append(imgs[uID][tarID])
              

#             print(np.asarray(img_batch).shape)
#             print(np.asarray(agl_batch).shape)            
#             print(np.asarray(fp_batch).shape)
#             print(np.asarray(img__batch).shape)
#             print(pairs_batch.shape)
            yield np.asarray(img_batch), np.asarray(fp_batch), np.asarray(agl_batch), np.asarray(img__batch)

In [5]:
def evaluate():
  """Eval CIFAR-10 for a number of steps."""
  with tf.Graph().as_default() as g:
    # Get images and labels for CIFAR-10.
    data_dir = os.path.join(FLAGS.data_dir, conf.dataset, 'trainging_pickle/')
    dirs = np.asarray([d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))])
    # training_dirs = dirs[0:(dirs.shape[0]-int(dirs.shape[0]*validation_portion))]
    # valiation_dirs = dirs[(dirs.shape[0]-int(dirs.shape[0]*validation_portion)):dirs.shape[0]]
    imgs, agls, _, anchor_maps = load_dataset2.load(data_dir=data_dir, dirs = dirs, eye = "L", pose = "0P")
    
    if(len(imgs)!= len(agls) & len(imgs)!= len(anchor_maps)):
        sys.exit("Wrong length between 3 inputs")
    
    pairs = get_pair(agls)
    iter_ = data_iterator(imgs, agls, anchor_maps, pairs, 128)
    tcb = time.time()
    inputs_batch = next(iter_)
    print(time.time()-tcb)
    print(inputs_batch[1])
#     # define placeholder for inputs to network
#     with tf.name_scope('inputs'):
#         input_img = tf.placeholder(tf.float32, [None, conf.height, conf.width, conf.channel], name="input_img") # [None, 41, 51, 3]
#         input_fp = tf.placeholder(tf.float32, [None, conf.height, conf.width,conf.ef_dim], name="input_fp") # [None, 41, 51, 14]
#         input_ang = tf.placeholder(tf.float32, [None, conf.agl_dim], name="input_ang") ## [None, 41, 51, 2]
#         phase_train = tf.placeholder(tf.bool, name='phase_train') # a bool for batch_normalization
#         img_ = tf.placeholder(tf.float32, [None, conf.height, conf.width, conf.channel], name ="Ground_Truth")
#     # Build a Graph that computes the logits predictions from the
#     # inference model.
#     img_pred = deepwarp.inference(input_img, input_fp, input_ang, phase_train, conf)

#     # Calculate predictions.ji3
#     top_k_op = tf.nn.in_top_k(logits, labels, 1)

#     # Restore the moving average version of the learned variables for eval.
#     variable_averages = tf.train.ExponentialMovingAverage(
#         cifar10.MOVING_AVERAGE_DECAY)
#     variables_to_restore = variable_averages.variables_to_restore()
#     saver = tf.train.Saver(variables_to_restore)

#     # Build the summary operation based on the TF collection of Summaries.
#     summary_op = tf.summary.merge_all()

#     summary_writer = tf.summary.FileWriter(FLAGS.eval_dir, g)
    
#     with tf.Session() as sess:
#         ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
#         if ckpt and ckpt.model_checkpoint_path:
#             # Restores from checkpoint
#             saver.restore(sess, ckpt.model_checkpoint_path)
#             # Assuming model_checkpoint_path looks something like:
#             #   /my-favorite-path/cifar10_train/model.ckpt-0,
#             # extract global_step from it.
#             global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
#         else:
#             print('No checkpoint file found')
#             return

#         # Start the queue runners.
#         coord = tf.train.Coordinator()
#         try:
#             threads = []
#             for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
#                 threads.extend(qr.create_threads(sess, coord=coord, daemon=True,
#                                              start=True))

#             num_iter = int(math.ceil(FLAGS.num_examples / FLAGS.batch_size))
#             true_count = 0  # Counts the number of correct predictions.
#             total_sample_count = num_iter * FLAGS.batch_size
#             step = 0
#             while step < num_iter and not coord.should_stop():
#                 predictions = sess.run([top_k_op])
#                 true_count += np.sum(predictions)
#                 step += 1

#             # Compute precision @ 1.
#             precision = true_count / total_sample_count
#             print('%s: precision @ 1 = %.3f' % (datetime.now(), precision))

#             summary = tf.Summary()
#             summary.ParseFromString(sess.run(summary_op))
#             summary.value.add(tag='Precision @ 1', simple_value=precision)
#             summary_writer.add_summary(summary, global_step)
#         except Exception as e:  # pylint: disable=broad-except
#             coord.request_stop(e)

#         coord.request_stop()
#         coord.join(threads, stop_grace_period_secs=10)
#     while True:
#       eval_once(saver, summary_writer, top_k_op, summary_op)
#       if FLAGS.run_once:
#         break
#       time.sleep(FLAGS.eval_interval_secs)

evaluate()


(324856, 3)
18 45 64
21 58 18
1 56 28
11 32 21
2 64 42
9 5 42
12 91 97
12 8 35
8 23 14
7 17 69
12 63 93
9 59 24
22 63 14
19 56 16
17 1 13
15 94 84
13 78 15
23 58 73
28 78 74
27 52 98
30 96 10
25 91 19
6 25 43
6 68 69
32 75 12
4 34 17
6 43 37
27 11 26
6 48 91
14 89 83
7 93 65
31 36 3
4 83 9
5 86 17
27 12 4
28 20 41
5 91 9
4 90 88
15 87 39
22 57 87
17 69 79
15 60 48
24 96 31
9 8 61
28 4 8
1 78 3
32 66 23
26 92 98
29 40 40
17 54 13
8 4 56
16 40 5
15 97 78
6 8 32
0 2 39
0 12 32
5 37 12
18 63 5
31 31 12
29 69 27
16 29 26
3 18 14
16 37 58
3 42 15
27 4 7
13 65 61
27 48 55
31 35 39
11 24 28
19 65 4
1 18 17
22 19 94
11 21 67
23 32 65
14 38 65
2 47 56
27 7 72
12 33 97
26 8 71
6 68 82
17 85 11
30 93 68
26 93 33
6 51 13
19 49 45
27 87 22
29 33 48
2 13 45
14 88 24
24 61 79
25 65 20
9 89 58
2 59 64
12 43 3
19 68 96
31 65 61
28 15 61
25 0 19
25 48 42
25 34 89
3 56 35
13 62 2
22 55 6
9 40 87
28 1 74
0 40 31
16 27 85
32 14 49
2 98 27
2 4 62
2 19 53
9 51 67
24 57 26
15 38 66
27 6 45
2 50 31
8 9 32
11 41 24
24 58 31
1 65 10
1 33 82
28 48 88
16 63 36
26 94 93
28 64 16
3 43 91
20 89 62
30 48 3
0.020032405853271484
[[  5.  28.]
 [-13. -31.]
 [ 19. -40.]
 [-46.   8.]
 [  0. -51.]
 [-21.   2.]
 [ 23.   2.]
 [  8. -29.]
 [  5.   3.]
 [-10.  40.]
 [  4.   3.]
 [-28. -15.]
 [-12. -29.]
 [ 20. -30.]
 [ 10. -10.]
 [-16.  -2.]
 [  2. -47.]
 [ 30.  10.]
 [ 22.  -5.]
 [ 41.  -1.]
 [ 17. -21.]
 [ 10. -60.]
 [ 16. -19.]
 [-10.   0.]
 [-20. -40.]
 [  0.  10.]
 [ 10.  10.]
 [ 33. -10.]
 [ 30.  40.]
 [ 40. -10.]
 [-40. -20.]
 [-29.  -7.]
 [ 30. -42.]
 [  3. -50.]
 [ 40.  10.]
 [ 30. -20.]
 [-33. -44.]
 [ 20.   0.]
 [ 20. -80.]
 [-10.  20.]
 [-30.  10.]
 [-30. -10.]
 [-17. -70.]
 [  0.  28.]
 [-11.  -1.]
 [ -1. -37.]
 [  1. -38.]
 [ 40.   0.]
 [  0.   0.]
 [-30. -20.]
 [  0.  20.]
 [ 40.  30.]
 [-10.  21.]
 [-20. -19.]
 [-28. -30.]
 [ 17. -12.]
 [ -3.  12.]
 [ 27. -26.]
 [ 24.  15.]
 [-50. -50.]
 [ -6.   0.]
 [-60.   0.]
 [-30.  40.]
 [ 30. -20.]
 [-11.  -5.]
 [ 28.  -1.]
 [ -4.  10.]
 [-50. -10.]
 [-47.  -1.]
 [  5. -25.]
 [-10.   0.]
 [-30.  60.]
 [ 15.  40.]
 [ 16.  46.]
 [-16.  59.]
 [-22.  15.]
 [ 19.  37.]
 [ 13.  46.]
 [-18.  14.]
 [ 20.  10.]
 [  8.  -4.]
 [-30. -20.]
 [ 50. -70.]
 [ -2. -16.]
 [  4.  -7.]
 [ 18. -70.]
 [ 30.  40.]
 [ 16.  13.]
 [ 11. -66.]
 [ -1.  13.]
 [ 31. -36.]
 [ 30. -30.]
 [ 10.   1.]
 [  0. -10.]
 [  7. -12.]
 [  2.   9.]
 [ 30.  28.]
 [ 20. -10.]
 [ 18.  -4.]
 [-32.  43.]
 [ 33. -51.]
 [-15. -11.]
 [  0. -10.]
 [-18.  45.]
 [ 35.  34.]
 [-30.  10.]
 [-44.  58.]
 [ 21.  20.]
 [  4. -36.]
 [ 10.  20.]
 [  0.  20.]
 [ 10.  10.]
 [ 13. -39.]
 [ 20.  60.]
 [-38.  10.]
 [-10. -30.]
 [-22. -17.]
 [  9. -25.]
 [ 11. -40.]
 [ 19. -36.]
 [-50.  60.]
 [  3.  40.]
 [-10. -40.]
 [ 10.   0.]
 [  9. -39.]
 [ 30.  40.]
 [  7. -26.]
 [  9. -10.]]

In [ ]:


In [6]:
# def main(argv=None):  # pylint: disable=unused-argument
#     evaluate()


# if __name__ == '__main__':
#     tf.app.run()