Over-sampling using LMNN


In [2]:
%matplotlib inline

import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse
import seaborn as sns
sns.set()

import numpy as np
import itertools

from sklearn.datasets import make_classification
from sklearn.decomposition import PCA
from sklearn.preprocessing import normalize

# Save a nice dark grey as a variable
almost_black = '#262626'

%run '/home/glemaitre/anaconda/lib/python2.7/site-packages/modshogun.py'
import modshogun

Generate some data with dimension reduction to observed something.


In [3]:
# Generate some data
x, y = make_classification(n_features=2, n_redundant=0, n_informative=1, class_sep=1.,
                           n_clusters_per_class=1, n_samples=10000, weights=[0.1, 0.9],
                           random_state=9)

# Instanciate a PCA object for the sake of easy visualisation
# pca = PCA(n_components = 2)

# Fit and transform x to visualise inside a 2D feature space
# x_vis = pca.fit_transform(x)

# Plot the original data
# Plot the two classes
palette = sns.color_palette()
plt.scatter(x[y==0, 0], x[y==0, 1], label="Class #0", alpha=0.5, 
            edgecolor=almost_black, facecolor=palette[0], linewidth=0.15)
plt.scatter(x[y==1, 0], x[y==1, 1], label="Class #1", alpha=0.5, 
            edgecolor=almost_black, facecolor=palette[2], linewidth=0.15)

plt.legend()
plt.show()


Compute the LMNN for each sample of the minority class


In [10]:
from collections import Counter

# Compute the classes representation
stat_class = Counter(y)
# Find the minority class key
label_min_class = min(stat_class, key=stat_class.get)
label_maj_class = max(stat_class, key=stat_class.get)
# Save the sample index of this class
idx_min_class = np.ravel(np.nonzero(y == label_min_class))
idx_maj_class = np.ravel(np.nonzero(y == label_maj_class))

from sklearn.neighbors import NearestNeighbors

neighbours_required = 7

# Create an object NN only for the minority class
min_class_NN = NearestNeighbors(n_neighbors=neighbours_required, metric='l2')#, n_jobs=-1)
min_class_NN.fit(x[idx_min_class, :], y[idx_min_class])

# Create an object NN only for the majority class
max_class_NN = NearestNeighbors(n_neighbors=neighbours_required, metric='l2')#, n_jobs=-1)
max_class_NN.fit(x[idx_maj_class, :], y[idx_maj_class])

# Create an object NN for the whole dataset
data_NN = NearestNeighbors(n_neighbors=1, metric='l2')#,n_jobs=-1)
data_NN.fit(x, y)

# Compute all the distance of the k-NN on the minority class on the original data for the point of the minority class
dist_min, ind_min = min_class_NN.kneighbors(x[idx_min_class, :])
# Compute all the distance of the k-NN on the majority class on the original data for the point of the minority class
dist_max, ind_max = max_class_NN.kneighbors(x[idx_min_class, :])
# Find the largest distance for the both above distances
dist_min = np.max(dist_min, axis=1)
dist_max = np.max(dist_max, axis=1)

# Select the max_dist between dist_min and dist_max to ensure a minimum of neighbours
#max_dist = dist_max.copy()
max_dist = dist_min.copy()
#max_dist[np.nonzero(dist_min < dist_max)] = dist_max[np.nonzero(dist_min < dist_max)]
#max_dist[np.nonzero(dist_min > dist_max)] = dist_min[np.nonzero(dist_min > dist_max)]

# Now find the samples to consider with the distance extracted using only the minority class
# We need to loop since that the distance is changing
s_considers = []
for s, d, tmpi in zip(x[idx_min_class], max_dist, idx_min_class):
    dist, ind = data_NN.radius_neighbors(X=np.atleast_2d(s), radius=d)
    a = ind[0]
    # Move to the first position the index of interest
    idxint = np.nonzero(a==tmpi)
    a[idxint] = a[0]
    a[0] = tmpi
    s_considers.append(a)
    
from metric_learn import LMNN, ITML, SDML, LSML

cov_mat = []
idx_cons = []
for s in s_considers:
    # Extract the interesting data
    x_s = x[np.ravel(s), :]
    y_s = y[np.ravel(s)]
        
    # Check if there is any imposter
    stat = Counter(y_s)
    print stat
    
    num_constraints = 200
    
    if (len(stat) > 1):
        if (stat[label_min_class] > stat[label_maj_class]):
            #print (stat[label_maj_class] / stat[label_min_class])
            if ((float(stat[label_min_class]) / float(stat[label_maj_class])) > 1 and 
                (float(stat[label_min_class]) / float(stat[label_maj_class])) < 2):
                
                # Fit the LMNN for these data
                s_lmnn = LMNN(k=1, convergence_tol=1e-9, min_iter=5000, max_iter=5000)
                s_lmnn.fit(x_s, y_s)
                #s_lmnn = ITML(max_iters=5000, convergence_threshold=1e-9)
                #s_lmnn.fit(x_s, ITML.prepare_constraints(y_s, x_s.shape[0], num_constraints))
                

                figure, axis = plt.subplots(1,1)
                plt.scatter(x_s[y_s==0, 0], x_s[y_s==0, 1], label="Class #0", alpha=0.5, 
                edgecolor=almost_black, facecolor=palette[0], linewidth=0.15)
                plt.scatter(x_s[y_s==1, 0], x_s[y_s==1, 1], label="Class #1", alpha=0.5, 
                edgecolor=almost_black, facecolor=palette[2], linewidth=0.15)
                cm = np.matrix(s_lmnn.metric())
                print np.matrix(cm).I
                elli = make_covariance_ellipse(np.matrix(cm).I, x_s[0], 1)
                axis.add_artist(elli)
                plt.show()

                # Store the covariance matrix
                cov_mat.append(np.matrix(s_lmnn.metric()))
                idx_cons.append(True)
            else:
                cov_mat.append(np.ma.cov(x_s.T))
                idx_cons.append(False)
        else:
            cov_mat.append(np.ma.cov(x_s.T))
            idx_cons.append(False)
    else:
        cov_mat.append(np.ma.cov(x_s.T))
        idx_cons.append(False)


Counter({0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({1: 828, 0: 7})
Counter({1: 14, 0: 7})
Counter({0: 7})
Counter({0: 7, 1: 6})
[[ 0.00579373 -0.00040028]
 [-0.00040028  0.00782237]]
Counter({0: 7, 1: 1})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({1: 240, 0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7, 1: 5})
[[   0.90852223   26.68875268]
 [  26.68875268  794.294801  ]]
Counter({1: 841, 0: 6})
Counter({0: 7, 1: 1})
Counter({0: 6, 1: 1})
Counter({0: 6})
Counter({0: 7})
Counter({0: 7, 1: 4})
[[ 0.15186859 -0.05587179]
 [-0.05587179  0.36288714]]
Counter({1: 564, 0: 7})
Counter({0: 7})
Counter({0: 7, 1: 4})
[[ 0.00463059 -0.00040301]
 [-0.00040301  0.00874444]]
Counter({0: 7, 1: 1})
Counter({0: 6})
Counter({0: 7})
Counter({0: 6})
Counter({0: 7})
Counter({0: 6})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7, 1: 1})
Counter({0: 7, 1: 3})
Counter({1: 46, 0: 7})
Counter({0: 7})
Counter({0: 7, 1: 4})
[[ 0.0498796  -0.01741362]
 [-0.01741362  0.15355456]]
Counter({0: 7, 1: 1})
Counter({0: 7})
Counter({1: 15, 0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({1: 29, 0: 6})
Counter({0: 7, 1: 6})
[[ 0.01695323 -0.006352  ]
 [-0.006352    0.01267806]]
Counter({0: 6})
Counter({1: 106, 0: 6})
Counter({1: 114, 0: 7})
Counter({0: 7, 1: 1})
Counter({0: 7, 1: 1})
Counter({1: 151, 0: 7})
Counter({0: 7})
Counter({0: 6})
Counter({0: 7})
Counter({0: 7, 1: 2})
Counter({0: 7})
Counter({1: 9, 0: 6})
Counter({0: 7, 1: 4})
[[ 0.00649212  0.00399015]
 [ 0.00399015  0.00975788]]
Counter({1: 10, 0: 7})
Counter({1: 907, 0: 7})
Counter({0: 7, 1: 4})
[[ 0.02388974 -0.00350034]
 [-0.00350034  0.03611192]]
Counter({1: 214, 0: 7})
Counter({1: 81, 0: 7})
Counter({0: 7})
Counter({1: 224, 0: 7})
Counter({1: 213, 0: 7})
Counter({0: 6, 1: 1})
Counter({0: 7})
Counter({0: 7, 1: 1})
Counter({0: 7, 1: 1})
Counter({1: 113, 0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({1: 58, 0: 6})
Counter({0: 6, 1: 1})
Counter({0: 6, 1: 1})
Counter({1: 11, 0: 7})
Counter({0: 7})
Counter({1: 316, 0: 7})
Counter({0: 7, 1: 1})
Counter({0: 7})
Counter({1: 42, 0: 7})
Counter({0: 7})
Counter({1: 67, 0: 6})
Counter({0: 7, 1: 1})
Counter({1: 9, 0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7, 1: 3})
Counter({0: 6})
Counter({0: 7})
Counter({1: 69, 0: 7})
Counter({0: 6})
Counter({0: 6})
Counter({0: 6})
Counter({0: 7, 1: 3})
Counter({1: 152, 0: 7})
Counter({0: 7})
Counter({0: 6})
Counter({1: 198, 0: 7})
Counter({0: 6, 1: 1})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({1: 19, 0: 7})
Counter({0: 7})
Counter({1: 39, 0: 6})
Counter({0: 7, 1: 3})
Counter({0: 7})
Counter({0: 7})
Counter({1: 52, 0: 7})
Counter({0: 6})
Counter({1: 75, 0: 7})
Counter({0: 6})
Counter({0: 7, 1: 1})
Counter({0: 7})
Counter({0: 7, 1: 3})
Counter({1: 166, 0: 7})
Counter({1: 232, 0: 7})
Counter({0: 6, 1: 2})
Counter({0: 6, 1: 3})
Counter({0: 7, 1: 1})
Counter({0: 6, 1: 5})
[[ 0.52303096 -0.41050632]
 [-0.41050632  0.33334457]]
Counter({1: 24, 0: 7})
Counter({0: 7})
Counter({1: 75, 0: 7})
Counter({0: 6})
Counter({0: 7, 1: 1})
Counter({1: 1059, 0: 7})
Counter({1: 84, 0: 7})
Counter({1: 10, 0: 7})
Counter({0: 7})
Counter({1: 577, 0: 7})
Counter({0: 7, 1: 1})
Counter({0: 6})
Counter({1: 29, 0: 6})
Counter({0: 7, 1: 3})
Counter({0: 7, 1: 6})
[[ 0.01868293 -0.00139871]
 [-0.00139871  0.01798727]]
Counter({0: 7})
Counter({0: 6, 1: 6})
Counter({0: 6, 1: 6})
Counter({1: 341, 0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({1: 29, 0: 6})
Counter({0: 7, 1: 3})
Counter({1: 101, 0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({1: 10, 0: 7})
Counter({1: 179, 0: 7})
Counter({1: 87, 0: 7})
Counter({1: 9, 0: 6})
Counter({0: 7, 1: 3})
Counter({0: 7, 1: 4})
[[  5.52277569e-03  -1.38829104e-05]
 [ -1.38829104e-05   9.32265738e-03]]
Counter({1: 520, 0: 6})
Counter({1: 8, 0: 6})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({1: 14, 0: 7})
Counter({0: 6, 1: 2})
Counter({1: 71, 0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({0: 6, 1: 4})
[[ 0.00577334  0.00376219]
 [ 0.00376219  0.01585844]]
Counter({1: 104, 0: 7})
Counter({0: 7})
Counter({1: 16, 0: 7})
Counter({0: 7, 1: 1})
Counter({0: 7})
Counter({0: 7})
Counter({1: 515, 0: 7})
Counter({0: 6})
Counter({0: 7})
Counter({0: 6})
Counter({1: 215, 0: 7})
Counter({0: 6})
Counter({0: 7, 1: 3})
Counter({0: 7})
Counter({1: 39, 0: 6})
Counter({1: 19, 0: 7})
Counter({0: 7, 1: 5})
[[ 0.00237759  0.00329611]
 [ 0.00329611  0.01452417]]
Counter({1: 433, 0: 7})
Counter({0: 7})
Counter({0: 6, 1: 2})
Counter({1: 93, 0: 7})
Counter({0: 7, 1: 1})
Counter({0: 7})
Counter({1: 38, 0: 7})
Counter({0: 6})
Counter({1: 29, 0: 7})
Counter({0: 6})
Counter({0: 7, 1: 1})
Counter({1: 43, 0: 6})
Counter({0: 7, 1: 1})
Counter({0: 7})
Counter({0: 7})
Counter({0: 6})
Counter({0: 7, 1: 1})
Counter({1: 38, 0: 7})
Counter({0: 7})
Counter({1: 13, 0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({1: 563, 0: 6})
Counter({1: 8, 0: 6})
Counter({1: 10, 0: 7})
Counter({1: 64, 0: 7})
Counter({0: 7, 1: 1})
Counter({0: 6, 1: 3})
Counter({0: 7, 1: 5})
[[ 0.01208267  0.00335002]
 [ 0.00335002  0.0077979 ]]
Counter({0: 7})
Counter({0: 7})
Counter({1: 82, 0: 7})
Counter({0: 7})
Counter({0: 6})
Counter({0: 6, 1: 2})
Counter({0: 7, 1: 4})
[[ 0.00439009 -0.00223409]
 [-0.00223409  0.00553405]]
Counter({1: 37, 0: 7})
Counter({0: 7})
Counter({1: 526, 0: 7})
Counter({0: 7})
Counter({0: 6, 1: 6})
Counter({0: 6})
Counter({0: 7})
Counter({1: 216, 0: 7})
Counter({0: 7, 1: 3})
Counter({0: 7})
Counter({0: 7, 1: 7})
Counter({0: 7})
Counter({0: 7})
Counter({0: 6, 1: 2})
Counter({0: 6})
Counter({0: 7})
Counter({1: 91, 0: 7})
Counter({0: 6})
Counter({0: 7})
Counter({0: 7, 1: 1})
Counter({0: 7})
Counter({0: 7, 1: 3})
Counter({0: 7})
Counter({1: 745, 0: 7})
Counter({0: 6})
Counter({0: 7, 1: 1})
Counter({1: 48, 0: 7})
Counter({1: 67, 0: 6})
Counter({0: 6})
Counter({1: 8, 0: 7})
Counter({0: 7})
Counter({1: 25, 0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7, 1: 2})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7, 1: 1})
Counter({0: 7})
Counter({0: 7, 1: 3})
Counter({0: 7})
Counter({1: 7, 0: 6})
Counter({0: 7, 1: 1})
Counter({0: 7})
Counter({0: 7})
Counter({0: 6})
Counter({0: 7})
Counter({0: 7, 1: 2})
Counter({0: 7})
Counter({0: 6, 1: 3})
Counter({1: 46, 0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({1: 240, 0: 6})
Counter({0: 7})
Counter({0: 7, 1: 1})
Counter({1: 154, 0: 6})
Counter({0: 7})
Counter({1: 8, 0: 6})
Counter({0: 6, 1: 2})
Counter({1: 184, 0: 6})
Counter({0: 6, 1: 5})
[[ 0.00466381 -0.00074879]
 [-0.00074879  0.00215031]]
Counter({1: 25, 0: 6})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7, 1: 2})
Counter({0: 7, 1: 3})
Counter({1: 177, 0: 7})
Counter({1: 463, 0: 7})
Counter({0: 7, 1: 1})
Counter({0: 6, 1: 2})
Counter({0: 7})
Counter({0: 6, 1: 2})
Counter({0: 7})
Counter({0: 7, 1: 1})
Counter({0: 7, 1: 1})
Counter({1: 8, 0: 7})
Counter({0: 7})
Counter({0: 7, 1: 3})
Counter({1: 1142, 0: 7})
Counter({0: 7})
Counter({1: 16, 0: 6})
Counter({0: 6, 1: 1})
Counter({0: 7, 1: 1})
Counter({1: 65, 0: 7})
Counter({1: 8, 0: 7})
Counter({0: 7})
Counter({0: 6})
Counter({0: 7})
Counter({1: 752, 0: 7})
Counter({1: 13, 0: 7})
Counter({0: 6})
Counter({0: 7})
Counter({0: 7})
Counter({0: 6, 1: 3})
Counter({0: 7, 1: 1})
Counter({0: 7, 1: 7})
Counter({1: 8, 0: 7})
Counter({0: 7})
Counter({0: 7, 1: 5})
[[ 0.00711407  0.00330653]
 [ 0.00330653  0.01098226]]
Counter({1: 46, 0: 7})
Counter({0: 6})
Counter({1: 10, 0: 7})
Counter({0: 6, 1: 2})
Counter({0: 7})
Counter({0: 7})
Counter({1: 619, 0: 6})
Counter({0: 7})
Counter({0: 6, 1: 2})
Counter({1: 8, 0: 6})
Counter({0: 7, 1: 4})
[[ 0.00792307  0.00213326]
 [ 0.00213326  0.00703884]]
Counter({0: 7})
Counter({1: 116, 0: 7})
Counter({1: 122, 0: 7})
Counter({1: 48, 0: 6})
Counter({0: 6})
Counter({0: 7})
Counter({0: 7, 1: 2})
Counter({0: 6, 1: 1})
Counter({0: 7, 1: 6})
[[  6.35506132e-03   9.36611864e-05]
 [  9.36611864e-05   8.22892678e-03]]
Counter({0: 6})
Counter({0: 7})
Counter({1: 28, 0: 7})
Counter({1: 15, 0: 7})
Counter({1: 913, 0: 7})
Counter({1: 11, 0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({0: 6})
Counter({0: 6, 1: 1})
Counter({0: 7})
Counter({0: 6})
Counter({0: 7})
Counter({0: 7})
Counter({0: 6, 1: 3})
Counter({0: 7})
Counter({0: 6, 1: 2})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({1: 325, 0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({1: 19, 0: 7})
Counter({0: 7, 1: 7})
Counter({0: 7})
Counter({0: 6})
Counter({0: 7, 1: 1})
Counter({1: 1370, 0: 7})
Counter({0: 7, 1: 4})
[[ 0.03840033  0.10856947]
 [ 0.10856947  0.38816604]]
Counter({0: 6})
Counter({0: 6, 1: 1})
Counter({0: 7})
Counter({0: 7})
Counter({1: 39, 0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({1: 20, 0: 7})
Counter({0: 7})
Counter({0: 6})
Counter({0: 7})
Counter({0: 7})
Counter({1: 17, 0: 7})
Counter({1: 220, 0: 7})
Counter({0: 7})
Counter({1: 83, 0: 6})
Counter({1: 23, 0: 7})
Counter({1: 20, 0: 7})
Counter({0: 6})
Counter({0: 7, 1: 4})
[[ 176405.64113981 -397526.80458942]
 [-397526.80458942  895819.23526599]]
Counter({0: 7})
Counter({0: 6})
Counter({0: 6})
Counter({1: 70, 0: 7})
Counter({1: 10, 0: 7})
Counter({0: 7})
Counter({1: 219, 0: 7})
Counter({0: 6})
Counter({1: 55, 0: 7})
Counter({1: 32, 0: 7})
Counter({1: 16, 0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7, 1: 1})
Counter({0: 7, 1: 1})
Counter({0: 7, 1: 3})
Counter({0: 7})
Counter({1: 17, 0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7, 1: 1})
Counter({1: 55, 0: 7})
Counter({0: 6, 1: 2})
Counter({0: 7})
Counter({1: 45, 0: 6})
Counter({0: 6})
Counter({0: 7})
Counter({0: 7})
Counter({0: 6, 1: 3})
Counter({1: 8, 0: 7})
Counter({0: 7})
Counter({0: 6})
Counter({1: 115, 0: 7})
Counter({0: 7, 1: 1})
Counter({1: 19, 0: 6})
Counter({0: 7})
Counter({0: 7, 1: 1})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7, 1: 1})
Counter({0: 7, 1: 1})
Counter({0: 7})
Counter({1: 37, 0: 7})
Counter({0: 7})
Counter({0: 6, 1: 4})
[[  1.18221384e+14   5.95775002e+14]
 [  5.95775002e+14   3.00239975e+15]]
Counter({0: 6})
Counter({0: 7})
Counter({0: 6})
Counter({0: 7})
Counter({0: 6, 1: 2})
Counter({0: 7})
Counter({0: 7, 1: 1})
Counter({0: 7})
Counter({0: 7, 1: 3})
Counter({0: 7})
Counter({0: 7, 1: 2})
Counter({0: 6})
Counter({1: 76, 0: 7})
Counter({1: 17, 0: 7})
Counter({0: 7})
Counter({1: 37, 0: 7})
Counter({0: 6})
Counter({0: 7, 1: 2})
Counter({0: 7})
Counter({0: 7, 1: 1})
Counter({0: 6, 1: 2})
Counter({0: 7})
Counter({1: 7, 0: 6})
Counter({1: 15, 0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7, 1: 3})
Counter({1: 10, 0: 7})
Counter({0: 7, 1: 3})
Counter({0: 7, 1: 7})
Counter({0: 7, 1: 6})
[[ 0.00395137 -0.0016684 ]
 [-0.0016684   0.00681988]]
Counter({0: 7, 1: 1})
Counter({0: 7, 1: 4})
[[ 0.38564981 -0.00436144]
 [-0.00436144  0.17324145]]
Counter({0: 7, 1: 1})
Counter({0: 7})
Counter({1: 369, 0: 7})
Counter({1: 13, 0: 7})
Counter({1: 7, 0: 6})
Counter({0: 7})
Counter({1: 16, 0: 7})
Counter({1: 629, 0: 7})
Counter({0: 6})
Counter({0: 6})
Counter({0: 7, 1: 3})
Counter({0: 7, 1: 3})
Counter({0: 7, 1: 1})
Counter({0: 6})
Counter({0: 7})
Counter({0: 7})
Counter({1: 31, 0: 7})
Counter({1: 306, 0: 7})
Counter({1: 9, 0: 7})
Counter({1: 70, 0: 6})
Counter({1: 1439, 0: 6})
Counter({0: 7})
Counter({0: 6, 1: 3})
Counter({0: 6})
Counter({1: 11, 0: 6})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({0: 6})
Counter({0: 7, 1: 1})
Counter({0: 7})
Counter({0: 7, 1: 5})
[[ 0.00496856 -0.0042735 ]
 [-0.0042735   0.01417977]]
Counter({0: 7})
Counter({0: 7})
Counter({1: 19, 0: 7})
Counter({1: 26, 0: 7})
Counter({1: 19, 0: 7})
Counter({1: 12, 0: 6})
Counter({0: 7, 1: 2})
Counter({0: 7, 1: 2})
Counter({0: 6})
Counter({0: 7})
Counter({0: 7, 1: 7})
Counter({0: 7, 1: 7})
Counter({0: 7})
Counter({0: 7, 1: 1})
Counter({0: 7, 1: 1})
Counter({1: 10, 0: 7})
Counter({1: 373, 0: 6})
Counter({0: 7})
Counter({0: 6, 1: 1})
Counter({0: 7})
Counter({1: 21, 0: 7})
Counter({1: 114, 0: 7})
Counter({0: 7, 1: 1})
Counter({0: 7, 1: 7})
Counter({1: 217, 0: 7})
Counter({1: 21, 0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({0: 6, 1: 5})
[[ 0.0782789   0.00127634]
 [ 0.00127634  0.0372014 ]]
Counter({0: 7, 1: 1})
Counter({0: 7, 1: 2})
Counter({0: 7})
Counter({0: 7, 1: 3})
Counter({0: 7})
Counter({1: 14, 0: 7})
Counter({1: 16, 0: 7})
Counter({1: 8, 0: 7})
Counter({0: 7, 1: 2})
Counter({0: 6, 1: 1})
Counter({0: 7, 1: 1})
Counter({0: 6})
Counter({0: 7, 1: 3})
Counter({1: 1023, 0: 7})
Counter({0: 7})
Counter({0: 7, 1: 6})
[[ 0.03182352  0.02091355]
 [ 0.02091355  0.06887874]]
Counter({0: 6, 1: 6})
Counter({0: 6})
Counter({0: 6, 1: 4})
[[ 0.01525936 -0.00208017]
 [-0.00208017  0.00144463]]
Counter({0: 7, 1: 1})
Counter({0: 6})
Counter({1: 14, 0: 7})
Counter({1: 45, 0: 7})
Counter({1: 95, 0: 7})
Counter({1: 345, 0: 7})
Counter({1: 68, 0: 7})
Counter({0: 7, 1: 3})
Counter({0: 7})
Counter({0: 7, 1: 1})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7, 1: 5})
[[ 0.01806509 -0.00783557]
 [-0.00783557  0.01401319]]
Counter({0: 7})
Counter({0: 6})
Counter({0: 6, 1: 2})
Counter({0: 6, 1: 1})
Counter({0: 7})
Counter({0: 6})
Counter({0: 7, 1: 2})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({0: 6})
Counter({0: 7})
Counter({0: 7})
Counter({0: 6, 1: 1})
Counter({0: 7})
Counter({0: 7, 1: 1})
Counter({0: 7})
Counter({0: 7, 1: 4})
[[ 0.0617996   0.0118647 ]
 [ 0.0118647   0.01548934]]
Counter({1: 1357, 0: 6})
Counter({1: 56, 0: 6})
Counter({0: 7, 1: 3})
Counter({0: 7, 1: 1})
Counter({1: 230, 0: 7})
Counter({0: 7})
Counter({1: 1424, 0: 7})
Counter({0: 7, 1: 1})
Counter({1: 82, 0: 6})
Counter({0: 7})
Counter({0: 7})
Counter({1: 398, 0: 7})
Counter({0: 7})
Counter({0: 6})
Counter({0: 7, 1: 3})
Counter({0: 7})
Counter({0: 6})
Counter({0: 7, 1: 3})
Counter({0: 7})
Counter({1: 26, 0: 6})
Counter({0: 7})
Counter({0: 7})
Counter({0: 6})
Counter({1: 848, 0: 7})
Counter({0: 6})
Counter({1: 14, 0: 7})
Counter({1: 42, 0: 7})
Counter({0: 7, 1: 3})
Counter({1: 346, 0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({1: 45, 0: 7})
Counter({1: 59, 0: 6})
Counter({1: 17, 0: 6})
Counter({0: 7})
Counter({0: 7})
Counter({1: 35, 0: 7})
Counter({0: 7, 1: 1})
Counter({0: 7})
Counter({0: 7, 1: 2})
Counter({0: 6})
Counter({0: 7, 1: 3})
Counter({1: 35, 0: 7})
Counter({0: 6})
Counter({0: 6})
Counter({0: 7, 1: 7})
Counter({0: 7, 1: 2})
Counter({0: 6})
Counter({1: 74, 0: 6})
Counter({1: 8, 0: 7})
Counter({0: 7})
Counter({0: 6, 1: 1})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7, 1: 3})
Counter({0: 6})
Counter({1: 38, 0: 7})
Counter({0: 7, 1: 5})
[[ 0.00790449  0.00024805]
 [ 0.00024805  0.00211973]]
Counter({0: 7, 1: 1})
Counter({1: 38, 0: 6})
Counter({1: 334, 0: 6})
Counter({1: 484, 0: 7})
Counter({1: 25, 0: 7})
Counter({0: 7, 1: 1})
Counter({0: 7, 1: 3})
Counter({0: 7, 1: 1})
Counter({0: 7, 1: 3})
Counter({1: 12, 0: 7})
Counter({0: 6})
Counter({0: 7, 1: 2})
Counter({1: 11, 0: 7})
Counter({1: 25, 0: 7})
Counter({0: 7, 1: 3})
Counter({0: 7})
Counter({1: 105, 0: 7})
Counter({0: 7, 1: 1})
Counter({0: 7, 1: 6})
[[ 0.00619698  0.00100345]
 [ 0.00100345  0.00966861]]
Counter({1: 132, 0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({0: 6})
Counter({0: 7})
Counter({0: 7})
Counter({1: 73, 0: 7})
Counter({0: 7})
Counter({1: 8, 0: 6})
Counter({0: 7, 1: 1})
Counter({0: 7, 1: 2})
Counter({1: 27, 0: 7})
Counter({0: 7})
Counter({1: 84, 0: 7})
Counter({0: 7})
Counter({0: 7, 1: 2})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({1: 10, 0: 7})
Counter({0: 7, 1: 1})
Counter({1: 44, 0: 7})
Counter({0: 7})
Counter({0: 6})
Counter({0: 7})
Counter({0: 7, 1: 1})
Counter({0: 7})
Counter({1: 11, 0: 7})
Counter({1: 345, 0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({0: 6, 1: 5})
[[ 0.00831218 -0.00292618]
 [-0.00292618  0.0071887 ]]
Counter({1: 9, 0: 6})
Counter({0: 6})
Counter({0: 7})
Counter({0: 7})
Counter({1: 40, 0: 7})
Counter({0: 7, 1: 1})
Counter({1: 10, 0: 6})
Counter({1: 178, 0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({1: 50, 0: 7})
Counter({0: 7, 1: 7})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({0: 6})
Counter({0: 7})
Counter({0: 7, 1: 1})
Counter({1: 11, 0: 7})
Counter({0: 7, 1: 3})
Counter({0: 7})
Counter({0: 7, 1: 1})
Counter({1: 62, 0: 6})
Counter({1: 1206, 0: 7})
Counter({1: 10, 0: 7})
Counter({1: 37, 0: 7})
Counter({0: 7})
Counter({1: 8, 0: 7})
Counter({1: 12, 0: 6})
Counter({1: 72, 0: 6})
Counter({1: 50, 0: 7})
Counter({0: 7})
Counter({1: 295, 0: 7})
Counter({0: 6})
Counter({1: 87, 0: 7})
Counter({1: 1067, 0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({1: 84, 0: 7})
Counter({0: 6})
Counter({0: 7})
Counter({1: 87, 0: 7})
Counter({0: 7, 1: 3})
Counter({0: 7})
Counter({0: 7})
Counter({0: 6, 1: 1})
Counter({0: 7})
Counter({0: 7})
Counter({0: 6})
Counter({0: 7, 1: 2})
Counter({1: 412, 0: 7})
Counter({0: 7})
Counter({0: 7, 1: 3})
Counter({1: 472, 0: 7})
Counter({1: 49, 0: 7})
Counter({0: 7, 1: 2})
Counter({1: 19, 0: 7})
Counter({1: 12, 0: 7})
Counter({0: 7, 1: 1})
Counter({0: 7})
Counter({1: 14, 0: 6})
Counter({1: 40, 0: 6})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7, 1: 1})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7, 1: 3})
Counter({0: 7})
Counter({1: 497, 0: 6})
Counter({0: 6})
Counter({1: 73, 0: 6})
Counter({1: 42, 0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({1: 316, 0: 6})
Counter({1: 18, 0: 6})
Counter({0: 6})
Counter({0: 7, 1: 1})
Counter({1: 93, 0: 6})
Counter({0: 7, 1: 1})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({0: 6})
Counter({1: 203, 0: 7})
Counter({0: 7, 1: 1})
Counter({1: 180, 0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7, 1: 1})
Counter({0: 7})
Counter({1: 382, 0: 6})
Counter({1: 8, 0: 7})
Counter({0: 6, 1: 2})
Counter({0: 7})
Counter({0: 7, 1: 2})
Counter({1: 195, 0: 7})
Counter({0: 6, 1: 2})
Counter({0: 7})
Counter({0: 7, 1: 1})
Counter({1: 192, 0: 7})
Counter({1: 130, 0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({1: 87, 0: 6})
Counter({0: 7})
Counter({0: 7, 1: 1})
Counter({0: 7})
Counter({0: 7, 1: 1})
Counter({0: 6})
Counter({0: 6, 1: 4})
[[ 0.01285698  0.00341472]
 [ 0.00341472  0.02219753]]
Counter({0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({0: 6})
Counter({1: 141, 0: 7})
Counter({0: 7, 1: 3})
Counter({0: 7})
Counter({0: 6})
Counter({0: 6})
Counter({0: 7})
Counter({0: 6, 1: 1})
Counter({1: 13, 0: 7})
Counter({1: 325, 0: 7})
Counter({0: 7})
Counter({1: 349, 0: 6})
Counter({1: 9, 0: 7})
Counter({0: 7})
Counter({0: 6, 1: 3})
Counter({1: 131, 0: 7})
Counter({1: 40, 0: 7})
Counter({0: 7, 1: 1})
Counter({0: 6})
Counter({0: 6, 1: 1})
Counter({0: 7, 1: 3})
Counter({0: 7})
Counter({0: 6})
Counter({1: 25, 0: 7})
Counter({0: 7})
Counter({1: 180, 0: 7})
Counter({0: 7})
Counter({1: 17, 0: 6})
Counter({1: 486, 0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({0: 6})
Counter({0: 7, 1: 4})
[[ 0.07399623  0.12570212]
 [ 0.12570212  0.24108038]]
Counter({1: 331, 0: 7})
Counter({0: 6})
Counter({1: 84, 0: 7})
Counter({0: 7})
Counter({1: 10, 0: 7})
Counter({0: 6})
Counter({1: 39, 0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7, 1: 1})
Counter({1: 8, 0: 7})
Counter({0: 7})
Counter({1: 8, 0: 6})
Counter({0: 7, 1: 5})
[[ 0.00388791  0.00228342]
 [ 0.00228342  0.01579654]]
Counter({0: 7})
Counter({0: 7})
Counter({0: 7, 1: 5})
[[ 0.03282322  0.00851875]
 [ 0.00851875  0.04404702]]
Counter({0: 7})
Counter({0: 6})
Counter({1: 11, 0: 6})
Counter({1: 124, 0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({1: 19, 0: 6})
Counter({0: 7})
Counter({1: 1131, 0: 7})
Counter({0: 6})
Counter({1: 22, 0: 6})
Counter({1: 26, 0: 6})
Counter({0: 7, 1: 1})
Counter({1: 75, 0: 6})
Counter({1: 24, 0: 7})
Counter({0: 7, 1: 6})
[[ 0.01775289 -0.0051474 ]
 [-0.0051474   0.01954029]]
Counter({0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({1: 46, 0: 7})
Counter({0: 7})
Counter({1: 25, 0: 7})
Counter({0: 7})
Counter({0: 6})
Counter({0: 6})
Counter({0: 7})
Counter({0: 7, 1: 3})
Counter({0: 7})
Counter({0: 7})
Counter({0: 6, 1: 3})
Counter({0: 6})
Counter({0: 7, 1: 2})
Counter({0: 6})
Counter({0: 6})
Counter({1: 43, 0: 7})
Counter({1: 627, 0: 6})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7, 1: 1})
Counter({0: 7})
Counter({1: 17, 0: 6})
Counter({1: 8, 0: 7})
Counter({0: 7})
Counter({0: 7, 1: 4})
[[  0.28459538  -0.57839373]
 [ -0.57839373  12.65362595]]
Counter({1: 29, 0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({0: 6, 1: 1})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({1: 44, 0: 6})
Counter({0: 7, 1: 1})
Counter({0: 6, 1: 2})
Counter({0: 7, 1: 1})
Counter({1: 1042, 0: 7})
Counter({0: 7})
Counter({0: 6, 1: 4})
[[ 15.53084716  -5.52980813]
 [ -5.52980813   1.97618266]]
Counter({0: 6, 1: 2})
Counter({0: 7})
Counter({0: 6, 1: 1})
Counter({0: 6})
Counter({0: 7})
Counter({0: 7, 1: 2})
Counter({0: 7, 1: 1})
Counter({0: 7, 1: 4})
[[ 0.04362212  0.01159575]
 [ 0.01159575  0.02070155]]
Counter({0: 7, 1: 6})
[[ 0.02744641  0.00653263]
 [ 0.00653263  0.03784267]]
Counter({0: 7})
Counter({0: 6})
Counter({1: 611, 0: 6})
Counter({0: 7})
Counter({0: 7})
Counter({1: 14, 0: 7})
Counter({0: 7, 1: 1})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({1: 17, 0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7, 1: 3})
Counter({0: 6})
Counter({1: 9, 0: 7})
Counter({0: 7, 1: 5})
[[ 0.02101626 -0.02100619]
 [-0.02100619  0.06383645]]
Counter({0: 6})
Counter({0: 7})
Counter({1: 8, 0: 7})
Counter({1: 9, 0: 6})
Counter({0: 7, 1: 3})
Counter({1: 18, 0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({1: 49, 0: 6})
Counter({0: 6})
Counter({1: 8, 0: 7})
Counter({0: 6})
Counter({0: 7, 1: 2})
Counter({0: 7})
Counter({0: 7, 1: 4})
[[ 0.00569901 -0.00227307]
 [-0.00227307  0.00744244]]
Counter({0: 7, 1: 2})
Counter({0: 7, 1: 1})
Counter({0: 7, 1: 1})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7, 1: 1})
Counter({0: 7, 1: 1})
Counter({0: 7})
Counter({1: 8, 0: 6})
Counter({1: 25, 0: 7})
Counter({1: 292, 0: 7})
Counter({0: 6, 1: 3})
Counter({0: 7, 1: 3})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7})
Counter({0: 7, 1: 4})
[[ 0.02076846 -0.01106577]
 [-0.01106577  0.04037964]]
Counter({1: 790, 0: 7})
Counter({1: 34, 0: 7})
Counter({0: 7, 1: 1})
Counter({1: 30, 0: 6})
Counter({0: 7, 1: 4})
[[ 0.25246416  0.21697009]
 [ 0.21697009  0.21189227]]
Counter({0: 6})
Counter({0: 7})
Counter({1: 13, 0: 6})
Counter({0: 6, 1: 1})
Counter({0: 6})
Counter({1: 11, 0: 7})
Counter({1: 7, 0: 6})
Counter({0: 7, 1: 5})
[[ 0.00801566  0.00159033]
 [ 0.00159033  0.00852931]]
Counter({0: 7})
Counter({0: 7, 1: 2})
Counter({0: 6})
Counter({0: 7})
Counter({0: 6, 1: 2})
Counter({0: 7, 1: 3})
Counter({0: 6})
Counter({0: 6})
Counter({0: 7, 1: 1})
Counter({0: 7, 1: 1})
Counter({0: 7, 1: 1})
Counter({0: 6})
Counter({0: 7, 1: 1})
Counter({1: 139, 0: 7})
Counter({0: 7, 1: 7})
Counter({1: 360, 0: 7})
Counter({1: 30, 0: 7})
Counter({0: 7, 1: 2})
Counter({1: 404, 0: 6})
Counter({0: 7, 1: 1})
Counter({0: 7})
Counter({0: 6, 1: 5})
[[ 0.08126242 -0.02923237]
 [-0.02923237  0.02577203]]
Counter({1: 171, 0: 7})
Counter({0: 7})
Counter({0: 7, 1: 1})
Counter({1: 12, 0: 7})
Counter({0: 7})
Counter({0: 6})
Counter({0: 7})
Counter({0: 7, 1: 3})
Counter({0: 6, 1: 1})
Counter({0: 7})
Counter({1: 8, 0: 7})
Counter({0: 7})
Counter({1: 35, 0: 6})
Counter({1: 12, 0: 7})
Counter({0: 6})
Counter({0: 7, 1: 1})
Counter({0: 6, 1: 1})
Counter({1: 7, 0: 6})

In [ ]:
xxxx = [172, 6541, 8670, 9961]
print x[xxxx, :]

Function to plot a covariance with some sigma on the plot


In [ ]:
def plot_cov_ellipse(cov, pos, nstd=2, ax=None, **kwargs):
    """
    Plots an `nstd` sigma error ellipse based on the specified covariance
    matrix (`cov`). Additional keyword arguments are passed on to the 
    ellipse patch artist.

    Parameters
    ----------
        cov : The 2x2 covariance matrix to base the ellipse on
        pos : The location of the center of the ellipse. Expects a 2-element
            sequence of [x0, y0].
        nstd : The radius of the ellipse in numbers of standard deviations.
            Defaults to 2 standard deviations.
        ax : The axis that the ellipse will be plotted on. Defaults to the 
            current axis.
        Additional keyword arguments are pass on to the ellipse patch.

    Returns
    -------
        A matplotlib ellipse artist
    """
    def eigsorted(cov):
        vals, vecs = np.linalg.eigh(cov)
        order = vals.argsort()[::-1]
        return vals[order], vecs[:,order]

    if ax is None:
        ax = plt.gca()

    vals, vecs = eigsorted(cov)
    theta = np.degrees(np.arctan2(*vecs[:,0][::-1]))

    # Width and height are "full" widths, not radius
    width, height = 2 * nstd * np.sqrt(vals)
    ellip = Ellipse(xy=pos, width=width, height=height, angle=theta, **kwargs)

    ax.add_artist(ellip)
    ellip.set_alpha(.1)
    return ellip

In [11]:
# Plot the data
figure, axis = plt.subplots(1,1)
#plt.figure(figsize=(18,10))
plt.scatter(x[y==0, 0], x[y==0, 1], label="Class #0", alpha=0.5, 
            edgecolor=almost_black, facecolor=palette[0], linewidth=0.15)
plt.scatter(x[y==1, 0], x[y==1, 1], label="Class #1", alpha=0.5, 
            edgecolor=almost_black, facecolor=palette[2], linewidth=0.15)

# For each data, let's plot some elippse
for cm, s, cs in zip(cov_mat, x[idx_min_class], idx_cons):
    if (cs == True):
        #print np.matrix(cm).I
        #print s
        elli = make_covariance_ellipse(np.matrix(cm).I, s, std=2.)
        axis.add_artist(elli)
    
plt.legend()
plt.show()



In [ ]:
print np.sum(np.sum(np.cov(x_s.T)))

Easy LMNN example


In [ ]:
x = np.array([[0,0],[-1,0.1],[0.3,-0.05],[0.7,0.3],[-0.2,-0.6],[-0.15,-0.63],[-0.25,0.55],[-0.28,0.67]])
y = np.array([0,0,0,0,1,1,2,2])

In [ ]:
import matplotlib.pyplot as pyplot

%matplotlib inline

def plot_data(features,labels,axis,alpha=1.0):
    # separate features according to their class
    X0,X1,X2 = features[labels==0], features[labels==1], features[labels==2]
    
    # class 0 data
    axis.plot(X0[:,0], X0[:,1], 'o', color='green', markersize=12, alpha=alpha)
    # class 1 data
    axis.plot(X1[:,0], X1[:,1], 'o', color='red', markersize=12, alpha=alpha)
    # class 2 data
    axis.plot(X2[:,0], X2[:,1], 'o', color='blue', markersize=12, alpha=alpha)
    
    # set axes limits
    axis.set_xlim(-1.5,1.5)
    axis.set_ylim(-1.5,1.5)
    axis.set_aspect('equal')
    
    axis.set_xlabel('x')
    axis.set_ylabel('y')

figure,axis = plt.subplots(1,1)
plot_data(x,y,axis)
axis.set_title('Toy data set')
plt.show()

In [5]:
def make_covariance_ellipse(covariance, mean, std=2):
    import matplotlib.patches as patches
    import scipy.linalg       as linalg
    
    # the ellipse is centered at (0,0)
    # mean = np.array([0,0])
    
    # eigenvalue decomposition of the covariance matrix (w are eigenvalues and v eigenvectors),
    # keeping only the real part
    w,v = linalg.eigh(covariance)
    # normalize the eigenvector corresponding to the largest eigenvalue
    u = v[0]/linalg.norm(v[0])
    # angle in degrees
    angle = 180.0/np.pi*np.arctan(u[1]/u[0])
    # fill Gaussian ellipse at 2 standard deviation
    ellipse = patches.Ellipse(mean, std*w[0]**0.5, std*w[1]**0.5, 180+angle, color='orange', alpha=0.3)
    
    return ellipse

# represent the Euclidean distance
figure,axis = plt.subplots(1,1)
plot_data(x,y,axis)
ellipse = make_covariance_ellipse(np.eye(2))
axis.add_artist(ellipse)
axis.set_title('Euclidean distance')
plt.show()


---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
<ipython-input-5-a2555c7fdecd> in <module>()
     20 # represent the Euclidean distance
     21 figure,axis = plt.subplots(1,1)
---> 22 plot_data(x,y,axis)
     23 ellipse = make_covariance_ellipse(np.eye(2))
     24 axis.add_artist(ellipse)

NameError: name 'plot_data' is not defined

In [ ]:
# number of target neighbours per example
k = 2

lmnn = LMNN(k, min_iter=50, max_iter=1000, convergence_tol=1e-9)
lmnn.fit(x, y)

In [ ]:
# get the linear transform from LMNN
L = lmnn.transformer()
# square the linear transform to obtain the Mahalanobis distance matrix
M = np.matrix(lmnn.metric())

print M.I

# represent the distance given by LMNN
figure,axis = plt.subplots(1,1)
plot_data(x,y,axis)
ellipse = make_covariance_ellipse(M.I)
axis.add_artist(ellipse)
axis.set_title('LMNN distance')
plt.show()

Shogun version


In [14]:
import numpy

x = numpy.array([[0,0],[-1,0.1],[0.3,-0.05],[0.7,0.3],[-0.2,-0.6],[-0.15,-0.63],[-0.25,0.55],[-0.28,0.67]])
y = numpy.array([0,0,0,0,1,1,2,2])

import matplotlib.pyplot as pyplot

%matplotlib inline

def plot_data(features,labels,axis,alpha=1.0):
    # separate features according to their class
    X0,X1,X2 = features[labels==0], features[labels==1], features[labels==2]
    
    # class 0 data
    axis.plot(X0[:,0], X0[:,1], 'o', color='green', markersize=12, alpha=alpha)
    # class 1 data
    axis.plot(X1[:,0], X1[:,1], 'o', color='red', markersize=12, alpha=alpha)
    # class 2 data
    axis.plot(X2[:,0], X2[:,1], 'o', color='blue', markersize=12, alpha=alpha)
    
    # set axes limits
    axis.set_xlim(-1.5,1.5)
    axis.set_ylim(-1.5,1.5)
    axis.set_aspect('equal')
    
    axis.set_xlabel('x')
    axis.set_ylabel('y')

figure,axis = pyplot.subplots(1,1)
plot_data(x,y,axis)
axis.set_title('Toy data set')
pyplot.show()

def make_covariance_ellipse(covariance):
    import matplotlib.patches as patches
    import scipy.linalg       as linalg
    
    # the ellipse is centered at (0,0)
    mean = numpy.array([0,0])
    
    # eigenvalue decomposition of the covariance matrix (w are eigenvalues and v eigenvectors),
    # keeping only the real part
    w,v = linalg.eigh(covariance)
    # normalize the eigenvector corresponding to the largest eigenvalue
    u = v[0]/linalg.norm(v[0])
    # angle in degrees
    angle = 180.0/numpy.pi*numpy.arctan(u[1]/u[0])
    # fill Gaussian ellipse at 2 standard deviation
    ellipse = patches.Ellipse(mean, 2*w[0]**0.5, 2*w[1]**0.5, 180+angle, color='orange', alpha=0.3)
    
    return ellipse

# represent the Euclidean distance
figure,axis = pyplot.subplots(1,1)
plot_data(x,y,axis)
ellipse = make_covariance_ellipse(numpy.eye(2))
axis.add_artist(ellipse)
axis.set_title('Euclidean distance')
pyplot.show()

from modshogun import RealFeatures, MulticlassLabels

features = RealFeatures(x.T)
labels   = MulticlassLabels(y.astype(numpy.float64))

from modshogun import LMNN

# number of target neighbours per example
k = 1

lmnn = LMNN(features,labels,k)
# set an initial transform as a start point of the optimization
init_transform = numpy.eye(2)
lmnn.set_maxiter(2000)
lmnn.train(init_transform)

# get the linear transform from LMNN
L = lmnn.get_linear_transform()
# square the linear transform to obtain the Mahalanobis distance matrix
M = numpy.matrix(numpy.dot(L.T,L))

# represent the distance given by LMNN
figure,axis = pyplot.subplots(1,1)
plot_data(x,y,axis)
ellipse = make_covariance_ellipse(M.I)
axis.add_artist(ellipse)
axis.set_title('LMNN distance')
pyplot.show()

# project original data using L
lx = numpy.dot(L,x.T)

print L

# represent the data in the projected space
figure,axis = pyplot.subplots(1,1)
plot_data(lx.T,y,axis)
plot_data(x,y,axis,0.3)
ellipse = make_covariance_ellipse(numpy.eye(2))
axis.add_artist(ellipse)
axis.set_title('LMNN\'s linear transform')
pyplot.show()


[[ 0.93495512 -0.03239161]
 [-0.03207083  1.8724276 ]]