In [ ]:
%matplotlib inline
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sns
In [2]:
from bernoullimix import BernoulliMixture, random_mixture_generator
In [3]:
K, D = 4, 5
TRUE_PI = np.array([0.5, 0.25, 0.15, 0.1])
TRUE_P = np.array([[0.9, 0.2, 0.3, 0.1, 0.83],
[0.8, 0.8, 0.2, 0.9, 0.4],
[0.1, 0.1, 0.9, 0.8, 0.2],
[0.1, 0.4, 0.5, 0.05, 0.04]])
RANDOM_STATE = 1207
In [4]:
true_mixture = BernoulliMixture(K, D, TRUE_PI, TRUE_P)
In [5]:
N_SAMPLES = 3000
N_SPLIT_HALF = N_SAMPLES // 2
N_FULL_HALF = N_SPLIT_HALF + N_SAMPLES % 2
FULL_HALF_MISSING_COLUMNS = np.array([False, True, False, True, False])
In [6]:
full_observations, true_components = true_mixture.sample(N_SAMPLES, random_state=RANDOM_STATE)
In [7]:
full_observations = pd.DataFrame(full_observations) # todo: it should already be df
In [8]:
missing_observations = full_observations.copy()
missing_observations.iloc[:N_SPLIT_HALF//2, :D//2] = None
missing_observations.iloc[N_SPLIT_HALF//2:N_SPLIT_HALF, D//2:] = None
missing_observations.iloc[N_SPLIT_HALF:, FULL_HALF_MISSING_COLUMNS] = None
In [9]:
missing_observations.tail()
Out[9]:
In [10]:
missing_observations.isnull().apply(tuple, axis=1).value_counts()
Out[10]:
In [11]:
N_MIXTURES_TO_FIT = 500
In [12]:
generator = random_mixture_generator(K, missing_observations, random_state=RANDOM_STATE)
In [13]:
import itertools
mixtures = list(itertools.islice(generator, N_MIXTURES_TO_FIT))
In [14]:
%%time
convergences = []
lls = []
for mixture in mixtures:
ll, convergence = mixture.fit(missing_observations, iteration_limit=None, convergence_threshold=1e-3)
convergences.append(convergence)
lls.append(ll)
In [15]:
lls = pd.Series(lls)
In [16]:
best_mixture = mixtures[lls.argmax()]
In [ ]:
In [17]:
best_mixture.emission_probabilities
Out[17]:
In [18]:
answer_pis = pd.Series(best_mixture.mixing_coefficients)
answer_ems = pd.DataFrame(best_mixture.emission_probabilities,
index=answer_pis.index,
columns=missing_observations.columns)
In [19]:
answer_pis = answer_pis.sort_values(ascending=False)
answer_ems = answer_ems.loc[answer_pis.index]
In [22]:
plt.figure(figsize=(15, 5))
ax = plt.subplot(1, 2, 1)
sns.heatmap(np.array([TRUE_PI]), annot=True, ax=ax, vmin=0, vmax=1)
plt.title('True')
ax = plt.subplot(1, 2, 2)
sns.heatmap(np.array([answer_pis]), annot=True, ax=ax, vmin=0, vmax=1)
plt.title('Predicted')
Out[22]:
In [20]:
plt.figure(figsize=(15, 5))
ax = plt.subplot(1, 2, 1)
sns.heatmap(TRUE_P, annot=True, ax=ax, vmin=0, vmax=1)
plt.title('True')
ax = plt.subplot(1, 2, 2)
sns.heatmap(answer_ems, annot=True, ax=ax, vmin=0, vmax=1)
plt.title('Predicted')
Out[20]:
In [ ]: