In [53]:
%matplotlib inline
import numpy as np
import numpy.random as rng
import pylab as pl

In [54]:
N_categories = 20

Draw some proportion of the samples from each...


In [55]:
def make_sample(Cat_SRC, Cat_BGD, N_SRC, N_counts):
    N_BGD = N_counts - N_SRC
    counts_SRC = rng.multinomial(1, Categorical_SRC, N_SRC).sum(0)
    counts_BGD = rng.multinomial(1, Categorical_BGD, N_BGD).sum(0)
    return counts_SRC, counts_BGD

In [56]:
def get_logL_patch(counts, alphas):
    """
    Takes a vector of counts (across bins), and vector of alpha hyperparameters (ditto).
    Returns the log likelihood of those counts under the Dirichlet-multinomial distribution with
    those hyperparameters.
    """
    # initialize internal (static) variables
    try:
        get_logL_patch.Cmax
    except:
        get_logL_patch.Cmax = 1
        get_logL_patch.gammaln = np.array([np.Inf])
    N, A = np.sum(counts), np.sum(alphas)
    # calculate more lookup table if necessary
    if N + A > get_logL_patch.Cmax:
        print ("get_logL_patch(): Calculating expensive lookup shit")
        get_logL_patch.Cmax = 8 * (N + A)
        get_logL_patch.gammaln = np.append(np.array([np.Inf, 0.]), np.cumsum(np.log(np.arange(get_logL_patch.Cmax - 2) + 1)))
    
    # now the actual LF
    logL = get_logL_patch.gammaln[A] - get_logL_patch.gammaln[N + A] + np.sum(get_logL_patch.gammaln[counts + alphas]) - np.sum(get_logL_patch.gammaln[alphas]) 
    logL = logL + get_logL_patch.gammaln[N] - np.sum(get_logL_patch.gammaln[counts])
    return logL

In [80]:
C=100
true_alphas = C*np.ones(N_categories, dtype=int)
Categorical = rng.dirichlet(true_alphas)
#print(Categorical)
counts = rng.multinomial(1, Categorical, 20000).sum(0)
print(counts)

Cs = np.arange(1, 400)
scores = []
for C in Cs:
    test_alphas = C*np.ones(N_categories, dtype=int)
    logL = get_logL_patch(counts, test_alphas)
    scores.append(logL)
pl.plot(Cs, scores)
pl.draw()


[ 969 1163 1157 1013 1073 1079  967 1031 1013  932  696 1093 1060  883  999
  878  882 1093  963 1056]

In [70]:


In [ ]: