A notebook to test quadratic MMD two-sample test



In [ ]:

%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import freqopttest.util as util
import freqopttest.data as data
import freqopttest.kernel as kernel
import freqopttest.tst as tst
import freqopttest.glo as glo
import sys
import time




In [ ]:

# sample source
n = 800
dim = 1
seed = 14
alpha = 0.01
ss = data.SSGaussMeanDiff(dim, my=1)
ss = data.SSGaussVarDiff(dim)
#ss = data.SSSameGauss(dim)
#ss = data.SSBlobs()
dim = ss.dim()
tst_data = ss.sample(n, seed=seed)
tr, te = tst_data.split_tr_te(tr_proportion=0.5, seed=10)



## Test permutations



In [ ]:

xtr, ytr = tr.xy()
xytr = tr.stack_xy()
sig2 = util.meddistance(xytr, subsample=1000)
k = kernel.KGauss(sig2)




In [ ]:

start = time.time()

perm_mmds1 = tst.QuadMMDTest.permutation_list_mmd2(xtr, ytr, k, n_permute=200)

end = time.time()
print('permutations took: %.4f s'%(end-start))
print('perm_mmds1', perm_mmds1)




In [ ]:

def chi_square_weights_H0(k, X):
"""
Return a numpy array of the weights to be used as the weights in the
weighted sum of chi-squares for the null distribution of MMD^2.
- k: a Kernel
- X: n x d number array of n data points
"""
n = X.shape[0]
# Gram matrix
K = k.eval(X, X)
# centring matrix. Not the most efficient way.
H = np.eye(n) - np.ones((n, n))/float(n)
HKH = H.dot(K).dot(H)
#https://docs.scipy.org/doc/numpy-1.14.0/reference/generated/numpy.linalg.eigvals.html
evals = np.linalg.eigvals(HKH)
evals = np.real(evals)
# sort in decreasing order
evals = -np.sort(-evals)
weights = evals/float(n)**2
return weights




In [ ]:

def simulate_null_spectral(weights, n_simulate=1000, seed=275):
"""
weights: chi-square weights (for the infinite weigted sum of chi squares)
Return the values of MMD^2 (NOT n*MMD^2) simulated from the null distribution by
the spectral method.
"""
# draw at most block_size values at a time
block_size = 400
D = len(weights)
mmds = np.zeros(n_simulate)
from_ind = 0

with util.NumpySeedContext(seed=seed):
while from_ind < n_simulate:
to_draw = min(block_size, n_simulate-from_ind)
# draw chi^2 random variables.
chi2 = np.random.randn(D, to_draw)**2
# an array of length to_draw
sim_mmds = 2.0*weights.dot(chi2-1.0)
# store
end_ind = from_ind+to_draw
mmds[from_ind:end_ind] = sim_mmds
from_ind = end_ind
return mmds




In [ ]:

xytr = np.vstack((xtr, ytr))
chi2_weights = chi_square_weights_H0(k, xytr)
sim_mmds = simulate_null_spectral(chi2_weights, n_simulate=2000)
a = 0.6

plt.figure(figsize=(8, 5))
plt.hist(perm_mmds1,20, color='blue', normed=True, label='Permutation', alpha=a)
plt.hist(sim_mmds, 20, color='red', normed=True, label='Spectral', alpha=a)
plt.legend()




In [ ]:

# test h1_mean_var
start = time.time()
mean, var = tst.QuadMMDTest.h1_mean_var(xtr, ytr, k, is_var_computed=True)
end = time.time()
print('h1_mean_var took: %.3f'%(end - start))
print('mean: %.3g, var: %.3g'%(mean, var))




In [ ]:

# test h1_mean_var_gram
start = time.time()
Kx = k.eval(xtr, xtr)
Ky = k.eval(ytr, ytr)
Kxy = k.eval(xtr, ytr)
mean, var = tst.QuadMMDTest.h1_mean_var_gram(Kx, Ky, Kxy, k, True)
end = time.time()
print('h1_mean_var took: %.3f'%(end - start))
print('mean: %.3g, var: %.3g'%(mean, var))




In [ ]:

# choose the best parameter and perform a test with permutations
med = util.meddistance(tr.stack_xy(), 1000)
list_gwidth = np.hstack( ( (med**2) *(2.0**np.linspace(-4, 4, 20) ) ) )
list_gwidth.sort()
list_kernels = [kernel.KGauss(gw2) for gw2 in list_gwidth]

# grid search to choose the best Gaussian width
besti, powers = tst.QuadMMDTest.grid_search_kernel(tr, list_kernels, alpha)
# perform test
best_ker = list_kernels[besti]




In [ ]:

start = time.time()
test_result = mmd_test.perform_test(te)
end = time.time()
print('MMD test took: %s seconds'%(end-start))
test_result



## New permutation



In [ ]:

# # @staticmethod
# def permutation_list_mmd2_rahul(X, Y, k, n_permute=400, seed=8273):
#     """ Permutation by maintaining inverse indices. This approach is due to
#     Rahul (Soumyajit De) briefly described in "Generative Models and Model
#     Criticism via Optimized Maximum Mean Discrepancy"

#     X: m x d matrix
#     Y: n x d matrix of data
#     k: a Kernel
#     """

#     def which_term(I, J, m):
#         """
#         There are three terms in MMD computation.
#         MMD^2 = \sum_{i=1}^m \sum_{j \neq i} k(x_i, x_j)/(m*(m-1))
#             \sum_{i=1}^n \sum_{j \neq i} k(y_i, y_j)/(n*(n-1))
#             -2*\sum_{i=1}^m \sum_{j=1}^n k(x_i, y_j)/(m*n)

#             Return 0 if (i,j) should participate in the first term
#             1 for the second term
#             2 for the third term

#         - I, J 1d arrays of indices.
#         """
#         assert len(I)==len(J)
#         t1 = np.logical_and(I<m , J<m)
#         t2 = np.logical_and(I>=m, J>=m)
# #         t3 = np.logical_not(np.logical_or(t1, t2))
#         t3 = np.logical_or(np.logical_and(I>=m, J<m), np.logical_and(I<m, J>=m))
#         term_inds = -np.ones(len(I))
#         term_inds = term_inds.astype(int)
#         term_inds[t1] = 0
#         term_inds[t2] = 1
#         term_inds[t3] = 2
# #         assert np.all(term_inds >= 0)
#         return term_inds

#     XY = np.vstack((X, Y))
#     m = X.shape[0]
#     n = Y.shape[0]
#     KZ = k.eval(XY, XY)
#     terms = np.zeros((n_permute, 3))
#     # a matrix of indices of size n_permute x mn
#     perm_inds = np.zeros((n_permute, m+n))
#     R = range(n_permute)
#     with util.NumpySeedContext(seed=seed):
#         for i in range(n_permute):
#             perm_inds[i] = np.random.permutation(m+n)
#     for i in range(m+n):
#         for j in range(m+n):
#             if i!=j:
#                 # ij is a number array of length 2
#                 IJ = perm_inds[:, [i, j]]
#                 term_inds = which_term(IJ[:, 0], IJ[:, 1], m)
#                 # !!! I hope the sum will not cause an overflow..
#                 terms[R, term_inds] =terms[R, term_inds] + KZ[i, j]

#     n_mmds = terms[:, 0]/float(m*(m-1)) + terms[:, 1]/float(n*(n-1)) - terms[:, 2]*2.0/(m*n)
#     return mmds, terms




In [ ]:

# xte, yte = te.xy()
# perm_mmds2, terms = permutation_list_mmd2_rahul(xte, yte, best_ker, n_permute=30, seed=399)




In [ ]:

# terms




In [ ]:

# perm_mmds = test_result['list_permuted_mmd2']
# plt.hist(perm_mmds, alpha=0.5, label='Current', bins=20)
# plt.hist(perm_mmds2, alpha=0.5, label='New version', bins=20)
# plt.legend(loc='best')




In [ ]:




In [ ]: