Sparsified K-means Heuristic Plots

This file generates a bunch of figures showing heuristically how sparsified k-means works.


In [1]:
import matplotlib as mpl
mpl.use('Agg') 
import matplotlib.pyplot as plt
import numpy as np
from scipy.linalg import hadamard
from scipy.fftpack import dct
%matplotlib inline

In [2]:
n = 10    #dimension of data (rows in plot)
K = 3     #number of centroids
m = 4     #subsampling dimension
p = 6     #number of observations (columns in plot)
np.random.seed(0)
DPI = 300 #figure DPI for saving

Data Matrices

X is the data matrix, U is the centroid matrix. We define these, and then set up a few tricks to mask certain regions of the arrays for plotting selected columns. The idea is to generate two copies of X, one which is used to plot only the column in which we are interested, the other of which is used to plot the remainder of the data. We make two different plots so that we can set different alpha values (transparency) for the column in question vs. the rest of the data.


In [12]:
def this_is_dumb(x):
    """ Surely there's a better way but this works. Permute X"""
    y = np.copy(x)
    np.random.shuffle(y)
    return y

## Preconditioning plot

# Unconditioned
vals_unconditioned = [i for i in range(-5,5)]
X_unconditioned = np.array([this_is_dumb(vals_unconditioned) for i in range(p)]).T

# Conditioned
D = np.diag(np.random.choice([-1,1],n))
X_conditioned = dct(np.dot(D,X_unconditioned), norm = 'ortho')

## Subsampling plots

# Define the entries to set X
vals = [1 for i in range(m)]
vals.extend([0 for i in range(n-m)])


# Define X by permuting the values.
X = np.array([this_is_dumb(vals) for i in range(p)]).T

# means matrix
U = np.zeros((n,K))

# This is used to plot the full data in X (before subsampling)
Z = np.zeros_like(X)    

# Generate two copies of X, one to plot just the column in question (YC) and one to plot the others (YO)
def get_col_X(col):
    YO = np.copy(X)
    YO[:,col-1] = -1

    YC = - np.ones_like(X)
    YC[:,col-1] = X[:,col-1]
    return [YO,YC]

# Generate a copy of U modified to plot the rows selected by the column we chose of X
def get_rows_U(col):
    US = np.copy(U)
    US[np.where(X[:,col-1]==1)[0],:]=1
    return US

Color Functions

We import the colors from files called 'CM.txt' (for main colors) and 'CA.txt' (for alternate colors). These text files are generated from www.paletton.com, by exporting the colors as text files. The text parsing is hacky but works fine for now. This makes it easy to try out different color schemes by directly exporting from patellon.

We use the colors to set up a colormap that we'll apply to the data matrices. We manually set the boundaries on the colormap to agree with how we defined the various matrices above. This way we can get different colored blocks, etc.


In [4]:
def read_colors(path_in):
    """ Crappy little function to read in the text file defining the colors."""
    mycolors = []
    with open(path_in) as f_in:
        lines = f_in.readlines()
        for line in lines:
            line = line.lstrip()
            if line[0:5] == 'shade':
                mycolors.append(line.split("=")[1].strip())
    return mycolors

CM = read_colors('CM.txt')
CA = read_colors('CA.txt')
CD = ['#404040','#585858','#989898']

# Set the axes colors
mpl.rc('axes', edgecolor = CD[0], linewidth = 1.3)

# Set up the colormaps and bounds
cmapM = mpl.colors.ListedColormap(['none', CM[1], CM[3]])
cmapA = mpl.colors.ListedColormap(['none', CA[1], CA[4]])

bounds = [-1,0,1,2]
normM = mpl.colors.BoundaryNorm(bounds, cmapM.N)
normA = mpl.colors.BoundaryNorm(bounds, cmapA.N)

bounds_unconditioned = [i for i in range(-5,6)]
cmap_unconditioned = mpl.colors.ListedColormap(CA[::-1] + CM)
norm_unconditioned = mpl.colors.BoundaryNorm(bounds_unconditioned, cmap_unconditioned.N)

Plotting Functions


In [5]:
def drawbrackets(ax):
    """ Way hacky. Draws the brackets around X. """
    ax.annotate(r'$n$ data points', xy=(0.502, 1.03), xytext=(0.502, 1.08), xycoords='axes fraction', 
            fontsize=14, ha='center', va='bottom',
            arrowprops=dict(arrowstyle='-[, widthB=4.6, lengthB=0.35', lw=1.2))
    
    ax.annotate(r'$p$ dimensions', xy=(-.060, 0.495), xytext=(-.22, 0.495), xycoords='axes fraction', 
            fontsize=16, ha='center', va='center', rotation = 90,
            arrowprops=dict(arrowstyle='-[, widthB=6.7, lengthB=0.36', lw=1.2, color='k'))
    
def drawbracketsU(ax):
    ax.annotate(r'$K$ centroids', xy=(0.505, 1.03), xytext=(0.505, 1.08), xycoords='axes fraction', 
            fontsize=14, ha='center', va='bottom',
            arrowprops=dict(arrowstyle='-[, widthB=2.25, lengthB=0.35', lw=1.2))
    
def formatax(ax):
    """ Probably want to come up with a different way to do this. Sets a bunch of formatting options we want. """
    ax.tick_params(
    axis='both',       # changes apply to both axis
    which='both',      # both major and minor ticks are affected
    bottom='off',      # ticks along the bottom edge are off
    top='off',         # ticks along the top edge are off
    left='off',
    right='off',
    labelbottom='off',
    labelleft = 'off') # labels along the bottom edge are off
    ax.set_xticks(np.arange(0.5, p-.5, 1))
    ax.set_yticks(np.arange(0.5, n-.5, 1))
    ax.grid(which='major', color = CD[0], axis = 'x', linestyle='-', linewidth=1.3)
    ax.grid(which='major', color = CD[0], axis = 'y', linestyle='--', linewidth=.5)
    
def drawbox(ax,col):
    """ Draw the gray box around the column. """
    s = col-2
    box_X = ax.get_xticks()[0:2]
    box_Y = [ax.get_yticks()[0]-1, ax.get_yticks()[-1]+1]
    box_X = [box_X[0]+s,box_X[1]+s,box_X[1]+s,box_X[0]+s, box_X[0]+s]
    box_Y = [box_Y[0],box_Y[0],box_Y[1],box_Y[1], box_Y[0]]
    ax.plot(box_X,box_Y, color = CD[0], linewidth = 3, clip_on = False)

    
def plot_column_X(ax,col):
    """ Draw data matrix with a single column highlighted. """
    formatax(ax)
    drawbrackets(ax)
    drawbox(ax,col)
    YO,YC = get_col_X(col)
    ax.imshow(YO,
          interpolation = 'none',
          cmap=cmapM,
          alpha = 0.8,
          norm=normM)

    ax.imshow(YC,
          interpolation = 'none',
          cmap=cmapM,
          norm=normM)

def plot_column_U(ax,col):
    """ Draw means matrix with rows corresponding to col highlighted. """
    formatax(ax)
    drawbracketsU(ax)
    US = get_rows_U(col)
    ax.imshow(US,
          interpolation = 'none',
          cmap=cmapA,
          norm=normA)
    
def plot_column_selection(col,fn,save=False):
    """ This one actually generates the plots. Wraps plot_column_X and plot_column_U, 
    saves the fig if we want to."""
    fig = plt.figure()
    gs = mpl.gridspec.GridSpec(1,2, height_ratios=[1])
    ax0 = plt.subplot(gs[0])
    ax1 = plt.subplot(gs[1])

    plot_column_X(ax0,col)
    plot_column_U(ax1,col)
    if save == True:
        fig.savefig(fn,dpi=DPI)
    else:
        plt.show()

Generate the Plots


In [13]:
fig = plt.figure()
gs = mpl.gridspec.GridSpec(1,2, height_ratios=[1])
ax0 = plt.subplot(gs[0])   
formatax(ax0)
drawbrackets(ax0)
ax0.imshow(X_unconditioned,
          interpolation = 'none',
          cmap=cmap_unconditioned,
          norm=norm_unconditioned)
ax1 = plt.subplot(gs[1])
formatax(ax1)
ax1.imshow(X_conditioned,
          interpolation = 'none',
          cmap=cmap_unconditioned,
          norm=norm_unconditioned)
#ax1.imshow(X_unconditioned,
#          interpolation = 'none',
#          cmap=cmap_unconditioned,
#          norm=norm_unconditioned)

plt.show()



In [7]:
# Make a plot showing the system before we subsample. 

fig = plt.figure()
gs = mpl.gridspec.GridSpec(1,2, height_ratios=[1])
ax0 = plt.subplot(gs[0])   
formatax(ax0)
drawbrackets(ax0)
ax0.imshow(Z,
          interpolation = 'none',
          cmap=cmapM,
          norm=normM)

ax1 = plt.subplot(gs[1])
formatax(ax1)
drawbracketsU(ax1)

ax1.imshow(U,
          interpolation = 'none',
          cmap=cmapA,
          norm=normA)
plt.show()
fig.savefig('mat0.png',dpi=DPI)



In [8]:
# Plot the subsampled system.

fig = plt.figure()
gs = mpl.gridspec.GridSpec(1,2, height_ratios=[1])
ax0 = plt.subplot(gs[0])   
formatax(ax0)
drawbrackets(ax0)
ax0.imshow(X,
          interpolation = 'none',
          cmap=cmapM,
          norm=normM)

ax1 = plt.subplot(gs[1])
formatax(ax1)
drawbracketsU(ax1)

ax1.imshow(U,
          interpolation = 'none',
          cmap=cmapA,
          norm=normA)
plt.show()
fig.savefig('mat1.png',dpi=DPI)



In [9]:
# Pick out the first column.

fig = plt.figure()
gs = mpl.gridspec.GridSpec(1,2, height_ratios=[1])
ax0 = plt.subplot(gs[0])   
formatax(ax0)
drawbrackets(ax0)
drawbox(ax0,1)

plot_column_X(ax0,1)


ax1 = plt.subplot(gs[1])
formatax(ax1)
drawbracketsU(ax1)

ax1.imshow(U,
          interpolation = 'none',
          cmap=cmapA,
          norm=normA)
plt.show()

fig.savefig('mat2.png',dpi=DPI)



In [10]:
# make all 6 "final plots". 
for i in range(1,p+1):
    fn = 'col' + str(i) + '.png'
    plot_column_selection(i,fn,save=True)