from __future__ import print_function
from __future__ import division

import os
import pickle
import sys

import numpy as np
np.set_printoptions(precision=4, suppress=True)

# Freeze seed - Must be called before import of Keras!
seed = 12345
print("Seed state - {seed}".format(seed=seed))

import cv2
from visualize import OrthoData, imshow
%matplotlib inline

from utils import get_files_with_ext

project_base_dir = '/home/moti/cg/project'

data_sessions_outputs = os.path.join(project_base_dir, 'sessions_outputs')
model_sessions_outputs = os.path.join(project_base_dir, 'meshNet/sessions_outputs')

data_dir = os.path.join(data_sessions_outputs, 'berlinRoi_-1600_-800_1600_1600_GridStep10')

results_dir = os.path.join(project_base_dir, 'meshNet', 'results')

def nearest_neighbour(y_true, y_pred, k=5):
    from scipy import spatial
    tree = spatial.cKDTree(y_true)
    return tree.query(y_pred, k)

nn_dist = None

def get_sorted_array_idx(data, is_interpolation):
    global nn_dist
    grid_step = data['grid_step']
    if is_interpolation:
        y_test_true = data['y_test_true']
        y_test_pred = data['y_test_pred']

        x_manhattan_distance = abs(y_test_true[:, 0] - y_test_pred[:, 0]) // grid_step
        y_manhattan_distance = abs(y_test_true[:, 1] - y_test_pred[:, 1]) // grid_step
        xy_manhattan_distance = x_manhattan_distance + y_manhattan_distance
        from visualize import plot_hist
        %matplotlib inline
        k = int(max(xy_manhattan_distance))
        plot_hist(xy_manhattan_distance, normed=False, bins=k, title='Manhattan distance of ' + str(len(y_test_true)) + " test samples", ylabel=None, show=True, save_path=None)

        # np.argsort([2,1,3,4,5]) is [1 0 2 3 4], i.e. from small to big.
        sorted_array_idx = np.argsort(xy_manhattan_distance)
        y_train_true = data['y_train_true']
        y_train_pred = data['y_train_pred']

        k = 1000
        nn = nearest_neighbour(y_train_true, y_train_pred, k=k)

        nn_dist = np.zeros(len(nn[0]), dtype=int)
        for idx, neighbors in enumerate(nn[1]):
            for j in xrange(k):
                if neighbors[j] == idx:
                    nn_dist[idx] = j + 1
    #     print(nn_dist)

        nn_dist[nn_dist == 0] = k + 1
#         for i in xrange(1, k + 2):
#             print(i, ':', sum(nn_dist == i))

        # np.argsort([2,1,3,4,5]) is [1 0 2 3 4], i.e. from small to big.
        sorted_array_idx = np.argsort(nn_dist)

    return sorted_array_idx

import os, errno
def mkdirs(newdir, mode=0777):
    try: os.makedirs(newdir, mode)
    except OSError, err:
        # Reraise the error unless it's about an already existing directory 
        if err.errno != errno.EEXIST or not os.path.isdir(newdir): 

import pexpect

def scp_expect(src, dst):
        var_password = "Shamir#2017!"
        var_command = "scp -r " + src + " " + dst
    #     " arik@" + \
    #     " /home/moti/cg/project/sessions_outputs/berlinRoi_-1600_-800_1600_1600_GridStep10/gridStep10_third_2/train/-800/-850_-350/.."

        #make sure in the above command that username and hostname are according to your server
        var_child = pexpect.spawn(var_command, timeout=60)
        i = var_child.expect(["arik@'s password: ", pexpect.EOF])

        if i == 0: # send password                
        elif i == 1: 
            print("Got the key or connection timeout")
    except Exception as e:
        print("Oops Something went wrong buddy")

#     print("scp done")

def verify_images_availability(file_urls, idx_list):
    for idx in idx_list:
#         if idx > len(file_urls):
#             print("idx:", idx, "exceeds number of samples. skipping...")
#             continue  # idx_list might not be ordered
        f = file_urls[idx]

#         print('idx:', idx, ', file:', f)
        f_local = f.replace('/mnt/arik_2T_usb/project/sessions_outputs', data_sessions_outputs)
        if not os.path.isfile(f_local):
            print('idx:', idx, f_local, "does not exist, copying from remote...")

            dst_dir = os.path.split(f_local)[0] + os.sep + os.pardir

            src_dir = os.path.split(f)[0]
            scp_expect("arik@" + src_dir, dst_dir)
            print('idx:', idx, ', done copy from remote\n')
#             print(f_local, "exists\n")

import utils

def y_inverse_transform(y, y_min_max, y_type):
    if y.ndim == 1:
        y = np.expand_dims(y, axis=0)

    y_new = np.zeros_like(y)

    for i in xrange(2):
        y_new[:, i] = utils.min_max_scale(y[:, i], (0, 1), (y_min_max[0][i], y_min_max[1][i]))

    if y_type == 'angle':
        y_new[:, 2:] = (y[:, 2:] * 360) % 360
        y_new[:, 2:] = y[:, 2:]
    return y_new.squeeze()

from meshNet_loader import ImageProcessor, LabelsParser  #, DataLoader
from visualize import imshow
# from consts import IMAGE_RANGE
import utils
from visualize import render_view

def render_images(y_true, y_pred, file_urls, nn_distance, idx_list, data, is_interpolation, weights_filename_local):
    print("render_images: Entered")

    grid_step = data['grid_step']
    roi = data['roi']
    prefix_str = "gridStep" + str(grid_step) + "_" + str(roi[0]) + "_" + str(roi[1]) + "_" + str(roi[2]) + \
        "_" + str(roi[3]) + "_" + \
        data['x_type'] + "_" + ("geo-interpolation" if is_interpolation else "geo_matching")
    evaluation_visualization_dir = os.path.join(os.path.dirname(weights_filename_local), os.pardir,
    print('evaluation_visualization_dir:', evaluation_visualization_dir)
    if not os.path.exists(evaluation_visualization_dir):

    for idx in idx_list:
        print("idx", idx)
        f = file_urls[idx]
        f_local = f.replace('/mnt/arik_2T_usb/project/sessions_outputs', data_sessions_outputs)
        if not os.path.isfile(f_local):
            print(f_local, "does not exist")

#         pose_from_disk = LabelsParser.load_pose_file(f_local)
#         y_true_from_disk = np.empty((4,), dtype=np.float32)
#         y_true_from_disk[:2] = pose_from_disk[:2]
#         y_true_from_disk[2:] = pose_from_disk[3:5]
#         print("pose_from_disk", pose_from_disk)
#         print("y_true_from_disk:", y_true_from_disk)
        yp_true = utils.get_yaw_pitch_from_quaternion_array(y_true[idx][2:])[0]
        print("yaw, pitch true:", yp_true)
        y_true_cur = np.empty((4,), dtype=np.float32)
        y_true_cur[:2] = y_true[idx][:2]
        y_true_cur[2:] = yp_true
        print("y_true_cur", y_true_cur)
        print("y_true:", y_true[idx])
        print("y_pred:", y_pred[idx])
        yp_pred = utils.get_yaw_pitch_from_quaternion_array(y_pred[idx][2:])[0]
        print("yaw, pitch pred:", yp_pred)
        y_pred_cur = np.empty((4,), dtype=np.float32)
        y_pred_cur[:2] = y_pred[idx][:2]
        y_pred_cur[2:] = yp_pred
        print("y_pred_cur", y_pred_cur)

        y_true_cur_str = str(y_true_cur[0]) + '_' + str(y_true_cur[1]) + '_' + str(y_true_cur[2]) + '_' + str(y_true_cur[3])
        y_pred_cur_str = str(y_pred_cur[0]) + '_' + str(y_pred_cur[1]) + '_' + str(y_pred_cur[2]) + '_' + str(y_pred_cur[3])

        if is_interpolation:
            x_manhattan_distance = abs(y_true_cur[0] - y_pred_cur[0]) // grid_step
            y_manhattan_distance = abs(y_true_cur[1] - y_pred_cur[1]) // grid_step
            xy_manhattan_distance = x_manhattan_distance + y_manhattan_distance
            dist_str = '_D_' + str(int(xy_manhattan_distance))
            dist_str = '_nn_' + str(int(nn_distance[idx]))
        print("dist_str:", dist_str)

        img_true_path = os.path.join(evaluation_visualization_dir, prefix_str + "_" +
                                     str(idx) + dist_str + '_img_true_' + y_true_cur_str + ".png")

        img_pred_path = os.path.join(evaluation_visualization_dir, prefix_str + "_" +
                                     str(idx) + dist_str + '_img_pred_' + y_pred_cur_str + ".png")
        img_true_from_disk = ImageProcessor.load_image(f_local, (160, 120))
        if img_true_from_disk is not None:
            img_true_from_disk = ImageProcessor.flip_imgs_colors(img_true_from_disk)
            imshow('img_true_from_disk' + str(idx), img_true_from_disk)

        if os.path.exists(img_true_path):
            print("img_true exists", img_true_path)
            img_true = render_view(y_true_cur)
            img_true = ImageProcessor.flip_imgs_colors(img_true)
            imshow('img_true_' + str(idx), img_true)
            cv2.imwrite(img_true_path, img_true)
            print("Wrote img_true:", img_true_path)
        if os.path.exists(img_pred_path):
            print("img_pred exists", img_pred_path)
            img_pred = render_view(y_pred_cur)
            img_pred = ImageProcessor.flip_imgs_colors(img_pred)
            imshow('img_pred_' + str(idx), img_pred)

            cv2.imwrite(img_pred_path, img_pred)
            print("Wrote img_pred:", img_pred_path)

import re

def get_y_min_max_from_log(weights_filename_local):
    log_file_path = os.path.join(os.path.dirname(weights_filename_local), os.pardir, 'logs',
    print('log_file_path:', log_file_path)
    found_line = False
    with open(log_file_path) as f:
        y_min_max_str = ""
        for line in f:
            if found_line:
                y_min_max_str += line
            if line.startswith("y_min_max"):
                y_min_max_str += line
                found_line = True
    if not found_line:
        raise Exception("Did not find y_min_max in log file", log_file_path)
    y_min_max_str = y_min_max_str.replace(" ", "")
    y_min_max_str = y_min_max_str.replace("\n", "")
#     print('y_min_max_str', y_min_max_str)

#     y_min_max = np.array([[-1600., -800., -0.7071, -0.561, -0.536, -0.4777], \
#                      [-800., 0, 0.5, 0.7934, 0.7934, 0.7071]])
#     y_min_max_str y_min_max:(array([-1200.,-400.,-0.7071,-0.561,-0.536,-0.4777],dtype=float32),array([-800.,0.,0.5,0.7934,0.7934,0.7071],dtype=float32))
    numbers_mask = re.compile(r"[+-]?\d+(?:\.\d+)?")

    newtext = y_min_max_str
    mtch =
    numbers = []
    while mtch:
        num =
#         print('match: %s' % (num))
        newtext = newtext[mtch.end(0) + 1:]
        mtch =

#     print(numbers)

    numbers = np.array(numbers, dtype=np.float32)
#     print(numbers)

    y_min_max = np.empty((2, 6), dtype=np.float32)
    y_min_max[0, :] = numbers[:6]
    y_min_max[1, :] = numbers[7:13]
    print('y_min_max:', y_min_max)
    return y_min_max

from meshNet_loader import process_labels

if os.path.exists('sorting_cache.pkl'):
    print('sorting_cache.pkl exists. Loading...')
    with open('sorting_cache.pkl', 'rb') as f:
        sorting_cache = pickle.load(f)
    sorting_cache = {}

def visualize_evaluation(cur_pkl):
    global sorting_cache

    with open(cur_pkl, 'rb') as f:
        data = pickle.load(f)
#     print("data fields:")
#     for key in data.keys():
#         print("    ", key)

    mesh_name = data['mesh_name']
    roi = data['roi']
    x_type = data['x_type']
    y_type = data['y_type']
    weights_filename = data['weights_filename']
    weights_filename_local = os.path.join(os.path.dirname(cur_pkl), os.path.pardir, 'hdf5', os.path.split(weights_filename)[1])

    print("mesh_name:", mesh_name)
    print('roi:', roi)
    print('grid_step:', data['grid_step'])
    print("x_type:", x_type)
    print("y_type:", y_type)
    print("weights_filename:", weights_filename)
    print("weights_filename_local:", weights_filename_local)

    if os.path.isfile(weights_filename_local):
        print("Local weight available")
        print("Local weight NOT found")
    is_interpolation = True if cur_pkl.find('val_loss') != -1 else False
    use_cache = False
    # Simple cache for debug
    if use_cache and weights_filename_local in sorting_cache:
        print("Using cache_sorted_array_idx")
        sorted_array_idx = sorting_cache[weights_filename_local]
        sorted_array_idx = get_sorted_array_idx(data, is_interpolation)
        # Update cache
        sorting_cache[weights_filename_local] = sorted_array_idx
    if is_interpolation:
        file_urls = data['file_urls_test'][sorted_array_idx]
        y_true = data['y_test_true'][sorted_array_idx]
        y_pred = data['y_test_pred'][sorted_array_idx]
        file_urls = data['file_urls_train'][sorted_array_idx]
        y_true = data['y_train_true'][sorted_array_idx]
        y_pred = data['y_train_pred'][sorted_array_idx]
    idx_list = range(20)    
    idx_list += range(45, 55)
    idx_list += range(99, 105)
    idx_list += range(995, 1005)
    idx_list += range(4995, 5005)
    idx_list += range(49995, 50005)
    idx_list += range(len(file_urls)-20, len(file_urls))
    valid_idx_list = []
    for idx in idx_list:
        if idx < len(file_urls):
    idx_list = valid_idx_list
#     print("file_urls len:", len(file_urls), "idx_list:", idx_list)
    verify_images_availability(file_urls, idx_list)

    if is_interpolation:
        nn_distance = None
        nn_distance = nn_dist[sorted_array_idx]
#     y_min_max = get_y_min_max_from_log(weights_filename_local)
    render_images(y_true, y_pred, file_urls, nn_distance, idx_list, data, is_interpolation, weights_filename_local)

# Main

grid_step_10_path = os.path.join(results_dir, 'gridStep10')
grid_step_20_path = os.path.join(results_dir, 'gridStep20')
grid_step_40_path = os.path.join(results_dir, 'gridStep40')

pkl_list_10 = get_files_with_ext(grid_step_10_path, ext_list=('.pkl'), recursive=True, abs_path=True,
                                 sort=True, warn_empty=True)
pkl_list_20 = get_files_with_ext(grid_step_20_path, ext_list=('.pkl'), recursive=True, abs_path=True,
                                 sort=True, warn_empty=True)
pkl_list_40 = get_files_with_ext(grid_step_40_path, ext_list=('.pkl'), recursive=True, abs_path=True,
                                 sort=True, warn_empty=True)

print('pkl_list_10:', len(pkl_list_10))
print('pkl_list_20:', len(pkl_list_20))
print('pkl_list_40:', len(pkl_list_40))
all_pkl = pkl_list_10 + pkl_list_20 + pkl_list_40

for idx, pkl in enumerate(all_pkl):
    print('Start pkl idx', idx, ", pkl:", pkl)
    print('Done pkl idx', idx, '\n\n\n')
#     break

# Invert colors for "evaluations_visualizations"

grid_step_10_path = os.path.join(results_dir, 'gridStep10')
grid_step_20_path = os.path.join(results_dir, 'gridStep20')
grid_step_40_path = os.path.join(results_dir, 'gridStep40')

img_list_10 = get_files_with_ext(grid_step_10_path, ext_list=('.png'), recursive=True, abs_path=True,
                                 sort=True, warn_empty=True)
img_list_20 = get_files_with_ext(grid_step_20_path, ext_list=('.png'), recursive=True, abs_path=True,
                                 sort=True, warn_empty=True)
img_list_40 = get_files_with_ext(grid_step_40_path, ext_list=('.png'), recursive=True, abs_path=True,
                                 sort=True, warn_empty=True)

print('img_list_10:', len(img_list_10))
print('img_list_20:', len(img_list_20))
print('img_list_40:', len(img_list_40))
all_imgs = img_list_10 + img_list_20 + img_list_40

# Choose only images of "evaluations_visualizations"
idx_list = []
for idx, fname in enumerate(all_imgs):
    idx_list.append(True if "evaluations_visualizations" in fname else False)
idx_list = np.array(idx_list)
all_imgs = np.array(all_imgs)
all_imgs = all_imgs[idx_list]
print('evaluation visualization imgs number:', len(all_imgs))

print("Inverting evaluation visualization images colors...")
for idx, fname in enumerate(all_imgs):
    print('idx:', idx, 'fname:', fname)    
    img = cv2.imread(fname, cv2.IMREAD_GRAYSCALE)
    img = ImageProcessor.flip_imgs_colors(img)
    new_fname = os.path.join(os.path.dirname(fname), 'white', os.path.basename(fname))
    cv2.imwrite(new_fname, img)

idx_list = []
for idx, pkl in enumerate(all_pkl):
    if "evaluations_visualizations" not in pkl:
import meshNet_model
import meshNet_loader
import resnet50
import cv2

params = {'image_shape': (120, 160, 1), 'xy_nb_outs': 2, 'rot_nb_outs': 4, 'multi_gpu': False}
model, model_name = resnet50.resnet50_regression_train(**params)

meshNet_model.load_model_weights(model, weights_filename_local)

# test_scores = model.evaluate(loader.x_test, [loader.y_test[:, :2], loader.y_test[:, 2:]],
#                              batch_size=batch_size, verbose=0)

from meshNet_loader import ImageProcessor, LabelsParser, DataLoader
from visualize import imshow
from consts import IMAGE_RANGE
import utils
from visualize import render_view

# Taken from log file
y_min_max = np.array([[-1600., -800., -0.7071, -0.561, -0.536, -0.4777], \
                     [-800., 0, 0.5, 0.7934, 0.7934, 0.7071]])
print("y_min_max", y_min_max)

print("x_type", x_type)
for idx, f in enumerate(file_urls_test_sorted):
    if idx % 50 != 0:

    if idx % 100 == 0:
    print("idx", idx)
    f_local = f.replace('/mnt/arik_2T_usb/project/sessions_outputs', data_sessions_outputs)
    if not os.path.isfile(f_local):
        print(f_local, "does not exist")
    img = ImageProcessor.load_image(f_local, (160, 120))
    img = ImageProcessor.flip_imgs_colors(img)
    pose = LabelsParser.load_pose_file(f_local)
    print("GT pose", pose)
    y_true = np.empty((4,), dtype=np.float32)
    y_true[:2] = pose[:2]
    y_true[2:] = pose[3:5]
    print("y_true", y_true)
#     imshow('aa', img)

    x = np.expand_dims(img, axis=0).astype(np.float32)
    x = utils.min_max_scale(x, IMAGE_RANGE, (0, 1))
    res = model.predict(x, batch_size=1, verbose=1)
    y = np.empty((6,), dtype=np.float32)
    y[0:2] = res[0][0][0:2]
    y[2:6] = res[1][0][0:4]
    y_pred_q = y_inverse_transform(y, y_min_max, y_type)
    print("y_pred_q", y_pred_q)
    yp = utils.get_yaw_pitch_from_quaternion_array(y_pred_q[2:])[0]
    print("yaw, pitch ", yp)

    y_pred = np.empty((4,), dtype=np.float32)
    y_pred[:2] = y_pred_q[:2]
    y_pred[2:] = yp[:2]
    print("y_pred", y_pred)
    print("y_true", y_true)    

    img_true = render_view(y_true)
    img_true = ImageProcessor.flip_imgs_colors(img_true)
    imshow('img_true_' + str(idx), img_true)

    img_pred = render_view(y_pred)
    img_pred = ImageProcessor.flip_imgs_colors(img_pred)
    imshow('img_pred_' + str(idx), img_pred)

    y_true_str = str(y_true[0]) + '_' + str(y_true[1]) + '_' + str(y_true[2]) + '_' + str(y_true[3])
    y_pred_str = str(y_pred[0]) + '_' + str(y_pred[1]) + '_' + str(y_pred[2]) + '_' + str(y_pred[3])
    dist = xy_manhattan_distance[sorted_array_idx][idx]
    dist_str = '_D_' + str(int(dist))
    cv2.imwrite(str(idx) + dist_str + '_img_true_' + y_true_str + ".png", img_true)
    cv2.imwrite(str(idx) + dist_str + '_img_pred_' + y_pred_str + ".png", img_pred)
#     break

