In [1]:
%load_ext autoreload
%autoreload 2
import sympy as smp
import numpy as np
import scipy.special as sp

import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

import matplotlib.pyplot as plt
%matplotlib inline

import seaborn as sns
color_names = ["windows blue",
               "red",
               "gold",
               "grass green"]
colors = sns.xkcd_palette(color_names)
sns.set_style("white")

In [2]:
def logQ_grep(eps,alpha):
    sqrt_poly = np.sqrt(sp.polygamma(1,alpha))
    norm_const = 0.5*np.log(sp.polygamma(1,alpha)) + alpha*sp.digamma(alpha)-sp.gammaln(alpha)
    exp_term = np.exp(eps*sqrt_poly + sp.digamma(alpha))
    return eps*alpha*sqrt_poly - exp_term + norm_const

def derH_eps(eps,alpha):
    b = alpha-1./3.
    c = 1./np.sqrt(9.*b)
    v = 1.+eps*c
    
    return 3.*b*c*(v**2)

def fun_H(eps,alpha):
    b = alpha-1./3.
    c = 1./np.sqrt(9.*b)
    v = 1.+eps*c
    
    return b*v**3

def Q_rej(eps,alpha):
    b = alpha-1./3.
    c = np.sqrt(9.*b)
    norm_const = -sp.gammaln(alpha)
    Q = np.zeros_like(eps)
    t_1 = (alpha-1.)*np.log(fun_H(eps[eps>-c],alpha))
    t_2 = -fun_H(eps[eps>-c],alpha)+np.log(np.abs(derH_eps(eps[eps>-c],alpha)))
    Q[eps>-c] = np.exp(t_1+t_2+norm_const)
    return Q

alphas = [1.0,2.0,10.0]
epsilon = np.linspace(-3.,3.,100)

fig = plt.figure(figsize=(16, 4))
ax = fig.add_subplot('131')
iter_ind = 0
for alpha in alphas:
    iter_ind += 1
    ax = fig.add_subplot('13'+str(iter_ind))
    ax.plot(epsilon,np.exp(-0.5*epsilon**2)/np.sqrt(2.*np.pi),'k--',lw=4, label='$\\mathcal{N}(0,1)$')
    ax.plot(epsilon,np.exp(logQ_grep(epsilon,alpha)),color=colors[0],lw=4, label='$\\mathrm{G-REP}$')
    if alpha < 1.:
        alpha_rej = alpha+1.
    else:
        alpha_rej = alpha
    ax.plot(epsilon,Q_rej(epsilon,alpha_rej),color = colors[1],lw=4, label='$\\mathrm{RSVI}$')
    ax.tick_params(axis='both', which='major', labelsize=16)
    ax.tick_params(axis='both', which='minor', labelsize=16)
    ax.set_xlabel("$\\varepsilon$", fontsize=28)
    ax.set_title('$\\mathbf{\\alpha} = '+str(alpha)+'$',fontsize=18)
    plt.yticks([0.0,0.1,0.2,0.3,0.4,0.5], [0.0,0.1,0.2,0.3,0.4,0.5])
lgd = ax.legend(fontsize=18,bbox_to_anchor=(0.1, 0.6, 1.5, 0.102))
#plt.tight_layout()
filename = './compare_Qeps.pdf'
plt.savefig(filename, bbox_extra_artists=(lgd,), bbox_inches='tight', dpi=300)



In [ ]: