In [1]:
from pyspark.sql import SparkSession

spark = SparkSession \
    .builder \
    .appName("Python Spark Feedforward neural network example") \
    .config("spark.some.config.option", "some-value") \
    .getOrCreate()

In [2]:
from pyspark.ml.classification import MultilayerPerceptronClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

# Load training data
df = spark.read.format('com.databricks.spark.csv').\
                               options(header='true', \
                               inferschema='true').load("./data/WineData.csv",header=True);
df.show(5)


+-------------+----------------+-----------+--------------+---------+-------------------+--------------------+-------+----+---------+-------+-------+
|fixed acidity|volatile acidity|citric acid|residual sugar|chlorides|free sulfur dioxide|total sulfur dioxide|density|  pH|sulphates|alcohol|quality|
+-------------+----------------+-----------+--------------+---------+-------------------+--------------------+-------+----+---------+-------+-------+
|          7.4|             0.7|        0.0|           1.9|    0.076|               11.0|                34.0| 0.9978|3.51|     0.56|    9.4|      5|
|          7.8|            0.88|        0.0|           2.6|    0.098|               25.0|                67.0| 0.9968| 3.2|     0.68|    9.8|      5|
|          7.8|            0.76|       0.04|           2.3|    0.092|               15.0|                54.0|  0.997|3.26|     0.65|    9.8|      5|
|         11.2|            0.28|       0.56|           1.9|    0.075|               17.0|                60.0|  0.998|3.16|     0.58|    9.8|      6|
|          7.4|             0.7|        0.0|           1.9|    0.076|               11.0|                34.0| 0.9978|3.51|     0.56|    9.4|      5|
+-------------+----------------+-----------+--------------+---------+-------------------+--------------------+-------+----+---------+-------+-------+
only showing top 5 rows


In [3]:
df.printSchema()


root
 |-- fixed acidity: double (nullable = true)
 |-- volatile acidity: double (nullable = true)
 |-- citric acid: double (nullable = true)
 |-- residual sugar: double (nullable = true)
 |-- chlorides: double (nullable = true)
 |-- free sulfur dioxide: double (nullable = true)
 |-- total sulfur dioxide: double (nullable = true)
 |-- density: double (nullable = true)
 |-- pH: double (nullable = true)
 |-- sulphates: double (nullable = true)
 |-- alcohol: double (nullable = true)
 |-- quality: integer (nullable = true)


In [4]:
# Convert to float format
def string_to_float(x):
    return float(x)

# 
def condition(r):
    if (0<= r <= 4):
        label = "low" 
    elif(4< r <= 6):
        label = "medium"
    else: 
        label = "high" 
    return label

In [5]:
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType, DoubleType
string_to_float_udf = udf(string_to_float, DoubleType())
quality_udf = udf(lambda x: condition(x), StringType())

In [6]:
#df= df.withColumn("quality", string_to_float_udf("quality")).withColumn("Cquality", quality_udf("quality"))
df= df.withColumn("quality", quality_udf("quality"))

In [7]:
df.printSchema()


root
 |-- fixed acidity: double (nullable = true)
 |-- volatile acidity: double (nullable = true)
 |-- citric acid: double (nullable = true)
 |-- residual sugar: double (nullable = true)
 |-- chlorides: double (nullable = true)
 |-- free sulfur dioxide: double (nullable = true)
 |-- total sulfur dioxide: double (nullable = true)
 |-- density: double (nullable = true)
 |-- pH: double (nullable = true)
 |-- sulphates: double (nullable = true)
 |-- alcohol: double (nullable = true)
 |-- quality: string (nullable = true)


In [8]:
df.show()


+-------------+----------------+-----------+--------------+---------+-------------------+--------------------+-------+----+---------+-------+-------+
|fixed acidity|volatile acidity|citric acid|residual sugar|chlorides|free sulfur dioxide|total sulfur dioxide|density|  pH|sulphates|alcohol|quality|
+-------------+----------------+-----------+--------------+---------+-------------------+--------------------+-------+----+---------+-------+-------+
|          7.4|             0.7|        0.0|           1.9|    0.076|               11.0|                34.0| 0.9978|3.51|     0.56|    9.4| medium|
|          7.8|            0.88|        0.0|           2.6|    0.098|               25.0|                67.0| 0.9968| 3.2|     0.68|    9.8| medium|
|          7.8|            0.76|       0.04|           2.3|    0.092|               15.0|                54.0|  0.997|3.26|     0.65|    9.8| medium|
|         11.2|            0.28|       0.56|           1.9|    0.075|               17.0|                60.0|  0.998|3.16|     0.58|    9.8| medium|
|          7.4|             0.7|        0.0|           1.9|    0.076|               11.0|                34.0| 0.9978|3.51|     0.56|    9.4| medium|
|          7.4|            0.66|        0.0|           1.8|    0.075|               13.0|                40.0| 0.9978|3.51|     0.56|    9.4| medium|
|          7.9|             0.6|       0.06|           1.6|    0.069|               15.0|                59.0| 0.9964| 3.3|     0.46|    9.4| medium|
|          7.3|            0.65|        0.0|           1.2|    0.065|               15.0|                21.0| 0.9946|3.39|     0.47|   10.0|   high|
|          7.8|            0.58|       0.02|           2.0|    0.073|                9.0|                18.0| 0.9968|3.36|     0.57|    9.5|   high|
|          7.5|             0.5|       0.36|           6.1|    0.071|               17.0|               102.0| 0.9978|3.35|      0.8|   10.5| medium|
|          6.7|            0.58|       0.08|           1.8|    0.097|               15.0|                65.0| 0.9959|3.28|     0.54|    9.2| medium|
|          7.5|             0.5|       0.36|           6.1|    0.071|               17.0|               102.0| 0.9978|3.35|      0.8|   10.5| medium|
|          5.6|           0.615|        0.0|           1.6|    0.089|               16.0|                59.0| 0.9943|3.58|     0.52|    9.9| medium|
|          7.8|            0.61|       0.29|           1.6|    0.114|                9.0|                29.0| 0.9974|3.26|     1.56|    9.1| medium|
|          8.9|            0.62|       0.18|           3.8|    0.176|               52.0|               145.0| 0.9986|3.16|     0.88|    9.2| medium|
|          8.9|            0.62|       0.19|           3.9|     0.17|               51.0|               148.0| 0.9986|3.17|     0.93|    9.2| medium|
|          8.5|            0.28|       0.56|           1.8|    0.092|               35.0|               103.0| 0.9969| 3.3|     0.75|   10.5|   high|
|          8.1|            0.56|       0.28|           1.7|    0.368|               16.0|                56.0| 0.9968|3.11|     1.28|    9.3| medium|
|          7.4|            0.59|       0.08|           4.4|    0.086|                6.0|                29.0| 0.9974|3.38|      0.5|    9.0|    low|
|          7.9|            0.32|       0.51|           1.8|    0.341|               17.0|                56.0| 0.9969|3.04|     1.08|    9.2| medium|
+-------------+----------------+-----------+--------------+---------+-------------------+--------------------+-------+----+---------+-------+-------+
only showing top 20 rows


In [9]:
# convert the data to dense vector
def transData(data):
    return data.rdd.map(lambda r: [r[-1], Vectors.dense(r[:-1])]).toDF(['label','features'])

In [10]:
from pyspark.sql import Row
from pyspark.ml.linalg import Vectors

data= transData(df)
data.show()


+------+--------------------+
| label|            features|
+------+--------------------+
|medium|[7.4,0.7,0.0,1.9,...|
|medium|[7.8,0.88,0.0,2.6...|
|medium|[7.8,0.76,0.04,2....|
|medium|[11.2,0.28,0.56,1...|
|medium|[7.4,0.7,0.0,1.9,...|
|medium|[7.4,0.66,0.0,1.8...|
|medium|[7.9,0.6,0.06,1.6...|
|  high|[7.3,0.65,0.0,1.2...|
|  high|[7.8,0.58,0.02,2....|
|medium|[7.5,0.5,0.36,6.1...|
|medium|[6.7,0.58,0.08,1....|
|medium|[7.5,0.5,0.36,6.1...|
|medium|[5.6,0.615,0.0,1....|
|medium|[7.8,0.61,0.29,1....|
|medium|[8.9,0.62,0.18,3....|
|medium|[8.9,0.62,0.19,3....|
|  high|[8.5,0.28,0.56,1....|
|medium|[8.1,0.56,0.28,1....|
|   low|[7.4,0.59,0.08,4....|
|medium|[7.9,0.32,0.51,1....|
+------+--------------------+
only showing top 20 rows


In [11]:
from pyspark.ml.feature import IndexToString, StringIndexer, VectorIndexer
# Index labels, adding metadata to the label column.
# Fit on whole dataset to include all labels in index.
labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(data)
labelIndexer.transform(data).show(6)


+------+--------------------+------------+
| label|            features|indexedLabel|
+------+--------------------+------------+
|medium|[7.4,0.7,0.0,1.9,...|         0.0|
|medium|[7.8,0.88,0.0,2.6...|         0.0|
|medium|[7.8,0.76,0.04,2....|         0.0|
|medium|[11.2,0.28,0.56,1...|         0.0|
|medium|[7.4,0.7,0.0,1.9,...|         0.0|
|medium|[7.4,0.66,0.0,1.8...|         0.0|
+------+--------------------+------------+
only showing top 6 rows


In [12]:
# Automatically identify categorical features, and index them.
# Set maxCategories so features with > 4 distinct values are treated as continuous.
featureIndexer =VectorIndexer(inputCol="features", \
                              outputCol="indexedFeatures", \
                              maxCategories=4).fit(data)

featureIndexer.transform(data).show(6)


+------+--------------------+--------------------+
| label|            features|     indexedFeatures|
+------+--------------------+--------------------+
|medium|[7.4,0.7,0.0,1.9,...|[7.4,0.7,0.0,1.9,...|
|medium|[7.8,0.88,0.0,2.6...|[7.8,0.88,0.0,2.6...|
|medium|[7.8,0.76,0.04,2....|[7.8,0.76,0.04,2....|
|medium|[11.2,0.28,0.56,1...|[11.2,0.28,0.56,1...|
|medium|[7.4,0.7,0.0,1.9,...|[7.4,0.7,0.0,1.9,...|
|medium|[7.4,0.66,0.0,1.8...|[7.4,0.66,0.0,1.8...|
+------+--------------------+--------------------+
only showing top 6 rows


In [13]:
data.printSchema()


root
 |-- label: string (nullable = true)
 |-- features: vector (nullable = true)


In [14]:
# Split the data into train and test
(trainingData, testData) = data.randomSplit([0.6, 0.4])

In [15]:
data.show()


+------+--------------------+
| label|            features|
+------+--------------------+
|medium|[7.4,0.7,0.0,1.9,...|
|medium|[7.8,0.88,0.0,2.6...|
|medium|[7.8,0.76,0.04,2....|
|medium|[11.2,0.28,0.56,1...|
|medium|[7.4,0.7,0.0,1.9,...|
|medium|[7.4,0.66,0.0,1.8...|
|medium|[7.9,0.6,0.06,1.6...|
|  high|[7.3,0.65,0.0,1.2...|
|  high|[7.8,0.58,0.02,2....|
|medium|[7.5,0.5,0.36,6.1...|
|medium|[6.7,0.58,0.08,1....|
|medium|[7.5,0.5,0.36,6.1...|
|medium|[5.6,0.615,0.0,1....|
|medium|[7.8,0.61,0.29,1....|
|medium|[8.9,0.62,0.18,3....|
|medium|[8.9,0.62,0.19,3....|
|  high|[8.5,0.28,0.56,1....|
|medium|[8.1,0.56,0.28,1....|
|   low|[7.4,0.59,0.08,4....|
|medium|[7.9,0.32,0.51,1....|
+------+--------------------+
only showing top 20 rows


In [16]:
# specify layers for the neural network:
# input layer of size 11 (features), two intermediate of size 5 and 4
# and output of size 7 (classes)
layers = [11, 5, 4, 4, 3 , 7]

# create the trainer and set its parameters
FNN = MultilayerPerceptronClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures",\
                                         maxIter=100, layers=layers, blockSize=128, seed=1234)

In [17]:
# Convert indexed labels back to original labels.
labelConverter = IndexToString(inputCol="prediction", outputCol="predictedLabel",
                               labels=labelIndexer.labels)

In [18]:
# Chain indexers and forest in a Pipeline
from pyspark.ml import Pipeline
pipeline = Pipeline(stages=[labelIndexer, featureIndexer, FNN, labelConverter])

In [19]:
# train the model
# Train model.  This also runs the indexers.
model = pipeline.fit(trainingData)

In [20]:
# Make predictions.
predictions = model.transform(testData)

In [21]:
# Select example rows to display.
predictions.select("features","label","predictedLabel").show(5)


+--------------------+-----+--------------+
|            features|label|predictedLabel|
+--------------------+-----+--------------+
|[5.1,0.585,0.0,1....| high|        medium|
|[5.2,0.48,0.04,1....| high|        medium|
|[5.4,0.835,0.08,1...| high|        medium|
|[5.5,0.49,0.03,1....| high|        medium|
|[5.6,0.66,0.0,2.2...| high|        medium|
+--------------------+-----+--------------+
only showing top 5 rows


In [22]:
# Select (prediction, true label) and compute test error
evaluator = MulticlassClassificationEvaluator(
    labelCol="indexedLabel", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(predictions)
print("Predictions accuracy = %g, Test Error = %g" % (accuracy,(1.0 - accuracy)))


Predictions accuracy = 0.808642, Test Error = 0.191358