In [73]:
sc


Out[73]:
<pyspark.context.SparkContext at 0x1113eb358>

In [74]:
from pyspark.sql import SQLContext
sqlContext = SQLContext(sc)

In [75]:
df = sqlContext.read.parquet("tweets.consolidated.parquet")

In [76]:
df.show()


+----------+------------------+--------------------+--------------------+--------+
|      user|                id|                text|            location|hasMedia|
+----------+------------------+--------------------+--------------------+--------+
| 429803867|668129332066459648|e0b40f2381c430f6d...|[27.166142,73.852...|   false|
|2575662781|668129436932415488|:) https://t.co/r...|[19.5371016,-96.9...|    true|
|2558754024|668128681945092096|برد 😊 (@ miral -...|[29.10425394,48.1...|   false|
| 175196235|668128627406610432|christmas market:...|[43.6506691,-79.3...|   false|
| 737480838|668128627394019328|يا عزيزي يالمدريد...|[26.21390031,50.4...|   false|
|  22921151|668129030068166657|#noelgeek #ghostb...|[45.50757496,-73....|   false|
|  93448793|668129332041265152|Soooooo these #ne...|[38.72750195,-90....|   false|
| 959736212|668128937801682945|Green Turtle in W...|[39.5640488,-76.9...|   false|
|  59972446|668129025890455552|#Retail #Job in #...|[41.4517093,-82.0...|   false|
|3234610719|668129269160222720|#StaracArabia
الن...|[30.0960606,31.33...|   false|
|2329037172|668128677763379201|Açlık oyunları al...|[38.33868221,27.1...|   false|
|  86583009|668128749032902656|#beaurivagegolf #...|[34.11432705,-77....|   false|
| 569410380|668129231386333185|@bm0406 @ionacrv ...|[19.09500403,72.8...|   false|
|2781520319|668129369819357184|#Bilinmezlik @ İz...|[38.29360749,27.1...|   false|
| 110213197|668128673560592384|Razón tenía aquel...|[9.91077394,-84.0...|   false|
|1179981192|668129055238062081|349.336 personas ...|   [40.4203,-3.7058]|   false|
| 156122032|668128820365512704|🎉🎉🎉 @ Quilmes,...| [-34.7203,-58.2694]|   false|
|  27737029|668129436898885633|Risottinho de moq...|[-22.97027778,-43...|   false|
| 245111438|668128317044826113|fish bowl fridays...|[42.09727335,-75....|   false|
| 543604821|668128518354698242|Viendo el partido...|[19.39476248,-99....|   false|
+----------+------------------+--------------------+--------------------+--------+
only showing top 20 rows


In [77]:
df.printSchema()


root
 |-- user: long (nullable = true)
 |-- id: long (nullable = true)
 |-- text: string (nullable = true)
 |-- location: struct (nullable = true)
 |    |-- latitude: double (nullable = true)
 |    |-- longitude: double (nullable = true)
 |-- hasMedia: boolean (nullable = true)


In [78]:
df = df.cache()

In [80]:
df.groupBy("hasMedia").count().show()


+--------+-----+
|hasMedia|count|
+--------+-----+
|    true|  118|
|   false| 1967|
+--------+-----+


In [81]:
from pyspark.ml import Pipeline
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.feature import StringIndexer, VectorIndexer
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.mllib.util import MLUtils
from pyspark.sql import Row
from pyspark.ml.feature import Tokenizer

In [82]:
def boolToInt(val):
    if val:
        return 1.0
    else:
        return 0.0

ml_df = sqlContext.createDataFrame(df.map(lambda r : Row(id=r.id, text=r.text, label=boolToInt(r.hasMedia))).collect())

In [83]:
ml_df.show()


+------------------+-----+--------------------+
|                id|label|                text|
+------------------+-----+--------------------+
|668129332066459648|  0.0|e0b40f2381c430f6d...|
|668129436932415488|  1.0|:) https://t.co/r...|
|668128681945092096|  0.0|برد 😊 (@ miral -...|
|668128627406610432|  0.0|christmas market:...|
|668128627394019328|  0.0|يا عزيزي يالمدريد...|
|668129030068166657|  0.0|#noelgeek #ghostb...|
|668129332041265152|  0.0|Soooooo these #ne...|
|668128937801682945|  0.0|Green Turtle in W...|
|668129025890455552|  0.0|#Retail #Job in #...|
|668129269160222720|  0.0|#StaracArabia
الن...|
|668128677763379201|  0.0|Açlık oyunları al...|
|668128749032902656|  0.0|#beaurivagegolf #...|
|668129231386333185|  0.0|@bm0406 @ionacrv ...|
|668129369819357184|  0.0|#Bilinmezlik @ İz...|
|668128673560592384|  0.0|Razón tenía aquel...|
|668129055238062081|  0.0|349.336 personas ...|
|668128820365512704|  0.0|🎉🎉🎉 @ Quilmes,...|
|668129436898885633|  0.0|Risottinho de moq...|
|668128317044826113|  0.0|fish bowl fridays...|
|668128518354698242|  0.0|Viendo el partido...|
+------------------+-----+--------------------+
only showing top 20 rows


In [84]:
ml_df.printSchema()


root
 |-- id: long (nullable = true)
 |-- label: double (nullable = true)
 |-- text: string (nullable = true)


In [85]:
from pyspark.ml.feature import Word2Vec

tokenizer = Tokenizer(inputCol="text", outputCol="words")

# Learn a mapping from words to Vectors.
word2Vec = Word2Vec(vectorSize=3, minCount=0, inputCol="words", outputCol="wordvecs")
model = word2Vec.fit(tokenizer.transform(ml_df))
ml_vec_df = model.transform(tokenizer.transform(ml_df))
for feature in ml_vec_df.select("wordvecs").take(3):
    print(feature)


Row(wordvecs=DenseVector([-0.1051, -0.0159, 0.0675]))
Row(wordvecs=DenseVector([-0.219, 0.1291, 0.0865]))
Row(wordvecs=DenseVector([-0.1474, 0.0314, 0.0006]))

In [86]:
ml_vec_df.show()


+------------------+-----+--------------------+--------------------+--------------------+
|                id|label|                text|               words|            wordvecs|
+------------------+-----+--------------------+--------------------+--------------------+
|668129332066459648|  0.0|e0b40f2381c430f6d...|[e0b40f2381c430f6...|[-0.1050815135240...|
|668129436932415488|  1.0|:) https://t.co/r...|[:), https://t.co...|[-0.2190021015703...|
|668128681945092096|  0.0|برد 😊 (@ miral -...|[برد, 😊, (@, mir...|[-0.1473689468370...|
|668128627406610432|  0.0|christmas market:...|[christmas, marke...|[-0.0281017047153...|
|668128627394019328|  0.0|يا عزيزي يالمدريد...|[يا, عزيزي, يالمد...|[-0.0273308289237...|
|668129030068166657|  0.0|#noelgeek #ghostb...|[#noelgeek, #ghos...|[-0.1210067877545...|
|668129332041265152|  0.0|Soooooo these #ne...|[soooooo, these, ...|[-0.0882662865691...|
|668128937801682945|  0.0|Green Turtle in W...|[green, turtle, i...|[-0.0825941441580...|
|668129025890455552|  0.0|#Retail #Job in #...|[#retail, #job, i...|[0.13797220580662...|
|668129269160222720|  0.0|#StaracArabia
الن...|[#staracarabia, ا...|[-0.1060454837512...|
|668128677763379201|  0.0|Açlık oyunları al...|[açlık, oyunları,...|[-0.0877051008865...|
|668128749032902656|  0.0|#beaurivagegolf #...|[#beaurivagegolf,...|[-0.1311062552373...|
|668129231386333185|  0.0|@bm0406 @ionacrv ...|[@bm0406, @ionacr...|[4.58738093988762...|
|668129369819357184|  0.0|#Bilinmezlik @ İz...|[#bilinmezlik, @,...|[-0.1089821634814...|
|668128673560592384|  0.0|Razón tenía aquel...|[razón, tenía, aq...|[-0.0772449535262...|
|668129055238062081|  0.0|349.336 personas ...|[349.336, persona...|[-0.0234433366606...|
|668128820365512704|  0.0|🎉🎉🎉 @ Quilmes,...|[🎉🎉🎉, @, quilm...|[-0.1844782309296...|
|668129436898885633|  0.0|Risottinho de moq...|[risottinho, de, ...|[-0.1389982389893...|
|668128317044826113|  0.0|fish bowl fridays...|[fish, bowl, frid...|[-0.1339836244005...|
|668128518354698242|  0.0|Viendo el partido...|[viendo, el, part...|[-0.1077206712216...|
+------------------+-----+--------------------+--------------------+--------------------+
only showing top 20 rows


In [87]:
ml_vec_min_df = ml_vec_df.drop("text").drop("words")
ml_vec_min_df.show()


+------------------+-----+--------------------+
|                id|label|            wordvecs|
+------------------+-----+--------------------+
|668129332066459648|  0.0|[-0.1050815135240...|
|668129436932415488|  1.0|[-0.2190021015703...|
|668128681945092096|  0.0|[-0.1473689468370...|
|668128627406610432|  0.0|[-0.0281017047153...|
|668128627394019328|  0.0|[-0.0273308289237...|
|668129030068166657|  0.0|[-0.1210067877545...|
|668129332041265152|  0.0|[-0.0882662865691...|
|668128937801682945|  0.0|[-0.0825941441580...|
|668129025890455552|  0.0|[0.13797220580662...|
|668129269160222720|  0.0|[-0.1060454837512...|
|668128677763379201|  0.0|[-0.0877051008865...|
|668128749032902656|  0.0|[-0.1311062552373...|
|668129231386333185|  0.0|[4.58738093988762...|
|668129369819357184|  0.0|[-0.1089821634814...|
|668128673560592384|  0.0|[-0.0772449535262...|
|668129055238062081|  0.0|[-0.0234433366606...|
|668128820365512704|  0.0|[-0.1844782309296...|
|668129436898885633|  0.0|[-0.1389982389893...|
|668128317044826113|  0.0|[-0.1339836244005...|
|668128518354698242|  0.0|[-0.1077206712216...|
+------------------+-----+--------------------+
only showing top 20 rows


In [88]:
labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(ml_vec_min_df)

In [89]:
training, test = ml_vec_min_df.randomSplit((0.7, 0.3), seed = 1)

In [90]:
dt = DecisionTreeClassifier(labelCol="indexedLabel", featuresCol="wordvecs")

In [91]:
pipeline = Pipeline(stages=[labelIndexer, dt])

In [93]:
model = pipeline.fit(training)

In [94]:
predictions = model.transform(test)

In [95]:
predictions.printSchema()
predictions.show()


root
 |-- id: long (nullable = true)
 |-- label: double (nullable = true)
 |-- wordvecs: vector (nullable = true)
 |-- indexedLabel: double (nullable = true)
 |-- rawPrediction: vector (nullable = true)
 |-- probability: vector (nullable = true)
 |-- prediction: double (nullable = true)

+------------------+-----+--------------------+------------+-------------+--------------------+----------+
|                id|label|            wordvecs|indexedLabel|rawPrediction|         probability|prediction|
+------------------+-----+--------------------+------------+-------------+--------------------+----------+
|668129436932415488|  1.0|[-0.2190021015703...|         1.0|  [300.0,8.0]|[0.97402597402597...|       0.0|
|668128627394019328|  0.0|[-0.0273308289237...|         0.0|   [13.0,8.0]|[0.61904761904761...|       0.0|
|668129332041265152|  0.0|[-0.0882662865691...|         0.0|  [300.0,8.0]|[0.97402597402597...|       0.0|
|668128677763379201|  0.0|[-0.0877051008865...|         0.0|  [300.0,8.0]|[0.97402597402597...|       0.0|
|668128749032902656|  0.0|[-0.1311062552373...|         0.0|  [300.0,8.0]|[0.97402597402597...|       0.0|
|668129055238062081|  0.0|[-0.0234433366606...|         0.0|   [13.0,8.0]|[0.61904761904761...|       0.0|
|668129436898885633|  0.0|[-0.1389982389893...|         0.0|  [137.0,0.0]|           [1.0,0.0]|       0.0|
|667705933833834496|  0.0|[-0.0487950840698...|         0.0|  [300.0,8.0]|[0.97402597402597...|       0.0|
|667705933825515520|  0.0|[-0.2451408080756...|         0.0|   [77.0,0.0]|           [1.0,0.0]|       0.0|
|667705535400046593|  0.0|[0.19889285196908...|         0.0|   [45.0,9.0]|[0.83333333333333...|       0.0|
|667705069832306688|  0.0|[-0.0239955218774...|         0.0| [101.0,11.0]|[0.90178571428571...|       0.0|
|667705992579256320|  0.0|[-0.0190466323401...|         0.0| [101.0,11.0]|[0.90178571428571...|       0.0|
|667705619252715520|  0.0|[-0.1546504220200...|         0.0|   [77.0,0.0]|           [1.0,0.0]|       0.0|
|668130883946188800|  1.0|[-0.0498293007121...|         1.0|  [300.0,8.0]|[0.97402597402597...|       0.0|
|668131290781081600|  1.0|[0.04161962299674...|         1.0| [101.0,11.0]|[0.90178571428571...|       0.0|
|668131672475492352|  0.0|[-0.1203543860465...|         0.0| [165.0,13.0]|[0.92696629213483...|       0.0|
|668131018172317697|  0.0|[-0.0839086537200...|         0.0|  [130.0,0.0]|           [1.0,0.0]|       0.0|
|668131919943458817|  0.0|[-0.1943321777507...|         0.0|   [59.0,0.0]|           [1.0,0.0]|       0.0|
|668126962271911936|  0.0|[-0.1602127847986...|         0.0|  [130.0,0.0]|           [1.0,0.0]|       0.0|
|668126962280337408|  1.0|[0.18597440991331...|         1.0|   [49.0,0.0]|           [1.0,0.0]|       0.0|
+------------------+-----+--------------------+------------+-------------+--------------------+----------+
only showing top 20 rows


In [96]:
evaluator = MulticlassClassificationEvaluator(
    labelCol="indexedLabel", predictionCol="prediction", metricName="precision")
accuracy = evaluator.evaluate(predictions)
"Accuracy = {0}, Test Error = {1}".format(accuracy, (1.0 - accuracy))


Out[96]:
'Accuracy = 0.9426356589147287, Test Error = 0.05736434108527133'

In [97]:
print ("{0}".format(predictions.count()))

correct = predictions.where(predictions.indexedLabel == predictions.prediction)
print ("{0}".format(correct.count()))
correct.show()


645
608
+------------------+-----+--------------------+------------+-------------+--------------------+----------+
|                id|label|            wordvecs|indexedLabel|rawPrediction|         probability|prediction|
+------------------+-----+--------------------+------------+-------------+--------------------+----------+
|668128627394019328|  0.0|[-0.0273308289237...|         0.0|   [13.0,8.0]|[0.61904761904761...|       0.0|
|668129332041265152|  0.0|[-0.0882662865691...|         0.0|  [300.0,8.0]|[0.97402597402597...|       0.0|
|668128677763379201|  0.0|[-0.0877051008865...|         0.0|  [300.0,8.0]|[0.97402597402597...|       0.0|
|668128749032902656|  0.0|[-0.1311062552373...|         0.0|  [300.0,8.0]|[0.97402597402597...|       0.0|
|668129055238062081|  0.0|[-0.0234433366606...|         0.0|   [13.0,8.0]|[0.61904761904761...|       0.0|
|668129436898885633|  0.0|[-0.1389982389893...|         0.0|  [137.0,0.0]|           [1.0,0.0]|       0.0|
|667705933833834496|  0.0|[-0.0487950840698...|         0.0|  [300.0,8.0]|[0.97402597402597...|       0.0|
|667705933825515520|  0.0|[-0.2451408080756...|         0.0|   [77.0,0.0]|           [1.0,0.0]|       0.0|
|667705535400046593|  0.0|[0.19889285196908...|         0.0|   [45.0,9.0]|[0.83333333333333...|       0.0|
|667705069832306688|  0.0|[-0.0239955218774...|         0.0| [101.0,11.0]|[0.90178571428571...|       0.0|
|667705992579256320|  0.0|[-0.0190466323401...|         0.0| [101.0,11.0]|[0.90178571428571...|       0.0|
|667705619252715520|  0.0|[-0.1546504220200...|         0.0|   [77.0,0.0]|           [1.0,0.0]|       0.0|
|668131672475492352|  0.0|[-0.1203543860465...|         0.0| [165.0,13.0]|[0.92696629213483...|       0.0|
|668131018172317697|  0.0|[-0.0839086537200...|         0.0|  [130.0,0.0]|           [1.0,0.0]|       0.0|
|668131919943458817|  0.0|[-0.1943321777507...|         0.0|   [59.0,0.0]|           [1.0,0.0]|       0.0|
|668126962271911936|  0.0|[-0.1602127847986...|         0.0|  [130.0,0.0]|           [1.0,0.0]|       0.0|
|668125775305031680|  0.0|[0.06825192621909...|         0.0|  [300.0,8.0]|[0.97402597402597...|       0.0|
|668127545284497408|  0.0|[-0.0344093372114...|         0.0|   [13.0,8.0]|[0.61904761904761...|       0.0|
|668127339750928384|  0.0|[0.14869336241527...|         0.0| [101.0,11.0]|[0.90178571428571...|       0.0|
|668379908150837248|  0.0|[-0.0976984456858...|         0.0|  [130.0,0.0]|           [1.0,0.0]|       0.0|
+------------------+-----+--------------------+------------+-------------+--------------------+----------+
only showing top 20 rows


In [98]:
treeModel = model.stages[1]
print (treeModel) # summary only


DecisionTreeClassificationModel of depth 5 with 49 nodes