In [1]:
import numpy as np
import os
import matplotlib
matplotlib.use('Qt5Agg')
import matplotlib.pyplot as plt
import scipy
import time

In [2]:
# Data generation
np.random.seed(1234)
mean1 = [1, 6]
cov1 = [[6, 0], [2, 8]]
x1, y1 = np.random.multivariate_normal(mean1,cov1,300).T

np.random.seed(1234)
mean2 = [-1, -6]
cov2 = [[5, 2], [1.3, 4]]
x2, y2 = np.random.multivariate_normal(mean2,cov2,300).T

np.random.seed(1234)
mean3 = [6, -2]
cov3 = [[3, 0], [0, 7]]
x3, y3 = np.random.multivariate_normal(mean3,cov3,300).T

x = np.hstack((x1,x2,x3))
y = np.hstack((y1,y2,y3))
data = np.vstack((x,y))

In [9]:
# Plot : labeled data points
plt.ion()
plt.figure()
plt.plot(x1,y1,'bo')
plt.plot(x2,y2,'ro')
plt.plot(x3,y3,'go')
plt.title('Ground truth')
plt.ylim(-15,15)
plt.xlim(-8,12)
plt.pause(0.5)
plt.show()
#input('Press <Enter> to continue')


C:\Program Files\Anaconda3\lib\site-packages\matplotlib\backend_bases.py:2437: MatplotlibDeprecationWarning: Using default event loop until function specific to this GUI is implemented
  warnings.warn(str, mplDeprecation)

In [10]:
# Plot : unlabeled data points
plt.figure()
plt.plot(x,y,'bo')
plt.title('Data to be clustered')
plt.ylim(-15,15)
plt.xlim(-8,12)
plt.pause(0.5)
plt.show()
#input('Press <Enter> to continue')


C:\Program Files\Anaconda3\lib\site-packages\matplotlib\backend_bases.py:2437: MatplotlibDeprecationWarning: Using default event loop until function specific to this GUI is implemented
  warnings.warn(str, mplDeprecation)

In [13]:
# K means clustering algorithm
nCluster = 3
np.random.seed(8888)
centerIdx = np.random.permutation(len(x))[0:nCluster]
centerX = x[centerIdx[0:nCluster]]
centerY = y[centerIdx[0:nCluster]]
dist = np.zeros((x.shape[0],nCluster))
plt.show()

In [17]:
for i in range(0,10):
    # Calculate distance from the centers to the data points
    for j in range(0,nCluster):
        dist[:,j] = np.sqrt(np.square(x-centerX[j])+np.square(y-centerY[j]))

    # A point belongs to the cluster of which center is the closest from here
    clusterIdx = np.argmin(dist,axis=1)

    # Update center point
    for j in range(0,nCluster):
        centerX[j] = np.mean(x[np.where(clusterIdx==j)])
        centerY[j] = np.mean(y[np.where(clusterIdx==j)])

    # Plot for real-time visualization
    plt.clf()
    plt.plot(x[np.where(clusterIdx==0)],y[np.where(clusterIdx==0)],'bo')
    plt.plot(x[np.where(clusterIdx==1)],y[np.where(clusterIdx==1)],'ro')
    plt.plot(x[np.where(clusterIdx==2)],y[np.where(clusterIdx==2)],'go')
    plt.plot(centerX,centerY,'kX', markerSize=20)
    plt.title('K means clustering (iteration='+str(i+1)+')')
    plt.ylim(-15,15)
    plt.xlim(-8,12)
    plt.pause(0.5)
    plt.show()
    plt.pause(0.5)


---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-17-b380f0ac0318> in <module>()
     17     plt.plot(x[np.where(clusterIdx==1)],y[np.where(clusterIdx==1)],'ro')
     18     plt.plot(x[np.where(clusterIdx==2)],y[np.where(clusterIdx==2)],'go')
---> 19     plt.plot(centerX,centerY,'kX', markerSize=20)
     20     plt.title('K means clustering (iteration='+str(i+1)+')')
     21     plt.ylim(-15,15)

C:\Program Files\Anaconda3\lib\site-packages\matplotlib\pyplot.py in plot(*args, **kwargs)
   3159         ax.hold(hold)
   3160     try:
-> 3161         ret = ax.plot(*args, **kwargs)
   3162     finally:
   3163         ax.hold(washold)

C:\Program Files\Anaconda3\lib\site-packages\matplotlib\__init__.py in inner(ax, *args, **kwargs)
   1816                     warnings.warn(msg % (label_namer, func.__name__),
   1817                                   RuntimeWarning, stacklevel=2)
-> 1818             return func(ax, *args, **kwargs)
   1819         pre_doc = inner.__doc__
   1820         if pre_doc is None:

C:\Program Files\Anaconda3\lib\site-packages\matplotlib\axes\_axes.py in plot(self, *args, **kwargs)
   1380         kwargs = cbook.normalize_kwargs(kwargs, _alias_map)
   1381 
-> 1382         for line in self._get_lines(*args, **kwargs):
   1383             self.add_line(line)
   1384             lines.append(line)

C:\Program Files\Anaconda3\lib\site-packages\matplotlib\axes\_base.py in _grab_next_args(self, *args, **kwargs)
    379                 return
    380             if len(remaining) <= 3:
--> 381                 for seg in self._plot_args(remaining, kwargs):
    382                     yield seg
    383                 return

C:\Program Files\Anaconda3\lib\site-packages\matplotlib\axes\_base.py in _plot_args(self, tup, kwargs)
    329         ret = []
    330         if len(tup) > 1 and is_string_like(tup[-1]):
--> 331             linestyle, marker, color = _process_plot_format(tup[-1])
    332             tup = tup[:-1]
    333         elif len(tup) == 3:

C:\Program Files\Anaconda3\lib\site-packages\matplotlib\axes\_base.py in _process_plot_format(fmt)
    114         else:
    115             raise ValueError(
--> 116                 'Unrecognized character %c in format string' % c)
    117 
    118     if linestyle is None and marker is None:

ValueError: Unrecognized character X in format string

In [ ]: