In [ ]:
import sys
print(sys.version)

In [ ]:
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import time

import pandas as pd
import seaborn as sns

In [ ]:
import sys
sys.path.append('../code/')

from mnist_helpers import mnist_training, mnist_testing
from k_means import KMeans
from pca import Pca, make_image

In [ ]:
! ls -l ./data/PCA_training_data.pickle

In [ ]:
import pickle
pca_training = pickle.load(file=open('./data/PCA_training_data.pickle', "rb"))

In [ ]:
X_train_untransformed, y_train = mnist_training(shuffled=False) 
X_train = np.load('../notebooks/data/X_transformed_by_50_components.npy')
print("X_train shape: {}.  y_train shape: {}".format(X_train.shape, y_train.shape))

X_test_untransformed, y_test = mnist_testing(shuffled=False)
X_test = np.load('../notebooks/data/X_test_transformed_by_50_components.npy')
print("X_test shape: {}.  y_test shape: {}".format(X_test.shape, y_test.shape))

In [ ]:
debug = True #True
if debug:
    N_points = 500
else:
    N_points = 60000
    assert X_train_untransformed.shape[0] == N_points
if debug:
    X_train = X_train[0:N_points,]
    y_train = y_train[0:N_points,]
    print("Update: X_train shape: {}.  y_train shape: {}".format(
            X_train.shape, y_train.shape))

In [ ]:
X_train.shape
make_image(X_train_untransformed[100])
make_image(pca_training.transform_number_up(X_train[1])) make_image(pca_training.transform_number_up(X_train[100]))
make_image(pca_training.transform_number_up(X_test[100]))

In [ ]:
km = KMeans(k=250, train_X=X_train, 
            train_y=y_train, 
            pca_obj=pca_training,
            max_iter = 1000, 
            test_X=X_test, test_y=y_test,
           verbose=False)
km.run()

In [ ]:
km.converged

In [ ]:
km.num_iter

In [ ]:
km.loss_01_normalized()
km.num_assignments_per_cluster()

In [ ]:
km.results_df.head()

In [ ]:
! ls ../figures

In [ ]:
p1 = km.plot_squared_reconstruction_error()

In [ ]:
p1 = km.plot_0_1_loss()

In [ ]:
p1 = km.plot_num_assignments_for_each_center()

In [ ]:


In [ ]:
km.visualize_n_centers(8, top=False, dir = '/Users/janet/Downloads')

In [ ]:
km.visualize_n_centers(8, top=True, dir = '/Users/janet/Downloads')

In [ ]: