A notebook to test and demonstrate the MMD test
of Gretton et al., 2012 used as a goodness-of-fit test. Require the ability to sample from the density p
.
In [ ]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
#%config InlineBackend.figure_format = 'svg'
#%config InlineBackend.figure_format = 'pdf'
import freqopttest.tst as tst
import kgof
import kgof.data as data
import kgof.density as density
import kgof.goftest as gof
import kgof.mmd as mgof
import kgof.kernel as ker
import kgof.util as util
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as stats
In [ ]:
# font options
font = {
#'family' : 'normal',
#'weight' : 'bold',
'size' : 16
}
plt.rc('font', **font)
plt.rc('lines', linewidth=2)
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
In [ ]:
# true p
seed = 20
d = 2
# sample
n = 400
alpha = 0.05
mean = np.zeros(d)
variance = 1
In [ ]:
p = density.IsotropicNormal(mean, variance)
q_mean = mean.copy()
q_variance = variance
# q_mean[0] = 1
ds = data.DSIsotropicNormal(q_mean+1, q_variance)
# q_means = np.array([ [0], [0]])
# q_variances = np.array([0.01, 1])
# ds = data.DSIsoGaussianMixture(q_means, q_variances, pmix=[0.2, 0.8])
In [ ]:
# Test
dat = ds.sample(n, seed=seed+2)
X = dat.data()
# Use median heuristic to determine the Gaussian kernel width
sig2 = util.meddistance(X, subsample=1000)**2
k = ker.KGauss(sig2)
In [ ]:
mmd_test = mgof.QuadMMDGof(p, k, n_permute=300, alpha=alpha, seed=seed)
mmd_result = mmd_test.perform_test(dat)
mmd_result
In [ ]:
print('Reject H0?: {0}'.format(mmd_result['h0_rejected']))
sim_stats = mmd_result['list_permuted_mmd2']
stat = mmd_result['test_stat']
unif_weights = np.ones_like(sim_stats)/float(len(sim_stats))
plt.hist(sim_stats, label='Simulated', weights=unif_weights)
plt.plot([stat, stat], [0, 0], 'r*', markersize=30, label='Stat')
plt.legend(loc='best')
In [ ]:
def gbrbm_perturb(std_perturb_B, dx=50, dh=10):
"""
Get a Gaussian-Bernoulli RBM problem where the first entry of the B matrix
(the matrix linking the latent and the observation) is perturbed.
- var_perturb_B: Gaussian noise variance for perturbing B.
- dx: observed dimension
- dh: latent dimension
Return p (density), data source
"""
with util.NumpySeedContext(seed=10):
B = np.random.randint(0, 2, (dx, dh))*2 - 1.0
b = np.random.randn(dx)
c = np.random.randn(dh)
p = density.GaussBernRBM(B, b, c)
B_perturb = B.copy()
if std_perturb_B > 1e-7:
B_perturb[0, 0] = B_perturb[0, 0] + \
np.random.randn(1)*std_perturb_B
ds = data.DSGaussBernRBM(B_perturb, b, c, burnin=2000)
return p, ds
def gbrbm_perturb_all(std_perturb_B, dx=50, dh=10):
"""
Get a Gaussian-Bernoulli RBM problem where all entries of B
(the matrix linking the latent and the observation) are perturbed.
- var_perturb_B: Gaussian noise variance for perturbing B.
- dx: observed dimension
- dh: latent dimension
Return p (density), data source
"""
with util.NumpySeedContext(seed=11):
B = np.random.randint(0, 2, (dx, dh))*2 - 1.0
b = np.random.randn(dx)
c = np.random.randn(dh)
p = density.GaussBernRBM(B, b, c)
if std_perturb_B > 1e-7:
B_perturb = B + np.random.randn(dx, dh)*std_perturb_B
ds = data.DSGaussBernRBM(B_perturb, b, c, burnin=2000)
return p, ds
In [ ]:
n = 1000
d = 50
seed = 991
# p, qds = gbrbm_perturb_all(0.06, dx=d, dh=10)
p, qds = gbrbm_perturb(np.sqrt(0.1), dx=d, dh=10)
qdat = qds.sample(n, seed=seed+3)
Y = qdat.data()
pds = p.get_datasource()
datX = pds.sample(n, seed=seed+1)
X = datX.data()
XY = np.vstack((X, Y))
In [ ]:
np.var(X, 0)
In [ ]:
np.var(Y, 0)
In [ ]:
# Get the median heuristic for each dimension
med_factors = 2.0**np.linspace(-5, 5, 30)
meds = np.zeros(d)
for i in range(d):
medi = util.meddistance(XY[:, [i]], subsample=1000)
meds[i] = medi
candidate_kernels = []
for i in range(len(med_factors)):
ki = ker.KDiagGauss( (meds**2)*med_factors[i] )
candidate_kernels.append(ki)
# k = ker.KDiagGauss(2*meds**2)
# Construct a list of kernels to try based on multiples of the median
# heuristic
# med = util.meddistance(XY, subsample=1000)
# candidate_kernels = [ker.KGauss(f*med**2) for f in med_factors]
# k = ker.KGauss((2.0**-1)*med**2)
# candidate_kernels = [k]
In [ ]:
mmd_opt = mgof.QuadMMDGofOpt(p, n_permute=300, alpha=alpha, seed=seed+3)
mmd_result = mmd_opt.perform_test(qdat,
candidate_kernels=candidate_kernels,
tr_proportion=0.2, reg=1e-3)
mmd_result
In [ ]:
Kxy = k.eval(X, Y)
Kxx = k.eval(X, X)
Kyy = k.eval(Y, Y)
plt.figure(figsize=(8, 8))
plt.imshow(Kxy)
plt.title('Kxy')
plt.colorbar()
In [ ]:
plt.hist(Kxy.ravel(), bins=50)
In [ ]:
plt.figure(figsize=(8, 8))
plt.imshow(Kxx)
plt.title('Kxx')
plt.colorbar()
In [ ]:
plt.figure(figsize=(8, 8))
plt.imshow(Kyy)
plt.title('Kyy')
plt.colorbar()
In [ ]:
mmd = np.mean(Kxx+Kyy-2*Kxy)
mmd
In [ ]:
In [ ]:
In [ ]: