In [1]:
from __future__ import division, print_function

In [2]:
import numpy as np
import mmd

Generate fake data

As an example, let's regress 1d normals to their mean.

We're storing the data as a list of n numpy arrays, each of size n_samp x dim (with dim == 1).


In [3]:
n = 500
mean = np.random.normal(0, 10, size=n)
var = np.random.gamma(5, size=n)
n_samp = np.random.randint(10, 500, size=n)
samps = [np.random.normal(m, v, size=s)[:, np.newaxis]
         for m, v, s in zip(mean, var, n_samp)]

In [4]:
# this gives us a progress bar for MMD computations
from mmd.utils import show_progress
show_progress('mmd.mmd.progress')

In [5]:
# Get the median pairwise squared distance in the aggregate sample,
# as a heuristic for choosing the bandwidth of the inner RBF kernel.
from sklearn.metrics.pairwise import euclidean_distances
sub = np.vstack(samps)
sub = sub[np.random.choice(sub.shape[0], min(1000, sub.shape[0]), replace=False)]
D2 = euclidean_distances(sub, squared=True)
med_2 = np.median(D2[np.triu_indices_from(D2, k=1)], overwrite_input=True)
del sub, D2

In [6]:
from sklearn import cross_validation as cv
from sklearn.kernel_ridge import KernelRidge
import sys

In [7]:
l1_gamma_mults = np.array([1/16, 1/4, 1, 4])  # Could expand these, but it's quicker if you don't. :)
l1_gammas = l1_gamma_mults * med_2

Now we'll get the $\mathrm{MMD}^2$ values for each of the proposed gammas. (This is maybe somewhat faster than doing it independently, but I haven't really tested that; if it's causing memory issues or anything, do them separately.)

We also want to save the "diagonal" values (the mean map kernel between a set and itself), so we can compute the MMD to test values later without recomputing those.


In [8]:
mmds, mmk_diags = mmd.rbf_mmd(samps, gammas=l1_gammas, squared=True, n_jobs=40, ret_X_diag=True)


RBF mean map kernel:
 125,250 of 125,250 (100%) |###########################################################| Time: 0:02:11

Now, we want to turn this into a kernel and evaluate the regression for each of the other hyperparameter values.

We'll just do a 3d grid search here.

Ideally, this would be:

  • Parallelized. Scikit-learn's tools for this want to pickle the kernels, which is a non-starter here. It'd take some coding to work around that.
  • Not a grid search. It's harder to get away from grid search for the l1 gamma because it's kind of expensive, but you could definitely do a randomized search (or some kind of actual optimization) for the l2 gamma + alpha parameters.
  • Really, KernelRidge should support leave-one-out CV across a bunch of alphas, like RidgeCV. Supposedly this isn't too hard but I haven't spent the time to try to figure it out, and apparently neither has anyone else. This would help with the parallelization issue.

In [9]:
# Choose parameters for the hyperparameter search
k_fold = list(cv.KFold(n, n_folds=3, shuffle=True))

l2_gamma_mults = np.array([1/4, 1, 4, 8])
alphas = np.array([1/128, 1/64, 1/16, 1/4, 1, 4])

scores = np.empty((l1_gamma_mults.size, l2_gamma_mults.size, alphas.size, len(k_fold)))
scores.fill(np.nan)

In [10]:
%%time

K = np.empty((n, n), dtype=samps[0].dtype)
for l1_gamma_i, l1_gamma in enumerate(l1_gamma_mults * med_2):
    print("l1 gamma {} / {}: {:.4}".format(l1_gamma_i + 1, len(l1_gamma_mults), l1_gamma), file=sys.stderr)
    D2_mmd = mmds[l1_gamma_i]
    
    # get the median of *these* squared distances,
    # to scale the bandwidth of the outer RBF kernel
    mmd_med2 = np.median(D2_mmd[np.triu_indices_from(D2_mmd, k=1)])
    
    for l2_gamma_i, l2_gamma in enumerate(l2_gamma_mults * mmd_med2):
        print("\tl2 gamma {} / {}: {:.4}".format(l2_gamma_i + 1, len(l2_gamma_mults), l2_gamma), file=sys.stderr)
        np.multiply(D2_mmd, -l2_gamma, out=K)
        np.exp(K, out=K)
        
        for alpha_i, alpha in enumerate(alphas):
            ridge = KernelRidge(alpha=alpha, kernel='precomputed')
            these = cv.cross_val_score(ridge, K, mean, cv=k_fold)
            scores[l1_gamma_i, l2_gamma_i, alpha_i, :] = these
            print("\t\talpha {} / {}: {} \t {}".format(alpha_i + 1, len(alphas), alpha, these), file=sys.stderr)


l1 gamma 1 / 4: 7.562
	l2 gamma 1 / 4: 0.01677
		alpha 1 / 6: 0.0078125 	 [ 0.9253348   0.91112741  0.92746963]
		alpha 2 / 6: 0.015625 	 [ 0.86428587  0.83874947  0.87861644]
		alpha 3 / 6: 0.0625 	 [ 0.64522929  0.60369467  0.67002685]
		alpha 4 / 6: 0.25 	 [ 0.33072326  0.29631131  0.34359958]
		alpha 5 / 6: 1.0 	 [ 0.10770497  0.08482278  0.11357141]
		alpha 6 / 6: 4.0 	 [ 0.02324835  0.005496    0.02778346]
	l2 gamma 2 / 4: 0.06708
		alpha 1 / 6: 0.0078125 	 [ 0.98203002  0.9805339   0.96849574]
		alpha 2 / 6: 0.015625 	 [ 0.96198555  0.95618912  0.95410813]
		alpha 3 / 6: 0.0625 	 [ 0.86382546  0.83823775  0.87833251]
		alpha 4 / 6: 0.25 	 [ 0.64438717  0.60291489  0.66924618]
		alpha 5 / 6: 1.0 	 [ 0.32986607  0.29558004  0.34272459]
		alpha 6 / 6: 4.0 	 [ 0.10734844  0.08462546  0.11317859]
	l2 gamma 3 / 4: 0.2683
		alpha 1 / 6: 0.0078125 	 [ 0.99543429  0.99579355  0.9813862 ]
		alpha 2 / 6: 0.015625 	 [ 0.99145241  0.99138094  0.97655579]
		alpha 3 / 6: 0.0625 	 [ 0.9612365   0.95532443  0.95390085]
		alpha 4 / 6: 0.25 	 [ 0.86194888  0.8361614   0.87716608]
		alpha 5 / 6: 1.0 	 [ 0.64100342  0.59978763  0.66610057]
		alpha 6 / 6: 4.0 	 [ 0.32645335  0.29266772  0.33923855]
	l2 gamma 4 / 4: 0.5366
		alpha 1 / 6: 0.0078125 	 [ 0.99660695  0.9970905   0.98441303]
		alpha 2 / 6: 0.015625 	 [ 0.9952733   0.99556481  0.98137235]
		alpha 3 / 6: 0.0625 	 [ 0.9809779   0.97938433  0.96832703]
		alpha 4 / 6: 0.25 	 [ 0.92204166  0.90731807  0.9260111 ]
		alpha 5 / 6: 1.0 	 [ 0.76536657  0.7292215   0.78998224]
		alpha 6 / 6: 4.0 	 [ 0.48016052  0.44116842  0.4998896 ]
l1 gamma 2 / 4: 30.25
	l2 gamma 1 / 4: 0.01013
		alpha 1 / 6: 0.0078125 	 [ 0.79933156  0.76548461  0.8205987 ]
		alpha 2 / 6: 0.015625 	 [ 0.68228314  0.64153397  0.70724916]
		alpha 3 / 6: 0.0625 	 [ 0.37242102  0.33629039  0.38694699]
		alpha 4 / 6: 0.25 	 [ 0.1284564   0.10438441  0.13474427]
		alpha 5 / 6: 1.0 	 [ 0.02964356  0.01138949  0.03426635]
		alpha 6 / 6: 4.0 	 [ 0.001045   -0.01533853  0.00537696]
	l2 gamma 2 / 4: 0.04052
		alpha 1 / 6: 0.0078125 	 [ 0.93532218  0.92254241  0.93446721]
		alpha 2 / 6: 0.015625 	 [ 0.88182687  0.85869737  0.89289967]
		alpha 3 / 6: 0.0625 	 [ 0.68199126  0.64126241  0.70698824]
		alpha 4 / 6: 0.25 	 [ 0.37209429  0.33600907  0.38661566]
		alpha 5 / 6: 1.0 	 [ 0.1283001   0.10427815  0.1345786 ]
		alpha 6 / 6: 4.0 	 [ 0.02963815  0.01150946  0.03423441]
	l2 gamma 3 / 4: 0.1621
		alpha 1 / 6: 0.0078125 	 [ 0.98325221  0.98147808  0.96896823]
		alpha 2 / 6: 0.015625 	 [ 0.96648322  0.96092088  0.95684321]
		alpha 3 / 6: 0.0625 	 [ 0.88119356  0.85801448  0.89254499]
		alpha 4 / 6: 0.25 	 [ 0.68082137  0.64017482  0.70594108]
		alpha 5 / 6: 1.0 	 [ 0.370789   0.3348852  0.3852916]
		alpha 6 / 6: 4.0 	 [ 0.1276767   0.10385361  0.13391809]
	l2 gamma 4 / 4: 0.3241
		alpha 1 / 6: 0.0078125 	 [ 0.99126925  0.99096328  0.9758109 ]
		alpha 2 / 6: 0.015625 	 [ 0.98309079  0.98126885  0.96893928]
		alpha 3 / 6: 0.0625 	 [ 0.93434939  0.92145087  0.93408128]
		alpha 4 / 6: 0.25 	 [ 0.79701186  0.76315473  0.8188684 ]
		alpha 5 / 6: 1.0 	 [ 0.52994289  0.48929257  0.55134852]
		alpha 6 / 6: 4.0 	 [ 0.22805858  0.19894332  0.23710312]
l1 gamma 3 / 4: 121.0
	l2 gamma 1 / 4: 0.006552
		alpha 1 / 6: 0.0078125 	 [ 0.59008066  0.54844952  0.61324882]
		alpha 2 / 6: 0.015625 	 [ 0.43084104  0.39261253  0.44779727]
		alpha 3 / 6: 0.0625 	 [ 0.16134347  0.13547177  0.16838393]
		alpha 4 / 6: 0.25 	 [ 0.04036663  0.02144411  0.04510735]
		alpha 5 / 6: 1.0 	 [ 0.0038847  -0.01279122  0.00825952]
		alpha 6 / 6: 4.0 	 [-0.00566876 -0.02163579 -0.00138511]
	l2 gamma 2 / 4: 0.02621
		alpha 1 / 6: 0.0078125 	 [ 0.83080074  0.79987327  0.84862107]
		alpha 2 / 6: 0.015625 	 [ 0.72747294  0.68839863  0.75177899]
		alpha 3 / 6: 0.0625 	 [ 0.4306952   0.39248634  0.44765038]
		alpha 4 / 6: 0.25 	 [ 0.16126049  0.13540816  0.16829858]
		alpha 5 / 6: 1.0 	 [ 0.04035006  0.02146358  0.04508471]
		alpha 6 / 6: 4.0 	 [ 0.00392248 -0.01263198  0.008274  ]
	l2 gamma 3 / 4: 0.1048
		alpha 1 / 6: 0.0078125 	 [ 0.94530131  0.93322669  0.94071526]
		alpha 2 / 6: 0.015625 	 [ 0.90111136  0.88029346  0.9076435 ]
		alpha 3 / 6: 0.0625 	 [ 0.72700472  0.6879654   0.75138133]
		alpha 4 / 6: 0.25 	 [ 0.43011198  0.39198175  0.44706291]
		alpha 5 / 6: 1.0 	 [ 0.16092904  0.13515403  0.16795764]
		alpha 6 / 6: 4.0 	 [ 0.04028344  0.02153968  0.04499407]
	l2 gamma 4 / 4: 0.2096
		alpha 1 / 6: 0.0078125 	 [ 0.97034964  0.96425102  0.95838798]
		alpha 2 / 6: 0.015625 	 [ 0.94508432  0.93298604  0.94062706]
		alpha 3 / 6: 0.0625 	 [ 0.82997094  0.79905935  0.84805169]
		alpha 4 / 6: 0.25 	 [ 0.5886217   0.54715845  0.61184844]
		alpha 5 / 6: 1.0 	 [ 0.27767896  0.24596812  0.28839051]
		alpha 6 / 6: 4.0 	 [ 0.08427846  0.06296932  0.08963635]
l1 gamma 4 / 4: 484.0
	l2 gamma 1 / 4: 0.004618
		alpha 1 / 6: 0.0078125 	 [ 0.35031579  0.31512126  0.36359181]
		alpha 2 / 6: 0.015625 	 [ 0.21405404  0.18534118  0.22243785]
		alpha 3 / 6: 0.0625 	 [ 0.0590918   0.03903424  0.06403364]
		alpha 4 / 6: 0.25 	 [ 0.00903746 -0.00798812  0.01344722]
		alpha 5 / 6: 1.0 	 [ -4.38927273e-03  -2.05560712e-02  -7.67348200e-05]
		alpha 6 / 6: 4.0 	 [-0.0077666  -0.02360415 -0.0034969 ]
	l2 gamma 2 / 4: 0.01847
		alpha 1 / 6: 0.0078125 	 [ 0.65892488  0.61743994  0.68291129]
		alpha 2 / 6: 0.015625 	 [ 0.50890343  0.4683392   0.52877418]
		alpha 3 / 6: 0.0625 	 [ 0.21399975  0.18529711  0.22238288]
		alpha 4 / 6: 0.25 	 [ 0.05907455  0.03902952  0.06401532]
		alpha 5 / 6: 1.0 	 [ 0.00904342 -0.00794874  0.01344782]
		alpha 6 / 6: 4.0 	 [ -4.34596401e-03  -2.03923785e-02  -5.60242740e-05]
	l2 gamma 3 / 4: 0.07389
		alpha 1 / 6: 0.0078125 	 [ 0.86410064  0.83618152  0.87608821]
		alpha 2 / 6: 0.015625 	 [ 0.77854721  0.74208114  0.79987677]
		alpha 3 / 6: 0.0625 	 [ 0.50859409  0.46807412  0.52846942]
		alpha 4 / 6: 0.25 	 [ 0.21378271  0.18512098  0.22216313]
		alpha 5 / 6: 1.0 	 [ 0.05900562  0.03901061  0.0639421 ]
		alpha 6 / 6: 4.0 	 [ 0.00906677 -0.00779334  0.01344993]
	l2 gamma 4 / 4: 0.1478
		alpha 1 / 6: 0.0078125 	 [ 0.91993713  0.90062585  0.92032128]
		alpha 2 / 6: 0.015625 	 [ 0.86387362  0.83596879  0.87594703]
		alpha 3 / 6: 0.0625 	 [ 0.6582745  0.6168746  0.6823209]
		alpha 4 / 6: 0.25 	 [ 0.34957302  0.31449283  0.36284128]
		alpha 5 / 6: 1.0 	 [ 0.1174891   0.09412142  0.1234212 ]
		alpha 6 / 6: 4.0 	 [ 0.02630513  0.00841186  0.03084043]
CPU times: user 28.1 s, sys: 22.9 s, total: 51 s
Wall time: 3.79 s


In [11]:
mean_scores = scores.mean(axis=-1)

In [12]:
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns

In [13]:
import matplotlib.ticker as ticker

fig = plt.figure(figsize=(8, 15))
for i in range(len(l1_gamma_mults)):
    ax = fig.add_subplot(len(l1_gamma_mults), 1, i+1)
    cax = ax.matshow(mean_scores[i, :, :], vmin=0, vmax=1)
    
    ax.set_yticklabels([''] + ['{:.3}'.format(g * mmd_med2) for g in l2_gamma_mults])
    ax.set_xticklabels([''] + ['{:.3}'.format(a) for a in alphas])
    ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
    
    fig.colorbar(cax)
    ax.set_title("L1 gamma: {}".format(l1_gamma_mults[i] * med_2))
plt.tight_layout()


Looking at these results, the best results are with the lowest alpha, lowest L1 gamma, and highest L2 gamma. So if we were really trying to solve this as best as possible, we'd maybe want to try lower alphas / lower L1 gammas / higher L2 gammas. But, whatever; let's just use the best of the ones we did try for now.


In [14]:
best_l1_gamma_i, best_l2_gamma_i, best_alpha_i = np.unravel_index(mean_scores.argmax(), mean_scores.shape)
best_l1_gamma = l1_gamma_mults[best_l1_gamma_i] * med_2
best_l2_gamma = l2_gamma_mults[best_l2_gamma_i] * mmd_med2
best_alpha = alphas[best_alpha_i]

In [16]:
best_l1_gamma_i, best_l2_gamma_i, best_alpha_i


Out[16]:
(0, 3, 0)

In [15]:
best_l1_gamma, best_l2_gamma, best_alpha


Out[15]:
(7.5621827082571507, 0.14778624920765476, 0.0078125)

Now, train a model on the full training set:


In [17]:
# get the training kernel
D2_mmd = mmds[best_l1_gamma_i]
np.multiply(D2_mmd, -best_l2_gamma, out=K)
np.exp(K, out=K)

ridge = KernelRidge(alpha=best_alpha, kernel='precomputed')
ridge.fit(K, mean)


Out[17]:
KernelRidge(alpha=0.0078125, coef0=1, degree=3, gamma=None,
      kernel='precomputed', kernel_params=None)

To evaluate on new data:


In [18]:
# generate some test data from the same distribution
t_n = 100
t_mean = np.random.normal(0, 10, size=t_n)
t_var = np.random.gamma(5, size=t_n)
t_n_samp = np.random.randint(10, 500, size=t_n)
t_samps = [np.random.normal(m, v, size=s)[:, np.newaxis]
           for m, v, s in zip(t_mean, t_var, t_n_samp)]

In [19]:
# get the kernel from the training data to the test data
t_K = mmd.rbf_mmd(t_samps, samps, gammas=best_l1_gamma, squared=True,
                  Y_diag=mmk_diags[best_l1_gamma_i], n_jobs=20)
t_K *= -best_l2_gamma
np.exp(t_K, out=t_K);


RBF mean map kernel:
 50,100 of 50,100 (100%) |#############################################################| Time: 0:00:36


In [20]:
preds = ridge.predict(t_K)

In [21]:
plt.figure(figsize=(5, 5))
plt.scatter(t_mean, preds)


Out[21]:
<matplotlib.collections.PathCollection at 0x7fab400e2710>

Pretty good predictions!