In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [17]:
def plot_clusters(all_samples, centroids, n_samples_per_cluster):
    #Plot out the different clusters
     #Choose a different colour for each cluster
    colour = plt.cm.rainbow(np.linspace(0,1,len(centroids)))
    for i, centroid in enumerate(centroids):
        #Grab just the samples fpr the given cluster and plot them out with a new colour
        samples = all_samples[i*n_samples_per_cluster:(i+1)*n_samples_per_cluster]
        plt.scatter(samples[:,0], samples[:,1], c=colour[i])
        #Also plot centroid
        plt.plot(centroid[0], centroid[1], markersize=15, marker="x", color='k', mew=5)
        plt.plot(centroid[0], centroid[1], markersize=15, marker="x", color='r', mew=2)
    plt.show()

def create_samples(n_clusters, n_samples_per_cluster, n_features, embiggen_factor, seed):
    np.random.seed(seed)
    slices = []
    centroids = []
    # Create samples for each cluster
    for i in range(n_clusters):
        samples = tf.random_normal((n_samples_per_cluster, n_features),
                                   mean=0.0, stddev=5.0, dtype=tf.float32, seed=seed, name="cluster_{}".format(i))
        current_centroid = (np.random.random((1, n_features)) * embiggen_factor) - (embiggen_factor/2)
        centroids.append(current_centroid)
        samples += current_centroid
        slices.append(samples)
    # Create a big "samples" dataset
    samples = tf.concat(slices, 0, name='samples')
    centroids = tf.concat(centroids, 0, name='centroids')
    return centroids, samples

def choose_random_centroids(samples, n_clusters):
    # Step 0: Initialisation: Select `n_clusters` number of random points
    n_samples = tf.shape(samples)[0]
    random_indices = tf.random_shuffle(tf.range(0, n_samples))
    begin = [0,]
    size = [n_clusters,]
    size[0] = n_clusters
    centroid_indices = tf.slice(random_indices, begin, size)
    initial_centroids = tf.gather(samples, centroid_indices)
    return initial_centroids

def assign_to_nearest(samples, centroids):
    # Finds the nearest centroid for each sample

    # START from http://esciencegroup.com/2016/01/05/an-encounter-with-googles-tensorflow/
    expanded_vectors = tf.expand_dims(samples, 0)
    expanded_centroids = tf.expand_dims(centroids, 1)
    distances = tf.reduce_sum( tf.square(
               tf.subtract(expanded_vectors, expanded_centroids)), 2)
    mins = tf.argmin(distances, 0)
    # END from http://esciencegroup.com/2016/01/05/an-encounter-with-googles-tensorflow/
    nearest_indices = mins
    return nearest_indices

def update_centroids(samples, nearest_indices, n_clusters):
    # Updates the centroid to be the mean of all samples associated with it.
    nearest_indices = tf.to_int32(nearest_indices)
    partitions = tf.dynamic_partition(samples, nearest_indices, n_clusters)
    new_centroids = tf.concat([tf.expand_dims(tf.reduce_mean(partition, 0), 0) for partition in partitions], 0)
    return new_centroids

In [29]:
data_centroids.shape, samples.shape, nearest_indices.shape


Out[29]:
(TensorShape([Dimension(3), Dimension(2)]),
 TensorShape([Dimension(1500), Dimension(2)]),
 TensorShape([Dimension(1500)]))

In [35]:
n_features = 2
n_clusters = 3
n_samples_per_cluster = 500
seed = 700
embiggen_factor = 70


data_centroids, samples = create_samples(n_clusters, n_samples_per_cluster, n_features, embiggen_factor, seed)
initial_centroids = choose_random_centroids(samples, n_clusters)
nearest_indices = assign_to_nearest(samples, initial_centroids)
updated_centroids = update_centroids(samples, nearest_indices, n_clusters)

model = tf.global_variables_initializer()
with tf.Session() as session:
    samp_expand = tf.expand_dims(samples,0)
    cent_expand = tf.expand_dims(initial_centroids,1)
    diff = tf.square(tf.subtract(samp_expand, cent_expand))
    dists = tf.reduce_sum(diff,2)
    mins = tf.argmin(dists, 0)
    
    pyvars = session.run([samp_expand, cent_expand,diff,dists,mins])
    for p in pyvars: print p.shape, p[0],"\n"

    sample_values = session.run(samples)
    updated_centroid_value = session.run(updated_centroids)
    print(updated_centroid_value)

# plot_clusters(sample_values, updated_centroid_value, n_samples_per_cluster)


(1, 1500, 2) [[-24.234732    -5.6361933 ]
 [-21.307875     0.05049753]
 [-29.070787   -15.994404  ]
 ...
 [-19.84385      8.688988  ]
 [-12.377249    19.237701  ]
 [-21.833637    21.711456  ]] 

(3, 1, 2) [[-32.86691    -2.6321235]] 

(3, 1500, 2) [[ 74.51449     9.024435 ]
 [133.61128     7.1964555]
 [ 14.41054   178.55052  ]
 ...
 [169.60008   128.16757  ]
 [419.82617   478.28928  ]
 [121.733086  592.6099   ]] 

(3, 1500) [ 83.538925 140.80774  192.96106  ... 297.76764  898.1155   714.343   ] 

(1500,) 2 

[[-28.299948   13.762105 ]
 [-15.540976    4.9136477]
 [-20.921425   27.978687 ]]