Matching of curves, optimal transport v kernel-varifold data attachment

# Preliminary imports to get the right path to lddmm_python...
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
SCRIPT_DIR = os.path.dirname(module_path)
sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT)))

import lddmm_python  # My library folder

import time
from pylab import pause, inf
import numpy as np

from scipy.optimize import minimize

import theano
from theano import tensor as T

# Import of the relevant manifold
from lddmm_python.modules.manifolds.theano_curves import TheanoCurves
from lddmm_python.modules.manifolds.curves import Curve
from lddmm_python.modules.data_attachment.sinkhorn  import SinkhornOptions
from lddmm_python.modules.data_attachment.varifolds import VarifoldOptions
from import level_curves

# To illustrate the efficiency of the OT data attachment term,
# we now solve a simple matching problem between two curves in the plane.

npoints  = 200

q0 = level_curves('data/source.png', npoints)
Xt = level_curves('data/target.png', npoints)
all_pts = np.vstack((q0.to_array(), Xt.to_array()))
mini = np.amin(all_pts, axis=0)
maxi = np.amax(all_pts, axis=0)
midpoint = .5 * (mini + maxi)
axis_len = maxi - mini

m0 = midpoint ; s0 = np.amax(axis_len)

q0.translate_rescale(m0, s0)
Xt.translate_rescale(m0, s0)
print('Data have been rescaled to fit in the unit square.')

# Convert the source Q0 and the target Xt from Pythonic objects to simple numpy arrays
Q0 = q0.to_array()
nq = len(Q0) ; d  = Q0.shape[1]

Xt_emb = Xt.to_varifold()

def compute_matching(foldername, 
                     details_scale, max_interaction_scale, 
                     orientation_weight, orientation_order,
                     maxit_sinkhorn = 10000, maxit_descent = 1000) :
    if attachment_type == 'varifold-kernel' :
        use_transport = False
    elif attachment_type == 'varifold-sinkhorn' :
        use_transport = True
    # ========================================================================================   
    if not use_transport :
        data_attachment = (attachment_type, ('gaussian', details_scale))
    else :
        # In the paper, we use a simple "transport-only/Wasserstein" cost, with autodiff,
        # which is the "bug-proofed" version.
        data_attachment = (attachment_type, 
                                orientation_weight = orientation_weight,
                                orientation_order  = orientation_order ),
                                epsilon         = details_scale**2,
                                niter           = maxit_sinkhorn,  # Won't be reached in practice
                                rho             = max_interaction_scale**2,
                                tau             = -.8,    # Good enough acceleration
                                dual_cost       = False,  # (Dual v) Primal
                                discard_entropy = True,   # Remove  (v Keep) the -eps*H(g) in the primal
                                discard_KL      = True,   # Discard (v Compute) the rho * KL(...) term ?
                                grad_hack       = False,  # (Maths v) AutoDiff
                                display_error   = False   # True if you want to show the Sinkhorn number of steps, but it involves a slight hacl
                            )  ) )
    M = TheanoCurves(q0,      # TheanoCurves models the orbit of the curve q0 
                 kernel = ('gaussian', [(1., .025), (.75, .15)]), # Good enough kernel : high frequencies + large carriage
                 weights               = (0.001, 1), # Weights : 1. for the attachment, 1e-3 for the geodesic squared length
                 data_attachment       = data_attachment,
                 plot_interactive      = False,
                 plot_file             = True,
                 foldername            = 'results/vtk_files/' + foldername
    # =========================================================================================
    # Local density estimation - useful for the LBFGS preconditionning :
    # the shooting is parametrized by a normalized moment r0

    vertex = T.matrix()
    M_kernel = theano.function([vertex], M._Kq(vertex), allow_input_downcast=True)
    K_Q0 = M_kernel(Q0)
    dens = np.sum(K_Q0, 1)

    def p0_from_r0(r0) :
        p0 = r0.reshape((nq,d))
        p0 =  (p0.T * (1./ dens)).T
        return p0
    def dr0_from_dp0(dp0) :
        "Adjoint of a pointwise multiplication and transposes : self."
        dr0 = (dp0.T * (1./ dens)).T
        dr0 = dr0.ravel()
        return dr0
    # ======================================================================================================
    # L-BFGS minimization 
    nits  = maxit_descent # max number of iterations
    P0 = np.zeros((nq,d)) # Null initialization for the shooting momentum
    # N.B. : in actual fact, we plot every single model/plan along the line search,
    #        not only those that actually correspond to a BFGS descent.
    #        This is more accurate to estimate the cost of the algorithm.
    def matching_problem(r0) :
        p0 = p0_from_r0(r0) += 1
        [c, dq_c, dp_c, q1, cost_info] = M.shooting_cost(Q0, p0, target = Xt_emb)
        print('Cost value : ', c)
        plan = cost_info
        M.quiver(Q0, p0 ,                name='Descent/Momentums/Momentum_'+str(
        M.marker(q1,                     name='Descent/Models/Model_'+str(
        if use_transport :
            M.show_transport(q1, Xt, plan,   name='Descent/Plans/Plan_'+str(
        # The fortran routines used by scipy.optimize expect float64 vectors
        # instead of the gpu-friendly float32 matrices :
        dr0 = dr0_from_dp0(dp_c)
        return (c, dr0.astype('float64')) = 0

    time1 = time.time()
    res = minimize( matching_problem,     # function to minimize
                    P0.ravel(),           # starting estimate
                    method = 'L-BFGS-B',  # an order 2 method
                    jac = True,           # matching_problems also returns the gradient
                    options = dict(
                        disp    = True,
                        maxiter = nits,   # Won't be reached in practice
                        ftol    = .0000001, # Don't bother fitting the shapes to float precision, even for the paper...
                        maxcor  = 10      # Number of previous gradients used to approximate the Hessian
    time2 = time.time()

    P0 = p0_from_r0(res.x)
    print('Convergence success  : ', res.success, ', status = ', res.status)
    print('Optimization message : ', res.message.decode('UTF-8'))
    print('Final cost   after ', res.nit, ' iterations : ',    
    print('Elapsed time after ', res.nit, ' iterations : ', '{0:.2f}'.format(time2 - time1), 's')
    # =================================================================================================
    # Visualize the end point
    [Qt, Pt] = M.hamiltonian_trajectory(Q0, P0)

    M.current_axis = []
    M.marker(Q0,     name='Template')
    M.plot_traj(Qt,  name='Shoot/Shoot')
    M.plot_momentums(Qt, Pt, name='Momentums/Momentum')
    M.marker(Qt[-1], name='Model')

    Gt = M.grid_trajectory(Q0, P0, [(-.5,.5), (-.5,.5)], nlines = 21)
    M.file_plot_grids(Gt, 'Grid/grid')

Time to compute all that !

if True :
    compute_matching('kernel_big/'  ,  .2,  .2, 'varifold-kernel',   1., 4)
    compute_matching('kernel_small/', .05, .05, 'varifold-kernel',   1., 4)

if True :
    compute_matching('sinkhorn_eps-m_rho-s/', .03,  .1, 'varifold-sinkhorn', 1., 4)
    compute_matching('sinkhorn_eps-m_rho-m/', .03, .15, 'varifold-sinkhorn', 1., 4)
    compute_matching('sinkhorn_eps-m_rho-l/', .03,  .5, 'varifold-sinkhorn', 1., 4)

if True :
    compute_matching('sinkhorn_eps-l_rho-l/', .1,   .5, 'varifold-sinkhorn', 1., 4)
    #compute_matching('sinkhorn_eps-m_rho-l/', .03,  .5, 'varifold-sinkhorn', 1., 4)
    compute_matching('sinkhorn_eps-s_rho-l/', .015, .5, 'varifold-sinkhorn', 1., 4)

if False : # just to see how the plan evolves with the number of sinkhorn iterations
    compute_matching('sinkhorn_it5/', .05, .5, 'varifold-sinkhorn', 1., 4, maxit_sinkhorn = 5, maxit_descent = 2)
    compute_matching('sinkhorn_it10/', .05, .5, 'varifold-sinkhorn', 1., 4, maxit_sinkhorn = 10, maxit_descent = 2)
    compute_matching('sinkhorn_it25/', .05, .5, 'varifold-sinkhorn', 1., 4, maxit_sinkhorn = 25, maxit_descent = 2)

