In [1]:
import numpy as np
import matplotlib
%matplotlib inline
import matplotlib.pyplot as plt
import ds100
In [2]:
np.random.seed(13337)
c1 = np.random.randn(25, 2)
c2 = np.array([2, 8]) + np.random.randn(25, 2)
c3 = np.array([8, 4]) + np.random.randn(25, 2)
x1 = np.vstack((c1, c2, c3))
g1 = np.repeat([0, 1, 2], 25)
ds100.scatter2d_grouped(x1, g1)
Let's just run this algorithm as a black box for the moment to see how reasonably it performs.
In [3]:
example1 = ds100.kmeans(x1, 3)
example1.run()
example1.plot()
Dang, looks pretty good up to label permutation. With such promising results, we should pry under the hood a little. We discover that K-means can be described as follows:
This seems fairly innocuous. There seems to be a couple of ways for basic k-means to start off:
kmeans
object included in the ds100
module does this by default.centers
argument when instantiating the kmeans
object.What do you think? When would you pick one method over another?
Notation: For the remainder of the discussion, we'll refer to the K clusters as $C_1, C_2, ..., C_K$ and their specific coordinates as $c_1, c_2, ..., c_K$.
How do we measure "closeness"? With k-means, we use the square euclidean distance. Formally, for any two points $x$ and $c$, each a vector with $p$ coordinates (for the $p$ features) we can write this "dissimilarity" as: $$d(x, c) = \lVert x-c \rVert_2^2 = \sum_{j=1}^p (x_j - c_j)^2$$
With this measure, assign each of the $n$ data points, $x_i$, $i \in \{1, 2, 3, ..., n\}$, to the cluster $C_k$ that is closest to it.
It turns out that this "rule" isn't well-defined. When is there ambiguity? How do you propose we fix this?
Now that the cluster assignments have changed, we need to find their centers. This is a straightfoward calculation. For each cluster, we just take the average of all the points assigned to that cluster:
$$c_k = \frac{1}{|C_k|}\sum_{i \in C_k} x_i$$where $|C_k|$ is the size of the cluster.
In [4]:
manual_centers = np.array([[0,1], [1,1], [2,2]])
example2 = ds100.kmeans(x1, k = 3, centers = manual_centers)
example2.plot(colored=False)
We'll run the algorithm one step at a time to see what happens. First, let's assign clusters.
In [5]:
example2._update_clusters()
example2.show_clusters()
example2.plot()
Now to update the cluster centers. It seems like at least one center is moving in a reasonable direction.
In [6]:
example2._update_centers()
example2.show_centers()
example2.plot()
Continuing with a second cluster assignment... Doesn't look like much has changed.
In [7]:
example2._update_clusters()
example2.show_clusters()
example2.plot()
Letting k-means run on its own reveals that we could have run this for one more step, but it still stops in a pretty bad place. So indeed, k-means can fail to find a global optimum if it is seeded with a bad start.
In [8]:
example2 = ds100.kmeans(x1, k = 3, centers = manual_centers)
example2.run()
example2.summary()
example2.plot()
Despite its shortcomings, we should talk about its advantages.
With that said, let's build some more intuition behind what the algorithm is doing. Remember from lecture that the k-means objective function can be written as:
$$argmin_{C_1,...,C_K} \sum_{k=1}^K\sum_{i \in C_k} d(x_i, c_k) = argmin_{C_1,...,C_K} \sum_{k=1}^K\sum_{i \in C_k} \lVert x_i - c_k \rVert_2^2$$In words: find the cluster assignments such that the sum of squares within clusters is minimized. Imagine drawing squares at each data point where one vertex is on the data point and the other is on its cluster center. Add up the area of all those squares. That is what we're trying to minimize by shuffling the data around to different clusters.
Consider the following data. Do you see any "natural" groupings?
In [9]:
x3 = np.genfromtxt('example3.csv', delimiter=',')
x3, g3 = x3[:,:2], x3[:,2]
plt.scatter(x3[:,0], x3[:,1], color=plt.cm.Dark2(.5))
Most people would pick out the following pattern. Doesn't seem too unreasonable.
In [10]:
ds100.scatter2d_grouped(x3, g3)
It turns out that k-means will pick something completely different. That pokeball though...
In [11]:
example3 = ds100.kmeans(x3, 2)
example3.run()
example3.plot()
So what's happening here? Remember k-mean's objective function: minimize the sum of squares within clusters. Placing both centers at the origin and assigning the "natural" clusters would produce one "tight" cluster with small squares, but this is heavily overshadowed by the large squares resulting from the data points on the outer ring. In other words, k-means prefers clusters that are "separate balls of points".
Aside: This particular situation can actually be salvaged with k-means if we want to recover the "natural" clusters by transforming the data to polar coordinates.
In [12]:
r = np.sqrt(x3[:,0]**2 + x3[:,1]**2)
theta = np.arctan(x3[:,0] / x3[:,1])
x3_xformed = np.hstack((r[:, np.newaxis], theta[:, np.newaxis]))
example3xf = ds100.kmeans(x3_xformed, 2)
example3xf.run()
example3xf.plot()
Transforming back to cartesian coordinates:
In [13]:
ds100.scatter2d_grouped(x3, example3xf.clusters)
Consider the data below. There are two groups of different sizes in two different senses. The smaller group has both smaller variability and is less numerous. The larger of the two groups is more diffuse and populated. What do you think happens when we run k-means and why?
In [14]:
c1 = 0.5 * np.random.randn(25, 2)
c2 = np.array([10, 10]) + 3*np.random.randn(475, 2)
x4 = np.vstack((c1, c2))
g4 = np.repeat([0, 1], [25, 475])
ds100.scatter2d_grouped(x4, g4)
Oi, it looks like it split up the larger group. Again this is all due to the nature of the objective function. k-means, in its quest for tightness, will happily split big clouds to minimize the sum of squares.
In [15]:
example4 = ds100.kmeans(x4, 2)
example4.run()
example4.plot()
Even with the true centers of the data generating process chosen, we still observe the k-means really wants to leech points off the large cluster.
In [16]:
smart_centers = [[0, 0], [10, 10]]
example4 = ds100.kmeans(x4, 2, centers = smart_centers)
example4.run()
example4.plot()
It's worth noting that this is mitigated if the different clusters are of the same size. The inertial mass of the data keeps the cluster center from moving too far away. Notice the outlier point that does get swallowed up in the orbit of the bottom-left cloud though.
In [17]:
c1 = 0.5 * np.random.randn(250, 2)
c2 = np.array([10, 10]) + 3*np.random.randn(250, 2)
x5 = np.vstack((c1, c2))
g5 = np.repeat([0, 1], [250, 250])
ds100.scatter2d_grouped(x5, g5)
In [18]:
example5 = ds100.kmeans(x5, 2)
example5.run()
example5.plot()
Let's take a look at this data. Qualitatively, what are some properties of the groups?
In [19]:
c1 = np.random.multivariate_normal([-1.5,0], [[.5,0],[0,4]], 100)
c2 = np.random.multivariate_normal([1.5,0], [[.5,0],[0,4]], 100)
c3 = np.random.multivariate_normal([0, 6], [[4,0],[0,.5]], 100)
x6 = np.vstack((c1, c2, c3))
g6 = np.repeat([0, 1, 2], 100)
ds100.scatter2d_grouped(x6, g6)
There are two groups with more variability in the vertical direction than the horizontal and one group where the opposite is true. Is this an issue for k-means? If so, what do you think is the root cause?
In [20]:
example6 = ds100.kmeans(x6, 3)
example6.run()
example6.plot()
In [21]:
example6 = ds100.kmeans(x6, 3)
example6.run()
example6.plot()
In [22]:
example6 = ds100.kmeans(x6, 3)
example6.run()
example6.plot()
So indeed k-means might struggle here as well, stemming precisely from the difference in the direction of intra-group variability. Recall that we are working with square euclidean distances. How might that explain these failure modes?
So we've seen a few examples where k-means fails to recover the true clusters in a plot. Under the hood, there seems to be preference for non-overlap (see Voronoi diagrams, similarly-sized groups, and equivariance ("spheres"). But perhaps we're being disingenuous here for several reasons:
These ideas are encapsulated in what are called No Free Lunch theorems, which in a nutshell says that any optimization algorithm that is trying to solve a real question is powered by hopes and dreams assumptions on the real world. Treat these objects as black boxes at your own peril.