A notebook to test quadratic MMD two-sample test


In [ ]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import freqopttest.util as util
import freqopttest.data as data
import freqopttest.kernel as kernel
import freqopttest.tst as tst
import freqopttest.glo as glo
import sys
import time

In [ ]:
# sample source 
n = 800
dim = 1
seed = 14
alpha = 0.01
ss = data.SSGaussMeanDiff(dim, my=1)
ss = data.SSGaussVarDiff(dim)
#ss = data.SSSameGauss(dim)
#ss = data.SSBlobs()
dim = ss.dim()
tst_data = ss.sample(n, seed=seed)
tr, te = tst_data.split_tr_te(tr_proportion=0.5, seed=10)

Test permutations


In [ ]:
xtr, ytr = tr.xy()
xytr = tr.stack_xy()
sig2 = util.meddistance(xytr, subsample=1000)
k = kernel.KGauss(sig2)

In [ ]:
start = time.time()

perm_mmds1 = tst.QuadMMDTest.permutation_list_mmd2(xtr, ytr, k, n_permute=200)

end = time.time()
print('permutations took: %.4f s'%(end-start))
print('perm_mmds1', perm_mmds1)

In [ ]:
def chi_square_weights_H0(k, X):
    """
    Return a numpy array of the weights to be used as the weights in the
    weighted sum of chi-squares for the null distribution of MMD^2.
    - k: a Kernel
    - X: n x d number array of n data points
    """
    n = X.shape[0]
    # Gram matrix
    K = k.eval(X, X)
    # centring matrix. Not the most efficient way.
    H = np.eye(n) - np.ones((n, n))/float(n)
    HKH = H.dot(K).dot(H)
    #https://docs.scipy.org/doc/numpy-1.14.0/reference/generated/numpy.linalg.eigvals.html
    evals = np.linalg.eigvals(HKH)
    evals = np.real(evals)
    # sort in decreasing order 
    evals = -np.sort(-evals)
    weights = evals/float(n)**2
    return weights

In [ ]:
def simulate_null_spectral(weights, n_simulate=1000, seed=275):
    """
    weights: chi-square weights (for the infinite weigted sum of chi squares)    
    Return the values of MMD^2 (NOT n*MMD^2) simulated from the null distribution by
    the spectral method.
    """
     # draw at most block_size values at a time
    block_size = 400
    D = len(weights)
    mmds = np.zeros(n_simulate)
    from_ind = 0

    with util.NumpySeedContext(seed=seed):
        while from_ind < n_simulate:
            to_draw = min(block_size, n_simulate-from_ind)
            # draw chi^2 random variables. 
            chi2 = np.random.randn(D, to_draw)**2
            # an array of length to_draw 
            sim_mmds = 2.0*weights.dot(chi2-1.0)
            # store 
            end_ind = from_ind+to_draw
            mmds[from_ind:end_ind] = sim_mmds
            from_ind = end_ind
    return mmds

In [ ]:
xytr = np.vstack((xtr, ytr))
chi2_weights = chi_square_weights_H0(k, xytr)
sim_mmds = simulate_null_spectral(chi2_weights, n_simulate=2000)
a = 0.6

plt.figure(figsize=(8, 5))
plt.hist(perm_mmds1,20, color='blue', normed=True, label='Permutation', alpha=a)
plt.hist(sim_mmds, 20, color='red', normed=True, label='Spectral', alpha=a)
plt.legend()

In [ ]:
# test h1_mean_var
start = time.time()
mean, var = tst.QuadMMDTest.h1_mean_var(xtr, ytr, k, is_var_computed=True)
end = time.time()
print('h1_mean_var took: %.3f'%(end - start))
print('mean: %.3g, var: %.3g'%(mean, var))

In [ ]:
# test h1_mean_var_gram
start = time.time()
Kx = k.eval(xtr, xtr)
Ky = k.eval(ytr, ytr)
Kxy = k.eval(xtr, ytr)
mean, var = tst.QuadMMDTest.h1_mean_var_gram(Kx, Ky, Kxy, k, True)
end = time.time()
print('h1_mean_var took: %.3f'%(end - start))
print('mean: %.3g, var: %.3g'%(mean, var))

In [ ]:
# choose the best parameter and perform a test with permutations
med = util.meddistance(tr.stack_xy(), 1000)
list_gwidth = np.hstack( ( (med**2) *(2.0**np.linspace(-4, 4, 20) ) ) )
list_gwidth.sort()
list_kernels = [kernel.KGauss(gw2) for gw2 in list_gwidth]

# grid search to choose the best Gaussian width
besti, powers = tst.QuadMMDTest.grid_search_kernel(tr, list_kernels, alpha)
# perform test 
best_ker = list_kernels[besti]

In [ ]:
start = time.time()
mmd_test = tst.QuadMMDTest(best_ker, n_permute=200, alpha=alpha)
test_result = mmd_test.perform_test(te)
end = time.time()
print('MMD test took: %s seconds'%(end-start))
test_result

New permutation


In [ ]:
# # @staticmethod
# def permutation_list_mmd2_rahul(X, Y, k, n_permute=400, seed=8273):
#     """ Permutation by maintaining inverse indices. This approach is due to
#     Rahul (Soumyajit De) briefly described in "Generative Models and Model
#     Criticism via Optimized Maximum Mean Discrepancy" 
    
#     X: m x d matrix
#     Y: n x d matrix of data
#     k: a Kernel
#     """
    
#     def which_term(I, J, m):
#         """
#         There are three terms in MMD computation.
#         MMD^2 = \sum_{i=1}^m \sum_{j \neq i} k(x_i, x_j)/(m*(m-1))
#             \sum_{i=1}^n \sum_{j \neq i} k(y_i, y_j)/(n*(n-1))
#             -2*\sum_{i=1}^m \sum_{j=1}^n k(x_i, y_j)/(m*n)
            
#             Return 0 if (i,j) should participate in the first term
#             1 for the second term
#             2 for the third term
            
#         - I, J 1d arrays of indices.
#         """
#         assert len(I)==len(J)
#         t1 = np.logical_and(I<m , J<m)
#         t2 = np.logical_and(I>=m, J>=m)
# #         t3 = np.logical_not(np.logical_or(t1, t2))
#         t3 = np.logical_or(np.logical_and(I>=m, J<m), np.logical_and(I<m, J>=m))
#         term_inds = -np.ones(len(I))
#         term_inds = term_inds.astype(int)
#         term_inds[t1] = 0
#         term_inds[t2] = 1
#         term_inds[t3] = 2
# #         assert np.all(term_inds >= 0)
#         return term_inds
        
        
#     XY = np.vstack((X, Y))
#     m = X.shape[0]
#     n = Y.shape[0]
#     KZ = k.eval(XY, XY)
#     terms = np.zeros((n_permute, 3))
#     # a matrix of indices of size n_permute x mn
#     perm_inds = np.zeros((n_permute, m+n))
#     R = range(n_permute)
#     with util.NumpySeedContext(seed=seed):
#         for i in range(n_permute):
#             perm_inds[i] = np.random.permutation(m+n)
#     for i in range(m+n):
#         for j in range(m+n):
#             if i!=j:
#                 # ij is a number array of length 2
#                 IJ = perm_inds[:, [i, j]]
#                 term_inds = which_term(IJ[:, 0], IJ[:, 1], m)
#                 # !!! I hope the sum will not cause an overflow..
#                 terms[R, term_inds] =terms[R, term_inds] + KZ[i, j]
                
#     n_mmds = terms[:, 0]/float(m*(m-1)) + terms[:, 1]/float(n*(n-1)) - terms[:, 2]*2.0/(m*n)
#     return mmds, terms

In [ ]:
# xte, yte = te.xy()
# perm_mmds2, terms = permutation_list_mmd2_rahul(xte, yte, best_ker, n_permute=30, seed=399)

In [ ]:
# terms

In [ ]:
# perm_mmds = test_result['list_permuted_mmd2']
# plt.hist(perm_mmds, alpha=0.5, label='Current', bins=20)
# plt.hist(perm_mmds2, alpha=0.5, label='New version', bins=20)
# plt.legend(loc='best')

In [ ]:


In [ ]: