In [17]:
import numpy as np
from pyspark.mllib.clustering import KMeans
import matplotlib.pyplot as plt
from pyspark.sql import SQLContext, Row
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
%pylab inline
In [19]:
feature_names = ['sepal length (cm)',
'sepal width (cm)',
'petal length (cm)',
'petal width (cm)']
target_names = ['setosa', 'versicolor', 'virginica']
target_indices = sc.parallelize(target_names).zipWithIndex()
print target_indices.take(3)
names_df = target_indices.map(lambda row: Row(type=row[0], index=row[1])).toDF()
In [20]:
random_points = sc.textFile("../data/iris.csv")
random_points.cache()
random_points.takeSample(True, 5)
Out[20]:
In [21]:
data_df = random_points.map(lambda row: row.split(",")).map(lambda row: Row(
sepal_length=float(row[0]),
sepal_width=float(row[1]),
petal_length=float(row[2]),
petal_width=float(row[3]),
type=row[4])).toDF()
data = data_df.map(lambda row: np.array([float(row[0]), float(row[1]), float(row[2]), float(row[3])]))
data_df.take(5)
Out[21]:
In [22]:
centers = 3
model = KMeans.train(data, centers)
In [23]:
for i , center in enumerate(model.centers):
print i, center
In [24]:
m2 = sc.broadcast(model)
labels = data.map(lambda row: m2.value.predict(row)).collect()
X = data_df.select('petal_width').collect()
Y = data_df.select('sepal_length').collect()
Z = data_df.select('petal_length').collect()
In [25]:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(X, Y, Z, c=labels)
ax.w_xaxis.set_ticklabels([])
ax.w_yaxis.set_ticklabels([])
ax.w_zaxis.set_ticklabels([])
ax.set_xlabel('Petal width')
ax.set_ylabel('Sepal length')
ax.set_zlabel('Petal length')
Out[25]:
In [54]:
ind = data_df.select('type').join(names_df, data_df.type == names_df.type).select('index').collect()
ind[:5]
Out[54]:
In [60]:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(data_df.select('petal_width').collect(),
data_df.select('sepal_length').collect(),
data_df.select('petal_length').collect(), c=ind)
for name, label in target_indices.collect():
X_mean = data_df.filter(data_df.type == name).agg(
{'petal_width': 'avg', 'sepal_length': 'avg', 'petal_length': 'avg'}).collect()[0]
print X_mean, name
ax.text3D(X_mean[0], X_mean[1] + 1.5, X_mean[2], name,
horizontalalignment='center',
bbox=dict(alpha=.5, edgecolor='w', facecolor='w'))
ax.w_xaxis.set_ticklabels([])
ax.w_yaxis.set_ticklabels([])
ax.w_zaxis.set_ticklabels([])
ax.set_xlabel('Petal width')
ax.set_ylabel('Sepal length')
ax.set_zlabel('Petal length')
Out[60]:
In [ ]: