In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from datetime import datetime
import math
import time
import os
import numpy as np
import tensorflow as tf
from six.moves import xrange
import deepwarp
import load_dataset2
from config import get_config
conf,_ = get_config()
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('eval_dir', '/log_pred', """Directory where to write event logs.""")
tf.app.flags.DEFINE_string('checkpoint_dir', '/checkpoints', """Directory where to read model checkpoints.""")
tf.app.flags.DEFINE_integer('eval_interval_secs', 60 * 5, """How often to run the eval.""")
tf.app.flags.DEFINE_integer('num_examples', 10000, """Number of examples to run.""")
tf.app.flags.DEFINE_boolean('run_once', False, """Whether to run eval only once.""")
conf.dataset = 'dirl_v2'
tf.app.flags.DEFINE_string('data_dir', '../../dataset/', """Path to the dataset directory.""")
validation_portion=0.05
In [2]:
# def eval_once(saver, summary_writer, top_k_op, summary_op):
# """Run Eval once.
# Args:
# saver: Saver.
# summary_writer: Summary writer.
# top_k_op: Top K op.
# summary_op: Summary op.
# """
# with tf.Session() as sess:
# ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
# if ckpt and ckpt.model_checkpoint_path:
# # Restores from checkpoint
# saver.restore(sess, ckpt.model_checkpoint_path)
# # Assuming model_checkpoint_path looks something like:
# # /my-favorite-path/cifar10_train/model.ckpt-0,
# # extract global_step from it.
# global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
# else:
# print('No checkpoint file found')
# return
# # Start the queue runners.
# coord = tf.train.Coordinator()
# try:
# threads = []
# for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
# threads.extend(qr.create_threads(sess, coord=coord, daemon=True, start=True))
# num_iter = int(math.ceil(FLAGS.num_examples / FLAGS.batch_size))
# true_count = 0 # Counts the number of correct predictions.
# total_sample_count = num_iter * FLAGS.batch_size
# step = 0
# while step < num_iter and not coord.should_stop():
# predictions = sess.run([top_k_op])
# true_count += np.sum(predictions)
# step += 1
# # Compute precision @ 1.
# precision = true_count / total_sample_count
# print('%s: precision @ 1 = %.3f' % (datetime.now(), precision))
# summary = tf.Summary()
# summary.ParseFromString(sess.run(summary_op))
# summary.value.add(tag='Precision @ 1', simple_value=precision)
# summary_writer.add_summary(summary, global_step)
# except Exception as e: # pylint: disable=broad-except
# coord.request_stop(e)
# coord.request_stop()
# coord.join(threads, stop_grace_period_secs=10)
In [3]:
def get_pair(imgs):
for uid in xrange(len(imgs)):
# print(imgs[uid].shape)
n_img = np.arange(len(imgs[uid]))
sur, tar = np.meshgrid(n_img, n_img)
if uid == 0:
pairs = np.concatenate((np.expand_dims(np.repeat(uid, len(imgs[uid])*len(imgs[uid])), axis = 1),
np.expand_dims(np.reshape(sur,-1), axis = 1),
np.expand_dims(np.reshape(tar,-1), axis = 1)), axis = 1)
else:
pairs = np.concatenate((pairs, np.concatenate((np.expand_dims(np.repeat(uid, len(imgs[uid])*len(imgs[uid])), axis = 1),
np.expand_dims(np.reshape(sur,-1), axis = 1),
np.expand_dims(np.reshape(tar,-1), axis = 1)), axis = 1)),
axis = 0)
print(pairs.shape)
return pairs
In [4]:
def data_iterator(imgs, agls, anchor_maps, pairs, batch_size):
batch_idx = 0
while True:
idxs = np.arange(0, len(pairs))
np.random.shuffle(idxs)
for batch_idx in range(0, len(idxs), batch_size):
cur_idxs = idxs[batch_idx:batch_idx+batch_size]
pairs_batch = pairs[cur_idxs]
img_batch = []
fp_batch = []
agl_batch = []
img__batch = []
for pair_idx in range(len(pairs_batch)):
uID = pairs_batch[pair_idx,0]
surID = pairs_batch[pair_idx,1]
tarID = pairs_batch[pair_idx,2]
print(uID, surID, tarID)
if pair_idx == 0:
img_batch.append(imgs[uID][surID])
agl_batch.append(agls[uID][tarID] - agls[uID][surID])
fp_batch.append(anchor_maps[uID][surID])
img__batch.append(imgs[uID][tarID])
else:
img_batch.append(imgs[uID][surID])
agl_batch.append(agls[uID][tarID] - agls[uID][surID])
fp_batch.append(anchor_maps[uID][surID])
img__batch.append(imgs[uID][tarID])
# print(np.asarray(img_batch).shape)
# print(np.asarray(agl_batch).shape)
# print(np.asarray(fp_batch).shape)
# print(np.asarray(img__batch).shape)
# print(pairs_batch.shape)
yield np.asarray(img_batch), np.asarray(fp_batch), np.asarray(agl_batch), np.asarray(img__batch)
In [5]:
def evaluate():
"""Eval CIFAR-10 for a number of steps."""
with tf.Graph().as_default() as g:
# Get images and labels for CIFAR-10.
data_dir = os.path.join(FLAGS.data_dir, conf.dataset, 'trainging_pickle/')
dirs = np.asarray([d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))])
# training_dirs = dirs[0:(dirs.shape[0]-int(dirs.shape[0]*validation_portion))]
# valiation_dirs = dirs[(dirs.shape[0]-int(dirs.shape[0]*validation_portion)):dirs.shape[0]]
imgs, agls, _, anchor_maps = load_dataset2.load(data_dir=data_dir, dirs = dirs, eye = "L", pose = "0P")
if(len(imgs)!= len(agls) & len(imgs)!= len(anchor_maps)):
sys.exit("Wrong length between 3 inputs")
pairs = get_pair(agls)
iter_ = data_iterator(imgs, agls, anchor_maps, pairs, 128)
tcb = time.time()
inputs_batch = next(iter_)
print(time.time()-tcb)
print(inputs_batch[1])
# # define placeholder for inputs to network
# with tf.name_scope('inputs'):
# input_img = tf.placeholder(tf.float32, [None, conf.height, conf.width, conf.channel], name="input_img") # [None, 41, 51, 3]
# input_fp = tf.placeholder(tf.float32, [None, conf.height, conf.width,conf.ef_dim], name="input_fp") # [None, 41, 51, 14]
# input_ang = tf.placeholder(tf.float32, [None, conf.agl_dim], name="input_ang") ## [None, 41, 51, 2]
# phase_train = tf.placeholder(tf.bool, name='phase_train') # a bool for batch_normalization
# img_ = tf.placeholder(tf.float32, [None, conf.height, conf.width, conf.channel], name ="Ground_Truth")
# # Build a Graph that computes the logits predictions from the
# # inference model.
# img_pred = deepwarp.inference(input_img, input_fp, input_ang, phase_train, conf)
# # Calculate predictions.ji3
# top_k_op = tf.nn.in_top_k(logits, labels, 1)
# # Restore the moving average version of the learned variables for eval.
# variable_averages = tf.train.ExponentialMovingAverage(
# cifar10.MOVING_AVERAGE_DECAY)
# variables_to_restore = variable_averages.variables_to_restore()
# saver = tf.train.Saver(variables_to_restore)
# # Build the summary operation based on the TF collection of Summaries.
# summary_op = tf.summary.merge_all()
# summary_writer = tf.summary.FileWriter(FLAGS.eval_dir, g)
# with tf.Session() as sess:
# ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
# if ckpt and ckpt.model_checkpoint_path:
# # Restores from checkpoint
# saver.restore(sess, ckpt.model_checkpoint_path)
# # Assuming model_checkpoint_path looks something like:
# # /my-favorite-path/cifar10_train/model.ckpt-0,
# # extract global_step from it.
# global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
# else:
# print('No checkpoint file found')
# return
# # Start the queue runners.
# coord = tf.train.Coordinator()
# try:
# threads = []
# for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
# threads.extend(qr.create_threads(sess, coord=coord, daemon=True,
# start=True))
# num_iter = int(math.ceil(FLAGS.num_examples / FLAGS.batch_size))
# true_count = 0 # Counts the number of correct predictions.
# total_sample_count = num_iter * FLAGS.batch_size
# step = 0
# while step < num_iter and not coord.should_stop():
# predictions = sess.run([top_k_op])
# true_count += np.sum(predictions)
# step += 1
# # Compute precision @ 1.
# precision = true_count / total_sample_count
# print('%s: precision @ 1 = %.3f' % (datetime.now(), precision))
# summary = tf.Summary()
# summary.ParseFromString(sess.run(summary_op))
# summary.value.add(tag='Precision @ 1', simple_value=precision)
# summary_writer.add_summary(summary, global_step)
# except Exception as e: # pylint: disable=broad-except
# coord.request_stop(e)
# coord.request_stop()
# coord.join(threads, stop_grace_period_secs=10)
# while True:
# eval_once(saver, summary_writer, top_k_op, summary_op)
# if FLAGS.run_once:
# break
# time.sleep(FLAGS.eval_interval_secs)
evaluate()
In [ ]:
In [6]:
# def main(argv=None): # pylint: disable=unused-argument
# evaluate()
# if __name__ == '__main__':
# tf.app.run()