In [61]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import os
Конечные формулы для апостериорных распределений $p(z_i|x_i, \theta)$ вычисляются по стандартной формуле Байеса, при условии равномерного априорного распределения на $z_i$: $$ p(z_i|x_i, \theta) = \frac{p(x_i|z_i, \theta)}{\sum_k p(x_i|z_i = k, \theta)} $$ А так как $$ p(x_i|z_i, \theta) = \prod_{i=1}^D (\theta_{k,i})^{x_{n,i}}(1-\theta_{k,i})^{(1-x_{n,i})} $$ То мы можем уже сейчас реализовать функцию $\texttt{posterior}$. Так как знаменатель получившейся дроби -- это именно функция правдаподобия, то объединим две функции в одну для предоствращения повторного запуска одного и того же кода
In [901]:
def posterior(x, clusters):
'''Returns z: (K x D) array of posterior probability of object 'i' to be in cluster 'k' '''
return posterior_and_loglike(x,clusters)[0]
def likelihood(x, clusters, ignore_neginf = False):
'''Returns ll -- numpy integer of loglikelihood.'''
if ignore_neginf:
#for test sample sometimes we have horrible prediction, ignote them when counting loglike
ll = posterior_and_loglike2(x,clusters)[1]
return sum(ll[np.where(ll > -1e200)]) - x.shape[0] * np.log( clusters.shape[0] )
else:
return sum(posterior_and_loglike(x,clusters)[1]) - x.shape[0] * np.log( clusters.shape[0] )
def posterior_and_loglike(x, clusters):
'''Returns z and ll. See functions posterior and likelihood'''
K, D, N = clusters.shape[0], clusters.shape[1], x.shape[0]
# clip clusters for test sample
clusters = np.clip(clusters, 0, 1)
# initialize z
z = np.zeros( (K, N) , dtype=np.float)
#fill it by row with log's
for k in xrange( K ):
z[k,:] = np.sum(np.where(x > 0, np.log(clusters[k]), np.log(1-clusters[k]) ), axis = 1 )
#substract the largest element for computational stability
z = np.nan_to_num(z)
largest = np.max( z, axis = 0 )
z -= largest
#denominator that is the sum of all elements (by row)
denom = np.log( np.sum( np.exp( z ), axis = 0 ) )
assert not np.any(np.isinf(denom))
assert not np.any(np.isinf(largest))
return np.exp( z - denom ), denom + largest
Далее реализовываем E-шаг алгоритма. Необходимо вычислить $$ \theta^{new} = \arg \max_{\theta^{old}} \mathbb{E}[\mathcal{L(\theta^{old})}] $$ Как выясняется, найти максимум можно аналитически. $$ \mathbb{E}[\mathcal{L(\theta)}] = \sum_{n=1}^N \sum_{k=1}^K \mathbb{E}[\mathcal{z_{n,k}}] \left( \sum_{i=1}^D x_{n,i}\ln \theta_{k,i} + (1-x_{n,i})\ln(1-\theta_{k,i}) \right) $$ $$ \frac{\partial}{\partial \theta} \mathbb{E}[\mathcal{L(\theta})]=\sum_{n=1}^N \mathbb{E}[\mathcal{z_{n,k}}] \frac{x_{n,i} - \theta_{k,i}}{\theta_{k,i}(1-\theta_{k,i})} \Rightarrow $$ $$ \theta_{k,i} = \frac{\sum_{n=1}^N x_{n,i}z_{n,k} }{\sum_{n=1}^N z_{n,k}} $$ С точки зрения кода это очень легко реализовать
In [348]:
def learn_clusters(x, z):
'''Returns clusters: K x D array, each row describes k-th row of Bernoulli distribution'''
clusters = np.dot(z, x) / np.sum(z, axis=1)[:,None]
return clusters
Реализуем также $\texttt{em_algorithm}$. Немного отойдём от спецификации, добавив параметр stop_iter_percent -- когда процентное изменение среднего правдоподобия будет на очередном шаге меньше этого параметра, то алгоритм выходит из цикла, чтобы избежать лишних ненужных подсчётов
In [520]:
def em_algorithm(x, K, maxiter, stop_iter_percent = None, visualize = True, verbose = True):
assert stop_iter_percent is not None and stop_iter_percent < 1 and stop_iter_percent > 0, "Are u serious?!"
N, D = x.shape
clusters = rand.uniform( size = (K, x.shape[1]) )
#initialize clusters and loglikes histories
clusters_hist, loglikes = clusters[None,:].copy(), np.array( [likelihood(x, clusters)] )
#for checking previous loglike values:
loglike_past = -np.inf
#do algorithm maxiter times
for numiter in xrange(maxiter):
z, loglike = posterior_and_loglike(x, clusters)
loglike = sum(loglike) - N * np.log( K )
loglikes = np.append(loglikes, loglike)
clusters = learn_clusters(x, z)
clusters_hist = np.vstack( (clusters_hist, clusters[None, :]) )
if verbose:
print "N_iter=%i, current average loglike=%.3f"%(numiter, loglike/N)
if visualize:
show_data(clusters, n = K, n_col = min(K,12), cmap = plt.cm.hot, interpolation = 'nearest')
if stop_iter_percent is not None and abs((loglike-loglike_past)/loglike_past) < stop_iter_percent:
if verbose:
print "N_iter=%i successful converge! Final average loglike=%.3f"%(numiter, loglike/N)
break
loglike_past = loglike
return loglikes, clusters_hist
Загрузим обучающие и тестовые выборки
In [ ]:
raw_train = np.loadtxt( open( 'mnist_train.csv', "rb" ), delimiter = ",", skiprows = 0 )
mnist_labels, mnist_data = raw_train[:,:1], raw_train[:,1:]
raw_test = np.loadtxt( open( 'mnist_test.csv', "rb" ), delimiter = ",", skiprows = 0 )
test_labels, test_data = raw_test[:,:1], raw_test[:,1:]
Также подготовим визуализацию картинок
In [122]:
def arrange_flex( images, n_row = 10, n_col = 10, N = 28, M = 28, fill_value = 0 ) :
## Function by Ivan Nazarov and Anna Potapenko
## Create the final grid of images row-by-row
im_grid = np.full( ( n_row * N, n_col * M ), fill_value, dtype = images.dtype )
for k in range( min( images.shape[ 0 ], n_col * n_row ) ) :
## Get the grid cell at which to place the image
i, j = ( k // n_col ) * N, ( k % n_col ) * M
## Just put the image in the cell
im_grid[ i:i+N, j:j+M ] = np.reshape( images[ k ], ( N, M, ) )
return im_grid
def show_data( data, n, n_col = 10, transpose = False, **kwargs ) :
## Function by Ivan Nazarov and Anna Potapenko
## Get the number of rows necessary to plot the needed number of images
n_row = ( n + n_col - 1 ) // n_col
## Transpose if necessary
if transpose :
n_col, n_row = n_row, n_col
## Set the dimensions of the figure
fig = plt.figure( figsize = ( n_col, n_row ) )
axis = fig.add_subplot( 111 )
axis.imshow( arrange_flex( data[:n], n_col = n_col, n_row = n_row ), **kwargs )
## Plot
plt.show( )
Берём 6 и 9
In [ ]:
data = mnist_data[ np.where( mnist_labels == 6 )[0] | np.where( mnist_labels == 9 )[0] ]
In [511]:
loglikes, clusters_hist = em_algorithm(data, 2, 30, stop_iter_percent=0.001)
In [512]:
plt.plot(loglikes);
Как видим, уже за 12 итераций алгоритм сходится к максимуму неполного правдоподобия. Последние графики шаблонов вполне напоминают "типовые" 6 и 9. Заметим, что округлая часть цифр более плотная, сравнивая с "хвостом". Это связано с тем, что хвост может лежать под разным углом, а расположение круглой части в среднем более сконцентрировано.
In [513]:
loglikes, clusters_hist = em_algorithm(mnist_data, 10, 50, stop_iter_percent=0.001)
In [514]:
loglikes, clusters_hist = em_algorithm(mnist_data, 15, 50, stop_iter_percent=0.001)
In [517]:
loglikes, clusters_hist = em_algorithm(mnist_data, 20, 50, stop_iter_percent=0.001)
Как видно, $k=10$ недостаточно для того, чтобы выделить шаблоны всех цифр, при этом при $k=15$ для каждой цифры есть шаблон. Однако заметим также, что уже при $k=15$ появляется некий странный шаблон (правый верхний), вероятно, аккумулирующий все необычные написания разных цифр. Также заметим, что цифры 5 и 6, 9 и 4, 3 и 9 не всегда легко отличить друг от друга.
Наконец, цифры 0, 6, 9, и 2 требуют нескольких шаблонов. Это связано с существованием большого числа вариантов их написания у разных людей, и алгоритм заводит разные шаблоны под разное написание цифр.
Попробуем теперь увеличивать число $K$ до очень больших
In [573]:
res_different_K = {}
for K in range(15,80,3):
loglikes, clusters_hist = em_algorithm(mnist_data, K, 50, stop_iter_percent=0.001, verbose = False, visualize=False)
res_different_K[K] = (loglikes, clusters_hist)
print "K=%i successful. Loglike = %.3f"%(K, loglikes[-1])
In [ ]:
for K in range(80,200,7):
loglikes, clusters_hist = em_algorithm(mnist_data, K, 50, stop_iter_percent=0.001, verbose = False, visualize=False)
res_different_K[K] = (loglikes, clusters_hist)
print "K=%i successful. Loglike = %.3f"%(K, loglikes[-1])
График для разных $K$
In [759]:
plt.plot(*zip(*(sorted(res_different_K_loglike.items(), key=itemgetter(0)))));
К сожалению, результат неутешительный. С ростом $k$ должен был бы ожидаться, с какого-то момента, подение правдоподобия. Однако это не проивзошло. Причина может крыться в размерах выборки (очень большая выборка позволяет иметь большое число степеней свободы), либо банальный overfitting, не наблюдаемый на loglikelihood.
Посмотрим корреляцию правдоподобия тестовой и обучающей выборок
In [913]:
loglikes = [ [ items[0][-1] / mnist_data.shape[0], likelihood(test_data, items[1][-1], ignore_neginf=True) / test_data.shape[0] ]
for K, items in res_different_K.items() ]
In [943]:
ax = plt.subplot()
ax.scatter(*zip(*loglikes))
ax.set_xlabel('Train sample avg loglike')
ax.set_ylabel('Test sample avg loglike');
Результат достаточно неожиданный: переобученность абсолютно не наблюдается, корреляция невероятно высока
Когда есть апостериорная вероятность объекта, то выбрать шаблон очень просто:
In [561]:
def classify(data, clusters):
z = posterior(data, clusters)
return np.argmax(z, axis=0)
Так как у меня не удалось достичь максимума правдоподобия на тестовой выборке, пришлось ограничиться неким разумным $k$. Я остановился на $k=75$, то есть в среднем одной цифре будет соответствовать $7.5$ шаблонов. Получившиеся шаблоны:
In [652]:
show_data( optim_clusters[-1], n = optim_k, n_col = 17, cmap = plt.cm.hot, interpolation = 'nearest' )
Разметим обучающую выборку
In [586]:
classified = classify(mnist_data,optim_clusters[-1])
И посмотрим на распределение числа наблюдений в кластере.
In [594]:
from collections import Counter
c = Counter(classified)
In [779]:
plt.hist(c.values());
print "Median number of items in one cluster: %i" %np.median(c.values())
Как видно, у нас почти 700 объектов в классе, однако есть малочисленные шаблоны (например, закорючка во втором снизу ряду). Присвоим каждому шаблону наиболее распространенную в нем цифру, при этом малочисленные шаблону (менее 100 наблюдений максимум) будем игнорировать.
In [647]:
labeled_clusters = np.empty(len(c), dtype = int)
In [649]:
for k in range(len(c)):
occurs = Counter(mnist_labels[classified == k] )
labeled_clusters[k] = occurs.most_common(1)[0][0] if occurs.most_common(1)[0][1] > 100 else -1
In [773]:
labeled_clusters
Out[773]:
Классифицируем тестовую выборку.
In [655]:
test_classified = classify(test_data,optim_clusters[-1])
In [678]:
print "Correctly specified: %i" % sum(labeled_clusters[test_classified] == test_labels)
print "Could not specify: %i" % sum(labeled_clusters[test_classified] == -1)
print "False specified: %i" % sum(labeled_clusters[test_classified] != test_labels)
print "Percentage of correct: %.f%%" % (100.*( sum(labeled_clusters[test_classified] == test_labels)) / len(test_classified) )
In [1024]:
#error counts matrix. Row -- true value, column -- predicted
error_counts = np.empty((10,10))
In [1040]:
for i,j in np.ndindex((10,10)):
error_counts[i,j] = sum( (test_labels == i) & (labeled_clusters[test_classified] == j) )
error_counts = error_counts
In [1042]:
np.fill_diagonal(error_counts, 0)
In [1046]:
fig, ax = plt.subplots()
heatmap = ax.pcolor(error_counts, cmap=plt.cm.Reds)
# put the major ticks at the middle of each cell, notice "reverse" use of dimension
ax.set_yticks(np.arange(error_counts.shape[0])+0.5, minor=False)
ax.set_ylabel('True')
ax.set_xticks(np.arange(error_counts.shape[1])+0.5, minor=False)
ax.set_xlabel('Predicted')
ax.set_xticklabels(range(10), minor=False)
ax.set_yticklabels(range(10), minor=False)
fig.colorbar(heatmap)
plt.show()
Как видно из графика, чаще всего мы ошибаемся на девятке, присваивая ей четверку, и наоборот. Также неочевидной проблемой оказалась пара 7-9, при этом мы чаще предсказываем девятку, а на самом деле это семерка. Это может говорить о переобученности девятки, мы слишком часто ей предстказываем.
Также интересно будет посмотреть на примеры цифр, которые мы верно классифицировали, а также на цифры, в которых мы допустили ошибку.
In [ ]:
false_classified_id = np.where(labeled_clusters[test_classified] != test_labels)[0]
correct_classified_id = np.where(labeled_clusters[test_classified] == test_labels)[0]
Правильно классифицированные:
In [709]:
show_data( test_data[ rand.choice( correct_classified_id, 102, replace = True ) ],
n = 102, n_col = 17, cmap = plt.cm.hot, interpolation = 'nearest' )
Неправильно классифицированные
In [715]:
show_data( test_data[ rand.choice( false_classified_id, 102, replace = True ) ],
n = 102, n_col = 17, cmap = plt.cm.hot, interpolation = 'nearest' )