In [ ]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
%load_ext autoreload
%autoreload 2
In [ ]:
from numpy.random import rand, randn
In [ ]:
n, d, k = 100, 2, 2
In [ ]:
np.random.seed(20)
X = rand(n, d)
# means = [rand(d) for _ in range(k)] # works for any k
means = [rand(d) * 0.5 + 0.5 , - rand(d) * 0.5 + 0.5] # for better plotting when k = 2
S = np.diag(rand(d))
sigmas = [S]*k # we'll use the same Sigma for all clusters for better visual results
print(means)
print(sigmas)
In [ ]:
def compute_log_p(X, mean, sigma):
''' fill your code in here...
'''
In [ ]:
log_ps = [compute_log_p(X, m, s) for m, s in zip(means, sigmas)] # exercise: try to do this without looping
In [ ]:
assignments = np.argmax(log_ps, axis=0)
print(assignments)
In [ ]:
colors = np.array(['red', 'green'])[assignments]
plt.scatter(X[:, 0], X[:, 1], c=colors, s=100)
plt.scatter(np.array(means)[:, 0], np.array(means)[:, 1], marker='*', s=200)