This notebook contains code used to generate figures that are not from experimental results. These figures are used in, for instance, the paper and presentation slides.


In [ ]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
#%config InlineBackend.figure_format = 'svg'
#%config InlineBackend.figure_format = 'pdf'
import freqopttest.util as util
import freqopttest.data as data
import freqopttest.ex.exglobal as exglo
import freqopttest.kernel as kernel
import freqopttest.tst as tst
import freqopttest.glo as glo
import freqopttest.plot as plot
import matplotlib 
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as stats
import sys

Plot simple 2d data


In [ ]:
ss = data.SSGaussVarDiff(d=2)
sam = ss.sample(n=200, seed=7)
x, y = sam.xy()

plt.plot(x[:, 0], x[:, 1], 'ob', label='$\mathsf{X}$', alpha=0.9, markeredgecolor='b')
plt.plot(y[:, 0], y[:, 1], 'or', label='$\mathsf{Y}$', alpha=0.7, markeredgecolor='r')
plt.gca().get_xaxis().set_visible(False)
plt.gca().get_yaxis().set_visible(False)
plt.box(False)
plt.legend(loc='best')
plt.savefig('gvd_demo.pdf')

Plot blobs dataset


In [ ]:
# font options
font = {
    #'family' : 'normal',
    #'weight' : 'bold',
    'size'   : 32
}

plt.rc('font', **font)
plt.rc('lines', linewidth=3)
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42


matplotlib.rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})

# matplotlib.use('cairo')
matplotlib.rc('text', usetex=True)
matplotlib.rcParams['image.cmap'] = 'jet'
matplotlib.rcParams['text.usetex'] = True
plt.rc('font', **font)


# params = {'legend.fontsize': 'large',
#           #'figure.figsize': (15, 5),
#          'axes.labelsize': 'x-large',
#          'axes.titlesize':'x-large',
#          'xtick.labelsize':'x-large',
#          'ytick.labelsize':'x-large'}
# plt.rcParams.update(params)

In [ ]:
import os
def oral_save_path(rel_path):
    return os.path.join('..', 'text', 'nips2016', 'oral_img', rel_path)

def prefix_path(rel_path):
    return oral_save_path(rel_path)

In [ ]:
ss_blobs = data.SSBlobs()
sam = ss_blobs.sample(n=2000, seed=2)
bx, by = sam.xy()

plt.figure()
plt.plot(by[:, 0], by[:, 1], 'ob', markersize=3, markeredgecolor='b')
plt.xlim([-12, 12])
plt.ylim([-10, 10])
plt.title('Blobs data. Sample from P.')
plt.savefig('blobs_p.pdf')

plt.figure()
plt.plot(bx[:, 0], bx[:, 1], 'or', markersize=3, markeredgecolor='r')
plt.xlim([-12, 12])
plt.ylim([-10, 10])
plt.title('Blobs data. Sample from Q.')
plt.savefig('blobs_q.pdf')

Oral presentation

Figures used in the slides for NIPS 2016 oral presentation.


In [ ]:
# # font options
# font = {
#     #'family' : 'normal',
#     #'weight' : 'bold',
#     'size'   : 18
# }

# plt.rc('font', **font)
# plt.rc('lines', linewidth=3)
# matplotlib.rcParams['pdf.fonttype'] = 42
# matplotlib.rcParams['ps.fonttype'] = 42

In [ ]:
def plot_data_2d(x, y, title, dest_fname):
    plt.plot(x[:, 0], x[:, 1], 'ob', label='$\mathsf{X}$', alpha=0.8, markeredgecolor='b')
    plt.plot(y[:, 0], y[:, 1], 'or', label='$\mathsf{Y}$', alpha=0.7, markeredgecolor='r')
    plt.gca().get_xaxis().set_visible(False)
    plt.gca().get_yaxis().set_visible(False)
    plt.box(False)
    plt.legend()
    plt.title(title)
    plt.savefig(dest_fname, bbox_inches='tight')

In [ ]:
with util.NumpySeedContext(seed=6):
    n = 200
    x = np.random.rand(n, 2)
    y = np.random.rand(n, 2)
plot_data_2d(x, y, '$P=Q$', oral_save_path('uniform_2d_data.pdf'))

In [ ]:
with util.NumpySeedContext(seed=9):
    n = 400
    x = np.random.randn(n, 2)
    y = np.random.randn(n, 2).dot(np.diag([1, 4]))
plot_data_2d(x, y, r'$P \neq Q$', oral_save_path('h1_true_data.pdf'))

$H_0/H_1$ distributions


In [ ]:
J = 5
alpha = 0.03

# plot the null distribution
h0_xmin = stats.chi2.ppf(0.0001, J)
h0_xmax = stats.chi2.ppf(0.9999, J)
domain = np.linspace(h0_xmin, h0_xmax, 300)
# noncentrality parameter
nc = 20
h1_dom = np.linspace(stats.ncx2.ppf(0.0001, J, nc), stats.ncx2.ppf(0.9999, J, nc), 200)

dom = np.sort(np.hstack((domain, h1_dom)))
plt.figure(figsize=(10,4))
plt.plot(dom, stats.chi2.pdf(dom, J), 'b-', 
        label='$\mathbb{P}_{H_0}(\hat{\lambda}_n)$'
#           label='$\chi^2(J)$'
        )

# rejection threshold
thresh = stats.chi2.isf(alpha, J)
plt.plot([thresh, thresh], [0, stats.chi2.pdf(J, J)/2], '*g-', label=r'$T_\alpha$')

# plot H1 distribution. Say a noncentral Chi2
# plt.plot(dom, stats.ncx2.pdf(dom, J, nc), 'r-', label='$\mathbb{P}_{H_1}(\hat{\lambda}_n)$')

# actual test statistic
stat = 36
#plt.plot([stat, stat], [0, stats.chi2.pdf(J, J)/3], '*m-', label='$\hat{\lambda}_n$', linewidth=2)

# p-value
pval_x = np.linspace(stat, h0_xmax, 400)
pval_y = stats.chi2.pdf(pval_x, J)
#plt.fill_between(pval_x, np.zeros(len(pval_x)), pval_y, color='gray', alpha='0.5', label='a')

p1 = plt.Rectangle((0, 0), 1, 1, fc='gray')
#shade_leg = plt.legend([p1], ['p-val.'], loc='upper right', 
#                       bbox_to_anchor=(0.96, 0.53), frameon=False)
#plt.gca().add_artist(shade_leg)
plt.legend(loc='best', frameon=True)
plt.box(False)
plt.axhline(0, color='black')
plt.gca().get_yaxis().set_visible(False)
plt.gca().xaxis.set_ticks_position('bottom')
plt.xlabel(r'$\hat{\lambda}_n$')

#plt.savefig('h0_dist.pdf', bbox_inches='tight')
plt.savefig('h0_dist_thresh.pdf', bbox_inches='tight')
# plt.savefig('h0_h1_dists.pdf', bbox_inches='tight')

Test power highlight


In [ ]:
plt.figure(figsize=(10,4))
plt.plot(dom, stats.chi2.pdf(dom, J), 'b-', label=r'$\chi^2(J)$')
plt.plot([thresh, thresh], [0, stats.chi2.pdf(J, J)/2], '*g-', label=r'$T_\alpha$')
plt.plot(dom, stats.ncx2.pdf(dom, J, nc), 'r-', label='$\mathbb{P}_{H_1}(\hat{\lambda}_n)$')

#plt.stem([stat], [stats.chi2.pdf(J, J)/3], 'om-', label='$\hat{\lambda}$')
# test power 
pow_x = np.linspace(thresh, max(h1_dom), 400)
pow_y = stats.ncx2.pdf(pow_x, J, nc)
plt.fill_between(pow_x, np.zeros(len(pow_x)), pow_y, color='magenta', alpha=0.2)

p2 = plt.Rectangle((0, 0), 1, 1, fc='magenta', alpha=0.2)
shade_pow = plt.legend([p2], ['power'], loc='upper right',
                       bbox_to_anchor=(0.97, 0.53), frameon=False)
#plt.gca().add_artist(shade_pow)
plt.box(False)
plt.legend(frameon=True)
plt.gca().get_yaxis().set_visible(False)
#plt.gca().xaxis.set_visible(True)
plt.axhline(0, color='black')
plt.gca().xaxis.set_ticks_position('bottom')
plt.xlabel(r'$\hat{\lambda}_n$')
plt.savefig('test_power_demo.pdf',  bbox_inches='tight')
#plt.savefig('test_power_demo2.pdf',  bbox_inches='tight')

Type-1 error


In [ ]:
plt.figure(figsize=(10,4))
plt.plot(domain, stats.chi2.pdf(domain, J), 'b-', label=r'$\mathbb{P}_{H_0}(\hat{\lambda}_n)$')
plt.plot(h1_dom, stats.ncx2.pdf(h1_dom, J, nc), 'r-', label=r'$\mathbb{P}_{H_1}(\hat{\lambda}_n)$')
plt.plot([thresh, thresh], [0, stats.chi2.pdf(J, J)/2], '*g-', label=r'$T_\alpha$', linewidth=2)
#plt.stem([stat], [stats.chi2.pdf(J, J)/3], 'om-', label='$\hat{\lambda}$')
# type-1 error
t1_x = np.linspace(thresh, max(domain), 400)
t1_y = stats.chi2.pdf(t1_x, J)
plt.fill_between(t1_x, np.zeros(len(t1_x)), t1_y, color='brown', alpha=0.4)

p2 = plt.Rectangle((0, 0), 1, 1, fc='brown', alpha=0.4)
shade_pow = plt.legend([p2], ['type-I'], loc='upper right', 
                       bbox_to_anchor=(0.96, 0.53), frameon=False)
plt.box(False)
plt.gca().add_artist(shade_pow)
plt.legend(loc='best', frameon=False)
plt.gca().get_yaxis().set_visible(False)
plt.xlabel(r'$\hat{\lambda}_n$')
plt.savefig('type1_error_demo.pdf')

Noncentral chi-square / power


In [ ]:
J = 5
dom = np.linspace(1e-5, 95, 500)
thresh = 27
ncs = [10, 30, 50]
pow_colors = ['blue', 'green', 'red']

plt.figure(figsize=(10,4))
for i, nc in enumerate(ncs):
    plt.plot(dom, stats.ncx2.pdf(dom, J, nc), label=r'$\chi^2(J, \, %.2g)$'%( nc),
            linewidth=2)
    # plot the test powers
    pow_dom = np.linspace(thresh, max(dom), 500)
    plt.fill_between(pow_dom, np.zeros(len(dom)), 
                     stats.ncx2.pdf(pow_dom, J, nc), color=pow_colors[i], alpha=0.2)
    
    
plt.plot([thresh, thresh], [0, stats.ncx2.pdf(ncs[0]+J, J, ncs[0])/3], 
         '*m-', label='$T_\\alpha$')
plt.legend(loc='best', frameon=False)
plt.title('$J = %d$'%J)
plt.savefig('ncx2_pow_inc.pdf')

Visualize mean embeddings


In [ ]:
sigma2 = 1
def kgauss_me(Xte, X):
    """Inputs are 2d arrays"""
    kg = kernel.KGauss(sigma2=sigma2)
    K_dom_x = kg.eval(Xte, X)
    me_xte = np.mean(K_dom_x, 1)
    return me_xte
    
X = np.array([[2, 3.2, 4]]).T
Y = np.array([[4.5, 4.9, 6]]).T
xy = np.vstack((X, Y))

dom = np.linspace(min(xy)-3*sigma2, max(xy)+3*sigma2, 300)
dom_mat = dom[:, np.newaxis]

me_x = kgauss_me(dom_mat, X)
me_y = kgauss_me(dom_mat, Y)
me_diff = 1*(me_x - me_y)
# plot mean embedding as a function

# data X
markersize = 9
plt.figure(figsize=(10, 5))
plt.plot(X[:, 0], np.zeros(X.shape[0]) + 5e-3, 'ob', 
         markersize=markersize
        # , label='$\mathsf{X}$'
        )
# mu_x 
plt.plot(dom, me_x, 'b-', label='$\hat{\mu}_P(\mathbf{v})$')

# data Y
plt.plot(Y[:, 0], np.zeros(Y.shape[0]) + 5e-3, 'or', 
         markersize=markersize
        # , label='$\mathsf{Y}$'
        )
# mu_y
plt.plot(dom, me_y, 'r-', label=r'$\hat{\mu}_Q(\mathbf{v})$')
# diff
#plt.plot(dom, me_diff, '-g', label='$\hat{\mu}_P(\mathbf{v}) - \hat{\mu}_Q(\mathbf{v})$')

# diff^2
maxheight = np.max(np.maximum(me_x, me_y))
max_diff2 = np.max(me_diff**2)
plt.plot(dom, me_diff**2/max_diff2*maxheight*1.2, '-g', 
         label='$(\hat{\mu}_P(\mathbf{v}) - \hat{\mu}_Q(\mathbf{v}))^2$')

self_me_x = kgauss_me(X, X)
for i in range(X.shape[0]):
    plt.plot([X[i, 0], X[i, 0]], [0, self_me_x[i]], '--', color='blue', alpha=0.4)
self_me_y = kgauss_me(Y, Y)
for i in range(Y.shape[0]):
    plt.plot([Y[i, 0], Y[i, 0]], [0, self_me_y[i]], '--', color='red', alpha=0.4)
    
plt.xlabel('$\mathbf{v}$')
#plt.title('$\mathsf{X}$ = blue points. $\mathsf{Y}$ = red points.')
plt.legend(loc='best')
plt.gca().get_xaxis().set_visible(False)
plt.gca().get_yaxis().set_visible(False)
plt.box(False)
#plt.title('ME Test')
plt.savefig('mean_embeddings_diff.pdf')

Densities in the frequency domain


In [ ]:
import scipy as sp
import scipy.signal as sig

def tri_fun(x, w=1.0):
    return sig.bspline(x/w, 1)

def smooth_ker(x, w=1.0):
    return stats.norm.pdf(x, loc=0, scale=w)
    
v1 = 3
dom = np.linspace(-4, 4, 1000)
x = tri_fun(dom, w=1)
y = tri_fun(dom, w=2)

k = smooth_ker(dom, w=1.3)
x_k = np.convolve(x, k, 'same')
y_k = np.convolve(y, k, 'same')

plt_scale = 1.0/20
plt.figure(figsize=(7, 4))
# plot v
plt.plot([v1, v1], [0, np.max(x)/3.0], '--og', label=r'$\mathbf{v}_1$')
plt.plot(dom, x, 'b-', label=r'$\hat{p}(\mathbf{\omega})$')
plt.plot(dom, y, 'r-', label=r'$\hat{q}(\mathbf{\omega})$')
plt.legend(loc='best')
plt.title('Characteristic functions $\hat{p}(\mathbf{\omega}), \hat{q}(\mathbf{\omega})$')
#plt.gca().get_xaxis().set_visible(False)
plt.gca().get_yaxis().set_visible(False)
plt.box(False)
plt.savefig('characteristic_funcs.pdf')
#plt.plot(dom, plt_scale*x_k, '-r')

In [ ]:
plt.figure(figsize=(7, 4))
# plot v
plt.plot([v1, v1], [0, np.max(y_k)/3.0], '--og', label=r'$\mathbf{v}_1$')
plt.plot(dom, x_k, 'b-', label=r'$(l \ast \hat{p})(\mathbf{\omega})$')
plt.plot(dom, y_k, 'r-', label=r'$(l \ast \hat{p})(\mathbf{\omega})$')
plt.legend()
plt.title('Smoothed characteristic functions')
plt.gca().get_yaxis().set_visible(False)
plt.box(False)
plt.savefig('smooth_cfs.pdf')

ME test: interactive test locations


In [ ]:
from mpl_toolkits.mplot3d import axes3d
from matplotlib import cm

def best_loc2_testpower(X, Y, gamma, loc1):
    """Show a heatmap of Lambda(T) on many locations of the test points. 
    J=2 (two locations). Assume loc1 is given. Vary loc2 (2d). 
    * loc1 can be None, in which case, use only one test location."""
    
    # For simplicity, we will assume that J=2 (two frequencies) 
    # and that one (loc1) is fixed. We will optimize loc2 (2-dimensional).
    XY = np.vstack((X,Y))
    max1, max2 = np.max(XY, 0)
    min1, min2 = np.min(XY, 0)
    #sd1, sd2 = np.std(XY, 0)
    sd1, sd2 = (0, 0)
    # form a frequency grid to try 
    nd1 = 60
    nd2 = 60
    loc1_cands = np.linspace(min1-sd1/2, max1+sd1/2, nd1)
    loc2_cands = np.linspace(min2-sd2/2, max2+sd2/2, nd2)
    lloc1, lloc2 = np.meshgrid(loc1_cands, loc2_cands)
    # nd2 x nd1 x 2
    loc3d = np.dstack((lloc1, lloc2))
    # #candidates x 2
    all_loc2s = np.reshape(loc3d, (-1, 2) )
    
    # all_locs = #candidates x J x 2
    if loc1 is not None:
        all_locs = np.array( [np.vstack((c, loc1)) for c in all_loc2s] )
    else:
        all_locs = np.array( [ c[np.newaxis, :] for c in all_loc2s] )

    # evaluate Lambda(T) on each candidate T on the grid. Size = (#candidates, )
    stat_grid = np.array([t2_stat(X, Y, T, gamma) for T in all_locs])
    stat_grid = np.reshape(stat_grid, (nd2, nd1) )
    
    #ax = fig.gca(projection='3d')
    #ax.plot_surface(lloc1, lloc2, stat_grid, rstride=8, cstride=8, alpha=0.3)
    #cset = ax.contourf(lloc1, lloc2, stat_grid, zdir='z', offset=0, cmap=cm.coolwarm)
    #plt.figure(figsize=(7, 4))
    plt.contourf(lloc1, lloc2, stat_grid, alpha=0.28)
#     plt.gca().get_xaxis().set_visible(False)
#     plt.gca().get_yaxis().set_visible(False)
    plt.colorbar()

    max_stat = np.max(stat_grid)
    #plt.xlabel('')
    #plt.ylabel('')
    plt.title('$\mathbf{v}_2 \mapsto \hat{\lambda}_{n}^{tr}(\mathbf{v}_1, \mathbf{v}_2)$')
    #ax.view_init(elev=max_stat*2, azim=90)

    
    # plot the data
    #plt.figure(figsize=(12, 4))
    plt.plot(X[:, 0], X[:, 1], 'ob', label='$\mathsf{X}$', 
             markeredgecolor='b', markersize=4, alpha=0.9)
    plt.plot(Y[:, 0], Y[:, 1], 'or', label='$\mathsf{Y}$', 
             markeredgecolor='r', markersize=4, alpha=0.9)
    if loc1 is not None:
        loc1x, loc1y = loc1
        plt.plot(loc1x, loc1y, '^k', markersize=20, label='$\mathbf{v}_1$')
    #plt.legend(loc='best')
    plt.savefig('lambda_t2_surface.pdf')

In [ ]:
def t2_stat(X, Y, locs, gamma):
    """
    locs: J x d
    """
    tst_data = data.TSTData(X, Y)
    me = tst.MeanEmbeddingTest(locs, gamma, alpha=alpha)
    result = me.perform_test(tst_data)
    return result['test_stat']


alpha = 0.01    
# locations to test
ss = data.SSGaussMeanDiff(d=2, my=1.0)
#ss = data.SSGaussVarDiff(d=2)
#ss = data.SSBlobs()
n = 500
tst_data = ss.sample(n=n, seed=6)
#locs = tst.MeanEmbeddingTest.init_locs_2randn(tst_data, n_test_locs=2, seed=39)

X, Y = tst_data.xy()

In [ ]:
from __future__ import print_function
from ipywidgets import interact, interactive, fixed
from IPython.display import display
import ipywidgets as widgets

# interactively select test locations
def me_test_plot_interact(X, Y, loc1x=0, loc1y=0, gamma=1):
    #locs = np.array([[loc1x, loc1y], [loc2x, loc2y]])
    
    # compute the test statistic
    loc1 = np.array([loc1x, loc1y])
    best_loc2_testpower(X, Y, gamma, loc1)


loc1_bnd = (-5, 5, 0.1)
# loc1_bnd = (np.min(counts), np.max(counts), 1)
loc2_bnd = loc1_bnd
vs = interactive(me_test_plot_interact, X=fixed(X), Y=fixed(Y), loc1x=loc1_bnd, 
        loc1y=loc2_bnd, gamma=(0.1, 10, 0.1));
display(vs)

In [ ]:
# Another dataset

with util.NumpySeedContext(seed=74):
    # match count of each of the n players
    #counts = stats.poisson.rvs(mu, size=n)
#     counts = stats.uniform.rvs(loc=2, scale=30, size=n).astype(int)
    counts = 1+stats.expon.rvs(scale=15, size=n).astype(int)

    A = stats.uniform.rvs(loc=10*counts, scale=4, size=n)
    B = stats.uniform.rvs(loc=7*counts**0.6, scale=4, size=n)
    pobs = stats.beta.rvs(a=A, b=B)
    rand_signs = stats.bernoulli.rvs(0.5, size=n)*2-1
    pmodel = pobs + 0.5*np.random.randn(n)/counts**1.5 + rand_signs*np.random.rand(n)*0.03
    pmodel = np.maximum(np.minimum(pmodel, 1), 0)

X = np.vstack((counts, pobs)).T
Y = np.vstack((counts, pmodel)).T

plt.plot(pobs, pmodel, 'ko')
plt.xlabel('P obs')
plt.ylabel('P model')

In [ ]:
gw2 = util.meddistance(np.vstack((X, Y)), subsample=1000)
print('Gaussian width^2: {0}'.format(gw2))
plt.figure(figsize=(10, 5))
best_loc2_testpower(X, Y, gw2, loc1=None)
plt.title('Test criterion')
plt.xlabel('Match count')
plt.ylabel('P(win)')
plt.xlim([1, 40])
plt.savefig('tomminka_game_problem.pdf', bbox_inches='tight')

ME test vs MMD witness function


In [ ]:
# sample source 
m = 3000
#dim = 2
n = m
#ss = data.SSGaussMeanDiff(dim, my=0.5)
#ss = data.SSUnif(plb=[-2, -2], pub=[0, 2], qlb=[-2, -2], qub=[2, 2])
#ss = data.SSUnif(plb=[-2], pub=[0], qlb=[-2], qub=[2])

pbs = np.array([[-2, 0]])
slack = 0.0
qbs = np.array([[-2+slack, 0-slack], [2, 4]])
ss = data.SSMixUnif1D(pbs, qbs)

#ss = data.SSGaussVarDiff(d=dim)
#ss = data.SSBlobs()
tst_data = ss.sample(m, seed=9)
tr, te = tst_data.split_tr_te(tr_proportion=0.5, seed=11)

In [ ]:
alpha = 0.01
#gwidth = util.meddistance(tr.stack_xy(), subsample=1000)**2
gwidth = 0.3
reg = 0.0

dom = np.linspace(-6, 6, 400)
# add an x very close to 0
#dom = np.append(dom, [1e-9])
dom.sort()
xtr, ytr = tr.xy()

test_stats = np.zeros(len(dom))
# unnormalized stats
sigs = np.zeros(len(dom))
un_stats = np.zeros(len(dom))
witness = np.zeros(len(dom))
varx = np.zeros(len(dom))
vary = np.zeros(len(dom))
for i, t1x in enumerate(dom):
    t1 = np.array([t1x])
    T = t1[np.newaxis, :]
    #met_i = tst.MeanEmbeddingTest(T, gwidth, alpha)
    #test_i = met_i.perform_test(te)
    #stats[i] = test_i['test_stat']
    g = tst.MeanEmbeddingTest.gauss_kernel(xtr, T, gwidth)
    h = tst.MeanEmbeddingTest.gauss_kernel(ytr, T, gwidth)
    varx[i] = np.cov(g.T)
    vary[i] = np.cov(h.T)
 
    Z = g-h
    #Sig = np.cov(Z.T)
    Sig = varx[i] + vary[i]
    W = np.mean(Z, 0)
    stat = n*(W[0]**2)/(Sig + reg)
    
    test_stats[i] = stat
    sigs[i] = Sig
    un_stats[i] = n*W[0]**2
    witness[i] = W[0]
    
print('gwidth**2: %.3f'%gwidth)

In [ ]:
# plot location shift vs. test stat
plt.figure(figsize=(10, 5))
plt.plot(dom, test_stats, 'g-', label=r'$\hat{\lambda}_n(\mathbf{v})$')
#plt.plot(t1x_list, sigs, label=r'$\sigma_n^2 $')
# renormalized variance to match the scaling of the statistics
norm_sigs = sigs/np.max(sigs)*np.max(test_stats)/2
plt.plot(dom, norm_sigs, 'm-', label=r'$\propto \mathbf{S}_n(\mathbf{v})$')
norm_un_stats = un_stats/np.max(un_stats)*np.max(test_stats)
plt.plot(dom, norm_un_stats, 'k-', 
         label=r'$(\hat{\mu}_P(\mathbf{v}) - \hat{\mu}_Q(\mathbf{v}))^2$')

#plt.title(r'$\mathbf{v}_1 = [v, 0]$, $\gamma_n = %.2f$, $n=%d$'%(reg, xte.shape[0]))
plt.legend(bbox_to_anchor=(1.2, 1))
#plt.xlabel('$v$ in $1^{st}$ dim. of $\mathbf{v}_1$')
plt.xlabel(r'$\mathbf{v}$', fontsize=40)
#plt.xlim([-10, 6])
plt.gca().get_yaxis().set_visible(False)
plt.gca().xaxis.set_ticks_position('bottom')
plt.box(False)
plt.savefig('witness_vs_normalized_stat.pdf')
#plt.ylabel('Test statistic')

In [ ]:
# sigs = varx + vary
plt.figure(figsize=(10, 4))
plt.plot(dom, varx, 'b--', label=r'$\hat{s}_\mathbf{x}(\mathbf{v})$')
plt.plot(dom, vary, 'r--', label=r'$\hat{s}_\mathbf{y}(\mathbf{v})$')
plt.plot(dom, sigs, 'm-', label=r'$\hat{s}(\mathbf{v})$', alpha=0.5)
# r'$\hat{s}(\mathbf{v}) = \hat{s}_\mathbf{x}(\mathbf{v}) + \hat{s}_\mathbf{y}(\mathbf{v})$'
#plt.title('(unnormalized) variances')
plt.xlabel('$\mathbf{v}$', fontsize=40)
plt.ylim([0, np.max(sigs)+5e-3])
plt.gca().get_yaxis().set_visible(False)
plt.gca().xaxis.set_ticks_position('bottom')
plt.box(False)
plt.legend(bbox_to_anchor=(0.3, 1))
#plt.legend()

#plt.savefig('me_var_x.pdf', bbox_inches='tight')
#plt.savefig('me_var_xy.pdf', bbox_inches='tight')
plt.savefig('me_var.pdf', bbox_inches='tight')

In the above figure, $\sigma_n^2 = \mathbf{S}_n$ because this is a one-dimensional example. It is easily seen from the independence of $X$ and $Y$ that $\sigma_n^2 = \sigma^2_x + \sigma^2_y$, where $\sigma^2_x(v) = \frac{1}{n}\sum_{i=1}^n \left( k(x_i, v) - \frac{1}{n} \sum_{j=1}^n k(x_j, v)\right)^2$ and $\sigma^2_y(v) = \frac{1}{n}\sum_{i=1}^n \left( k(y_i, v) - \frac{1}{n} \sum_{j=1}^n k(y_j, v)\right)^2$.


In [ ]:
plt.figure(figsize=(10, 4))
plt.plot(dom, witness, 'k-')
plt.title('MMD witness function')
plt.grid(True)

In [ ]:
def plot_witness( px_label=r'$P$', py_label=r'$Q$'):
    """
    plot_var: A list of 'x', 'y', or 'total'. Not plotting if [].
    """
    # plot as 1d. Take the first dimension
    plt.figure(figsize=(8, 4))
    zthresh = 1e-2
    pden = ss.density_p(dom)
    qden = ss.density_q(dom)
    pden[pden<=zthresh] = zthresh
   
    plt.plot(dom, pden, 'b-', alpha=1,label=px_label,)
    plt.plot(dom, qden, 'r-', alpha=0.99, label=py_label,)
    max_den = np.max(np.maximum(pden, qden))
    max_witness = np.max(witness)*1
    #toplot_witness = witness/max_witness*max_den
    toplot_witness = witness
    plt.plot(dom, toplot_witness, 'k-', 
            # label=r'$(\hat{\mu}_P(\mathbf{v}) - \hat{\mu}_Q(\mathbf{v}))^2$'
             label=r'$\mathrm{witness}$'
            )
    #plt.xlabel(r'$\mathbf{v}$', fontsize=40)
    #plt.legend(loc='upper left')
    plt.legend(
        fontsize=18, 
        bbox_to_anchor=(0.32, 1.00),
        #loc='lower left'
        )
    plt.xlim([np.min(dom), np.max(dom)])
    plt.ylim([-0.4, 0.55])
    plt.yticks([ -0.25, 0, 0.25, 0.5], fontsize=18)
    plt.xticks(fontsize=18)
    #plt.gca().yaxis.set_visible(False)
    plt.gca().xaxis.set_ticks_position('bottom')
    plt.grid()
    #plt.box(False)
    #plt.axis('off')
    
    
######
plot_witness()
plt.savefig(prefix_path('unif_overlap_unsquared_witness.pdf'), 
            bbox_inches='tight'
           )

Variances for the two-uniform problem


In [ ]:
alpha = 0.01
#gwidth = util.meddistance(tr.stack_xy(), subsample=1000)**2
gwidth = 1.0
reg = 0.0

t1x_list = np.linspace(-6, 6, 200)
# add an x very close to 0
#t1x_list = np.append(t1x_list, [1e-9])
#t1x_list.sort()

test_stats = np.zeros(len(t1x_list))
# unnormalized stats
sigs = np.zeros(len(t1x_list))
un_stats = np.zeros(len(t1x_list))
witness = np.zeros(len(t1x_list))
varx = np.zeros(len(t1x_list))
vary = np.zeros(len(t1x_list))
for i, t1x in enumerate(t1x_list):
    t1 = np.array([t1x])
    T = t1[np.newaxis, :]
    #met_i = tst.MeanEmbeddingTest(T, gwidth, alpha)
    #test_i = met_i.perform_test(te)
    #stats[i] = test_i['test_stat']
    g = tst.MeanEmbeddingTest.gauss_kernel(xtr, T, gwidth)
    h = tst.MeanEmbeddingTest.gauss_kernel(ytr, T, gwidth)
    Z = g-h
    Sig = np.cov(Z.T)
    W = np.mean(Z, 0)
    stat = n*(W[0]**2)/(Sig + reg)
    
    test_stats[i] = stat
    sigs[i] = Sig
    un_stats[i] = n*W[0]**2
    witness[i] = W[0]
    varx[i] = np.cov(g.T)
    vary[i] = np.cov(h.T)
    
print('gwidth**2: %.3f'%gwidth)

In [ ]:


In [ ]:


In [ ]:


In [ ]:


In [ ]: