from shutil import copyfile
import subprocess
from subprocess import Popen, PIPE
from multiprocessing import Pool, freeze_support, cpu_count
import matplotlib.pyplot as plt
import itertools
import os
from glob import glob
import numpy as np
import argparse
import sys

from itertools import islice
import cv2
from copy import copy, deepcopy
from scipy.ndimage import rotate

import logging
logger = logging.getLogger()
import config
import math

import mosaic as m

from import imread, imsave
from skimage.color import gray2rgb, rgb2gray
from skimage.feature import match_descriptors
from skimage.measure import ransac
from skimage.transform import warp, SimilarityTransform, AffineTransform, ProjectiveTransform
from skimage import img_as_float, img_as_ubyte
sift = cv2.SIFT()
orb = cv2.ORB()
surf = cv2.SURF()
brisk = cv2.BRISK()

A = cv2.imread('apple.jpg')
B = cv2.imread('orange.jpg')

# generate Gaussian pyramid for A
G = A.copy()
gpA = [G]
for i in xrange(6):
    G = cv2.pyrDown(G)
    gpA.append(G)

# generate Gaussian pyramid for B
G = B.copy()
gpB = [G]
for i in xrange(6):
    G = cv2.pyrDown(G)
    gpB.append(G)

# generate Laplacian Pyramid for A
lpA = [gpA[5]]
for i in xrange(5,0,-1):
    GE = cv2.pyrUp(gpA[i])
    L = cv2.subtract(gpA[i-1],GE)
    lpA.append(L)

# generate Laplacian Pyramid for B
lpB = [gpB[5]]
for i in xrange(5,0,-1):
    GE = cv2.pyrUp(gpB[i])
    L = cv2.subtract(gpB[i-1],GE)
    lpB.append(L)

# Now add left and right halves of images in each level
LS = []
for la,lb in zip(lpA,lpB):
    rows,cols,dpt = la.shape
    ls = np.hstack((la[:,0:cols/2], lb[:,cols/2:]))
    LS.append(ls)

# now reconstruct
ls_ = LS[0]
for i in xrange(1,6):
    ls_ = cv2.pyrUp(ls_)
    ls_ = cv2.add(ls_, LS[i])

# image with direct connecting each half
real = np.hstack((A[:,:cols/2],B[:,cols/2:]))

cv2.imwrite('Pyramid_blending2.jpg',ls_)
cv2.imwrite('Direct_blending.jpg',real)

def add_alpha(img, mask=None):
    Adds a masked alpha channel to an image.
    img : (M, N[, 3]) ndarray
        Image data, should be rank-2 or rank-3 with RGB channels. If img already has alpha, 
        nothing will be done. 
    mask : (M, N[, 3]) ndarray, optional
        Mask to be applied. If None, the alpha channel is added
        with full opacity assumed (1) at all locations.
    # don't do anything if there is already an alpha channel
    #return img
    if img.shape[2] > 3:
        return img
    # make sure the image is 3 channels
    if img.ndim == 2:
        img = gray2rgb(img)
    if mask is None: 
        # create transparent mask 
        # 1 should be fully transparent
        mask = np.ones(img.shape[:2], np.uint8)*255
    return np.dstack((img, mask))

def find_corners(all_corners):

    # The overally output shape will be max - min
    corner_min = np.min(all_corners, axis=0)
    corner_max = np.max(all_corners, axis=0)
    output_shape = (corner_max - corner_min)

    # Ensure integer shape with np.ceil and dtype conversion
    output_shape = np.ceil(output_shape[::-1]).astype(int)

    # This in-plane offset is the only necessary transformation for the base image
    offset = SimilarityTransform(translation= -corner_min)
    return offset, output_shape

def getKeypointandDescriptors(img, detector):
    detector = cv2.SIFT()
    kps, des = detector.detectAndCompute(img, None)
    kp = np.asarray([ for k in kps])
    return kp, des

def loadImage(img_path, detector):
    rgb = add_alpha(cv2.imread(img_path))
    img = cv2.cvtColor(rgb, cv2.COLOR_BGR2GRAY)
    # Find key points in base image 
    k, d = getKeypointandDescriptors(img, detector)
    return rgb, k, d  

def make_chunks(it, size):
    return [it[x:x+size] for x in range(0, len(it), size)]

def filter_matches(matches, ratio = 0.75):
    filtered_matches = []
    for m in matches:
        if len(m) == 2 and m[0].distance < m[1].distance * ratio:

    return filtered_matches

def match_from_to(fk, fd, tk, td, min_matches):
    FLANN_INDEX_KDTREE = 1  # bug: flann enums are missing
    flann_params = dict(algorithm = FLANN_INDEX_KDTREE, trees = 5)
    matcher = cv2.FlannBasedMatcher(flann_params, {})
    print("STARTING MATCH")
    # get matching keypoints between images (from to) or (previous, base) or (next, base)
    if 1:
        #matches = match_descriptors(fd, td, cross_check=True)
        matches = matcher.knnMatch(fd, td, k=2)
        matches_subset = filter_matches(matches)
        matches_subset = np.array([[match.trainIdx,match.queryIdx] for match in matches_subset])
        src = fk[matches_subset[:,1]]
        dst = tk[matches_subset[:,0]]
        #src = [fk[match.queryIdx] for match in matches_subset]
        # target image is base image
        #dst = [tk[match.trainIdx] for match in matches_subset]

        src = np.asarray(src)
        dst = np.asarray(dst)

        if src.shape[0] > min_matches:
            # TODO - select which transform to use based on sensor data?
            model_robust, inliers = ransac((src, dst), AffineTransform, min_samples=8, 
            bad_matches = src.shape[0]-inliers.shape[0]
            precision= 1-float(bad_matches)/float(src.shape[0])
            ransac_matches = matches_subset[inliers]
            return model_robust, ransac_matches,precision
   # except Exception, e:
   #     logging.error(e)
    return None, None, 0

def warp_img(img, transform, output_shape):
        warped = warp(img, transform, order=1, mode='constant',
                   output_shape=output_shape, clip=True, cval=0)
        return warped
    except Exception, e:
        logging.error("Error warping image %s img shape %s, output shape %s" %(e, img.shape, output_shape))
        return None

def copy_new_files(input_dir, output_dir, in_ftype, out_ftype, wsize, do_clear, limit):
    if not os.path.exists(output_dir):

    if do_clear:
        to_clear_mosaics = sorted(glob(os.path.join(output_dir, '*RUN*MATCH*%s'%out_ftype)))
        if len(to_clear_mosaics):
                logging.warning("Clearing RUN files from output_dir: %s" %output_dir)
                for f in to_clear_mosaics:
        #to_clear = sorted(glob(os.path.join(output_dir, '*%s'%out_ftype)))
        #if len(to_clear):
        #    logging.warning("Clearing files from output_dir: %s" %output_dir)
        #    for f in to_clear:
        #        os.remove(f)
   "Using convert to transfer and scan input images")
    in_files = sorted(glob(os.path.join(input_dir, '*%s'%in_ftype)))
    if limit is not None:
            in_files = in_files[:limit]
    for iimg in sorted(in_files):
        oname = os.path.basename(iimg).split('.')[0] + '.%s' %out_ftype
        ofile = os.path.join(output_dir, oname)
        if not os.path.exists(ofile):
            cmd = ["convert", iimg, "-resize", "%dx%d" %(wsize[0], wsize[1]), ofile]
  "Calling %s" %' '.join(cmd))
            #logging.debug("The file %s already exists" %ofile)

class doMosaic():
    def __init__(self, inputpath, outpath, input_image_type, do_clear=False, limit=None, wsize=(3000,2000), addedge=False):
        self.detector = "SIFT" #cv2.SIFT(4000)
        self.outpath = outpath
        self.out_ftype = 'png'
        # Parameters for nearest-neighbor matching
        self.total_matched = 0
        self.chunk_size = 3
        self.addedge = addedge
        self.brute_searched = False
        copy_new_files(inputpath, outpath, input_image_type, self.out_ftype, wsize, do_clear, limit)
        self.total_to_match = len(sorted(glob(os.path.join(self.outpath, '*%s'%self.out_ftype))))

    def run_round(self, ROUND_NUM, last_num_imgs=1e6):
        img_paths = sorted(glob(os.path.join(self.outpath, '*%s'%self.out_ftype)))
        num_imgs = len(img_paths)
        self.total_matched = self.total_to_match - num_imgs
        print("TOTAL MATCHED", self.total_matched)
        # make sure that we have made some progress
        # TODO: if num_imgs doesn't shrink, need to add full search"Starting new round: %s num_imgs: %d last_num_imgs: %s" %(ROUND_NUM, num_imgs, last_num_imgs))
        if num_imgs >= last_num_imgs:
            self.chunk_size += 1
            logging.debug("Didn't find any matches last run, increasing search space to: %s" %self.chunk_size)"FOUND %s images to stitch in round %s with chunk size: %s" %(num_imgs, ROUND_NUM, self.chunk_size ))
        # time to bring out the big guns
        # try to search every other image
        # TODO: put a cap on the number of images so this doesnt blow up
        if self.chunk_size >= num_imgs:
  "Chunk size is larger than number of images - chunk size: %s num_imgs: %s" %(self.chunk_size, num_imgs))
  "brute_forced:%s, num_imgs: %s, last_num_imgs: %s" %(self.brute_searched, num_imgs, last_num_imgs))

            if self.brute_searched and (num_imgs >= last_num_imgs):
                logger.error("Brute forced true and did not find any matches last run, exiting")
      "Entering brute force search")
                for base_index, bn in enumerate(img_paths):
          "Searching bn:%s against all other images with base index of %d" %(os.path.basename(bn), base_index))
                    self.stitch_chunk(img_paths, base_index, 0, ROUND_NUM)
                    img_paths = sorted(glob(os.path.join(self.outpath, '*%s'%self.out_ftype)))
          "Brute force search against %s found %s matches" %(os.path.basename(bn), num_imgs-len(img_paths)))
                    num_imgs = len(img_paths)
                self.brute_searched = True
            self.run_round(ROUND_NUM+1, num_imgs) 
            # divide into chunks of 3 to match together
            chunks = make_chunks(img_paths, self.chunk_size)
            # make sure we actually found files
            if len(chunks) > 1:
                for CHUNK_NUM, chunk in enumerate(chunks):
                    base_index = len(chunk)/2
                    self.stitch_chunk(chunk, base_index, CHUNK_NUM, ROUND_NUM)
                self.run_round(ROUND_NUM+1, num_imgs)   
            elif len(chunks) == 1:
      "ONLY ONE CHUNK left")
                if len(chunks[0]) == 1:
          "Working on last match")
                    self.stitch_chunk(chunks[0], 0, 0, ROUND_NUM)
                logging.error("DID not find any files")

    def stitch_chunk(self, chunk, base_index, CHUNK_NUM, ROUND_NUM):
        # TODO - still off by one eggh
        min_precision = 0.7
        min_matches = 11
        logging.debug("WORKING ON CHUNK num: %s of %s ROUND NUM: %s" %(CHUNK_NUM,' '.join([os.path.basename(c) for c in chunk]),  ROUND_NUM))
        if len(chunk) > 1:
            if not (base_index < len(chunk)):
                logging.error("Was given incompatible base_index of %d with chunk size of %d" %(base_index, len(chunk)))
                base_index = len(chunk)-1
            # load the center or right image to use as base
            bn = chunk.pop(base_index)
            #brgb = m.add_alpha(cv2.imread(img_path))
            #gimg = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
            brgb, bk, bd = m.loadImage(bn, self.detector, self.addedge)
            if bk.shape[0] > 8:
                # Shape of base image, our registration target
                r, c = brgb.shape[:2]

                # Note that transformations take coordinates in (x, y) format,
                # not (row, column), in order to be consistent with most literature
                base_corners = np.array([[0, 0], #image (0,0) coordinate
                                    [0, r], # 
                                    [c, 0],
                                    [c, r]])

                corners = deepcopy(base_corners)
                models = []
                match_names = []
                iis = []
                for name in chunk:
                    rgb, k, d = m.loadImage(name, self.detector, self.addedge)

                    model_robust, ransac_matches, matches, inliers, precision = m.match_from_to(k, d, bk, bd, min_matches)
                    #model_robust, ransac_matches, precision = match_from_to(k, d, bk, bd, min_matches)

          "match precision: %s  for base %s with %s" %(precision, os.path.basename(bn), 
                    if precision > min_precision:
                        tcorners = model_robust(base_corners)
                        corners = np.vstack((corners, tcorners))
              "Not able to match with Base:%s %d, Img: %s %d keypoints" %(os.path.basename(bn),
                if len(models):
                    offset, output_shape = m.find_corners(corners)
                    lbrgb = img_as_float(brgb)
                    brgb_warped = m.warp_img(lbrgb, offset.inverse, output_shape)
                    if brgb_warped is None:
                        logger.error("Unable to warp base img %s" %os.path.basename(bn))
                        oname = os.path.join(self.outpath, 'RUN%03d_MATCH%03d.%s' %(ROUND_NUM, CHUNK_NUM, self.out_ftype))
                        tnames = []
                        ubrgb_warped = img_as_ubyte(brgb_warped)
                        omask = np.zeros(ubrgb_warped.shape[:2], dtype=np.uint)
                        omask += ubrgb_warped[:,:,3]
                        print("Original mask", np.max(ubrgb_warped), np.min(ubrgb_warped))
                        #nbrgb_warped = img_as_ubyte(brgb_warped)
                        #nbrgb_warped[nbrgb_warped==0] = 120
                        #ulbrgb = img_as_ubyte(brgb_warped)
                        for xxv, (model, i, n) in enumerate(zip(models, iis, match_names)):
                            ## Translate base into place
                            tname = '/tmp/timg_%02d.png' %xxv
                            logging.debug("writing tmp %s to match with bn %s as %s" %(os.path.basename(n), os.path.basename(bn), tname))
                            transform = (model + offset).inverse
                            rgb_warped = m.warp_img(i, transform, output_shape)
                            urgb_warped = img_as_ubyte(rgb_warped)
                            print("MAX", np.max(omask), np.max(urgb_warped))
                            omask[urgb_warped[:,:,3] > 0] = 255
                            if rgb_warped is None:
                                logger.error("Base warp: Unable to warp img %s" %os.path.basename(n))
                                # successful warp
                                #plt.imsave(tname, rgb_warped)
                                cv2.imwrite(tname, urgb_warped)
                        if not len(tnames):
                  "Not able to match %s images to %s" %(len(chunk), os.path.basename(n)))
                            bname = '/tmp/bimg.png'
                            cv2.imwrite(bname, ubrgb_warped)
                            cmd = ['enblend']
                            toname = '/tmp/tbimg.png'
                            cmd.extend(['-o', toname])
                  "Calling subprocess command: %s" % ' '.join(cmd))
                            #omask = img_as_ubyte(omask/float(len(tnames)))
                            oimg = cv2.imread(toname)
                            print("READING OUTPUT", oimg.shape)
                            oout = m.add_alpha(oimg, omask)
                            cv2.imwrite(oname, oout)
                            #cv2.imwrite(oname.replace('.png', 'a.png'), omask)
                            print("ADDED ALPHA", oname, oout.shape)
                  "Wrote %s matches to file: %s" %(len(tnames), os.path.basename(oname)))
                            # TODO - check subprocess call
                            # should remove all in tnames
                            mdir = os.path.join(os.path.split(match_names[0])[0], 'matched')
                            if not os.path.exists(mdir):
                            for f in match_names:
                                na = os.path.split(f)[1]
                                oo = os.path.join(mdir, na)
                                cc = ['mv', f, oo]

face_test = False
ice_test = True
addedge = True
#bpath = "/Volumes/johannah_external 1/thesis-work/201511_sea_state_DRI_Sikululiaq/uas_data/seastate_october_20/n2/image/"
#inpath = os.path.join(bpath, "flight_2")
#outpath = os.path.join(bpath, "flight_2_out")
do_clear = True
#lsize = (400, 200)
if ice_test:
    bpath = "/Users/jhansen/Desktop/"
    inpath = os.path.join(bpath, "test_in")
    outpath = os.path.join(bpath, "test_out")
    do_clear = True
    lsize = (600, 400)    
if face_test:
    inpath = 'jo_patch/'
    outpath = 'aout'
    do_clear = True
    lsize = (400, 200)
dm = doMosaic(inpath, outpath, 'jpg', do_clear=do_clear, limit=14, wsize=lsize, addedge=False)

#TODO develop image names in such a way that they are temporal even after adding names
# do scale filtering based on altitude
# keep track of which images are in base
# store kpts/desc if already searched
# warp the keypoint/descriptors too

def compare(*images, **kwargs):
    Utility function to display images side by side.
    image0, image1, image2, ... : ndarrray
        Images to display.
    labels : list
        Labels for the different images.
    f, axes = plt.subplots(1, len(images), **kwargs)
    axes = np.array(axes, ndmin=1)
    labels = kwargs.pop('labels', None)
    if labels is None:
        labels = [''] * len(images)
    for n, (image, label) in enumerate(zip(images, labels)):
        axes[n].imshow(image, interpolation='nearest', cmap='gray')

a = imread('dji_0029s.jpg')
aa = deepcopy(a)

ab = cv2.imread('/Users/jhansen/Desktop/test_out/RUN001_MATCH000.png', cv2.IMREAD_UNCHANGED)

aa[:5,:,0] = 100
aa[a.shape[0]-4:,:,0] = 100
aa[:,:5,0] = 100
aa[:,a.shape[1]-4:,0] = 100

p = Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE)
output, err = p.communicate(b"input data that is passed to subprocess' s
rc = p.returncode

#from skimage.feature.util import _mask_border_keypoints, DescriptorExtractor
class zernike(DescriptorExtractor):
    def __init__(self, descriptor_size=256, patch_size=49,
                  sigma=1, sample_seed=1, mask=None):
        self.descriptor_size = descriptor_size
        self.patch_size = patch_size
        self.sigma = sigma
        self.sample_seed = sample_seed

        self.descriptors = None
        self.mask = mask
    def extract(self, image, keypoints):
        patch_size = self.patch_size
        desc_size = self.descriptor_size
        random = np.random.RandomState()
        samples = (patch_size / 5.0) * random.randn(desc_size * 8)
        samples = np.array(samples, dtype=np.int32)
        samples = samples[(samples < (patch_size // 2))
                          & (samples > - (patch_size - 2) // 2)]

        pos1 = samples[:desc_size * 2].reshape(desc_size, 2)
        pos2 = samples[desc_size * 2:desc_size * 4].reshape(desc_size, 2)
        pos1 = np.ascontiguousarray(pos1)
        pos2 = np.ascontiguousarray(pos2)
        self.mask = _mask_border_keypoints(image.shape, keypoints,
                                           patch_size // 2)
        keypoints = np.array(keypoints[self.mask, :], dtype=np.intp,
                             order='C', copy=False)
        self.descriptors = np.zeros((keypoints.shape[0], desc_size),
                                    dtype=bool, order='C')
        _zern_loop(image, self.descriptors.view(np.uint8), keypoints,
                    pos1, pos2)
def _zern_loop(image, descriptors, keypoints, pos0, pos1):
    for p in range(pos0.shape[0]):
        pr0 = pos0[p, 0]
        pc0 = pos0[p, 1]
        pr1 = pos1[p, 0]
        pc1 = pos1[p, 1]
        for k in range(keypoints.shape[0]):
            kr = keypoints[k, 0]
            kc = keypoints[k, 1]
            if image[kr + pr0, kc + pc0] < image[kr + pr1, kc + pc1]:
                descriptors[k, p] = True
                from mahotas.features import zernike_moments

br1 = zernike()
#keypoints = corner_peaks(corner_harris(img1), min_distance=5)
keypoints1 = corner_peaks(corner_harris(img1, method='eps', eps=.001, sigma=3), min_distance=5)
br1.extract(img1, keypoints1)
descriptors1 = br1.descriptors
keypoints1 = keypoints1[br1.mask]

br2 = zernike()
#keypoints1 = corner_peaks(corner_harris(img2), min_distance=5)
keypoints2 = corner_peaks(corner_harris(img2, method='eps', eps=.001, sigma=3), min_distance=5)
br2.extract(img2, keypoints)
descriptors2 = br2.descriptors
keypoints2 = keypoints2[br2.mask]

