In [17]:
from pyspark.ml.linalg import Vectors, SparseVector, DenseVector
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.regression import GBTRegressor
from collections import defaultdict
from pyspark import SparkContext
from pyspark.sql import SparkSession
import os
import json
import cPickle as pickle
import math
from pyspark.sql import functions as F


def tuple2sparse(tp, size=43, begin=19, end=42):
    dic = {}
    for i in xrange(end-begin):
        if (tp[i] - 0) > 10e-4:
            dic[i+begin] = tp[i]
    v = Vectors.sparse(size, dic)
    return v


def add(v1, v2):
    assert isinstance(v1, SparseVector) and isinstance(v2, SparseVector), 'One of them is not SparseVector!'
    assert v1.size == v2.size, 'Size not equal!'
    values = defaultdict(float) # Dictionary with default value 0.0
    # Add values from v1
    for i in range(v1.indices.size):
        values[v1.indices[i]] += v1.values[i]
    # Add values from v2
    for i in range(v2.indices.size):
        values[v2.indices[i]] += v2.values[i]
    return Vectors.sparse(v1.size, dict(values))

In [31]:
def loadDataJson(business_path='', user_path='', star_path=''):
    bDF = spark.read.json(business_path)
    uDF = spark.read.json(user_path)
    sDF = spark.read.json(star_path)

    businessDF = bDF.rdd.map(lambda x: (x['b_id'], tuple2sparse(
                         tuple(x['loc']) + tuple(x['votes']) + (x['avg_star'], ) +
                         tuple(x['cates']) + (x['rev_num'], ) + tuple(x['ckins']),
                         begin=19, end=42))).toDF(['b_id', 'b_features'])

    userDF = uDF.rdd.map(lambda x: (x['u_id'], tuple2sparse(
                     tuple(x['loc']) + tuple(x['votes']) +
                     (x['loc_num'], x['avg_star'], x['rev_num']) + tuple(x['cates']),
                     begin=0, end=19))).toDF(['u_id', 'u_features'])

    starDF = sDF.select((sDF.business_id).alias('b_id'), (sDF.user_id).alias('u_id'), 
                        (sDF.stars).alias('label'), (sDF.review_id).alias('rev_id'))
    return businessDF, userDF, starDF


def transData4GBT(businessDF, userDF, starDF):
    alldata = starDF.select(starDF.b_id, starDF.u_id, starDF.label) \
                    .join(businessDF, starDF.b_id == businessDF.b_id).drop(businessDF.b_id) \
                    .join(userDF, starDF.u_id == userDF.u_id).drop(userDF.u_id)\
                    .select('label', 'b_features', 'u_features', 'u_id', 'b_id')
    assembler = VectorAssembler(
                    inputCols=["b_features", "u_features"],
                    outputCol="features")

    data = assembler.transform(alldata).drop('b_features', 'u_features')
    return data


def traingbt(datafrom='json', business_path='', user_path='', star_path=''):
    gbt = GBTRegressor(maxIter=5, maxDepth=2, seed=42)
    if datafrom == 'json':
        businessDF, userDF, starDF = loadDataJson(business_path=business_path,
                                                  user_path=user_path,
                                                  star_path=star_path)
    elif datafrom == 'mongodb':
        businessDF, userDF, starDF = loadDataMongo()
    data = transData4GBT(businessDF, userDF, starDF)
    model = gbt.fit(data)
    return model

In [19]:
def recommendation(businessDF, userDF, testDF, model):
    CartesianDF = testDF.crossJoin(businessDF.select('b_id')).drop(testStarDF.b_id).drop('rev_id')
    recDF = transData4GBT(businessDF, userDF, CartesianDF)
    predDF = model.transform(recDF)
    
    temp = predDF.groupby('u_id').agg(F.max(predDF.prediction)) \
                 .withColumnRenamed('max(prediction)', 'prediction')
    pred = temp.join(predDF, ['prediction', 'u_id'], 'outer').drop(predDF.u_id).drop(predDF.prediction)
    pred = pred.select('u_id', 'b_id')

    return pred

In [34]:
business_path = 'businesses.json'
user_path = 'users.json'
star_path = 'yelp_academic_dataset_review.json'


gbt = GBTRegressor(maxIter=50, maxDepth=6, seed=42)
businessDF, userDF, starDF = loadDataJson(business_path=business_path,
                                          user_path=user_path,
                                          star_path=star_path)
# split starDF to training data and test data
trainStarDF, testStarDF = starDF.randomSplit([0.7, 0.3])

trainDF = transData4GBT(businessDF, userDF, trainStarDF)

model = gbt.fit(trainDF)

testDF = transData4GBT(businessDF, userDF, testStarDF)
predDF = model.transform(testDF)

predDF.show()
errors = predDF.rdd.map(lambda x: (x.label - x.prediction)**2).collect()
RMSE = math.sqrt(sum(errors)/len(errors))
print 'RMSE: %.8f' % RMSE

# recDF = recommendation(businessDF, testStarDF, model)
# recDF.printSchema()
# recDF.show()


+-----+--------------------+--------------------+--------------------+------------------+
|label|                u_id|                b_id|            features|        prediction|
+-----+--------------------+--------------------+--------------------+------------------+
|    5|-1zQA2f_syMAdA04P...|aNe8ofTYrealxqv7V...|(86,[19,21,22,23,...| 2.624166121228215|
|    2|-3i9bhfvrM3F1wsC9...|5iSmZO0SrKU6EoXK_...|(86,[19,21,22,23,...|  3.63287926235794|
|    4|-3i9bhfvrM3F1wsC9...|ghpFh6XpH1TYZhjAG...|(86,[19,21,22,23,...| 4.274533711932253|
|    4|-3i9bhfvrM3F1wsC9...|GtHu9uGXpn7Jg_Z7v...|(86,[19,21,22,23,...|3.8912370468409585|
|    4|-4Anvj46CWf57KWI9...|vKA9sIqBcW0UlTKGh...|(86,[19,21,22,23,...|3.3628979446389167|
|    4|-55DgUo52I3zW9Rxk...|3qlqzQrwh8hjBltlg...|(86,[19,21,22,23,...|  4.57561898326436|
|    5|-55DgUo52I3zW9Rxk...|CYWRPE-1IHPBb-zfF...|(86,[19,21,22,23,...|4.5735588804823175|
|    5|-55DgUo52I3zW9Rxk...|3awTUGMdUVrwEBkFF...|(86,[19,21,22,23,...| 4.580856617019137|
|    4|-55DgUo52I3zW9Rxk...|6t98--hqg8suYkm_3...|(86,[19,21,22,23,...| 4.322059431814168|
|    5|-7JSlmBJKUQwREG_y...|QD0gnPAdy7w2vZZG9...|(86,[19,21,22,23,...| 4.602464715419953|
|    1|-7V6r0PLuBlFVjbLJ...|x8O-Mll5ksDpeIgtA...|(86,[19,21,22,23,...|1.6449891781854125|
|    4|-9da1xk7zgnnfO1uT...|PXShA3JZMXr2mEH3o...|(86,[19,21,22,23,...| 4.517274403933761|
|    4|-9da1xk7zgnnfO1uT...|_YUcCnJXjUgkS9fSn...|(86,[19,21,22,23,...| 4.259394274403336|
|    5|-9da1xk7zgnnfO1uT...|7dHYudt6OOIjiaxkS...|(86,[19,21,22,23,...|  4.17251951814607|
|    4|-9da1xk7zgnnfO1uT...|EUWBT5GDxPC95w9it...|(86,[19,21,22,23,...| 3.863477352529793|
|    2|-9da1xk7zgnnfO1uT...|lYCeqldIiOggsbByH...|(86,[19,21,22,23,...|2.6200234809776504|
|    3|-9da1xk7zgnnfO1uT...|q18xbq3Cbyp_BJyfM...|(86,[19,21,22,23,...|  3.96319528777114|
|    5|-9da1xk7zgnnfO1uT...|y0x795PyDX8JL_oyI...|(86,[19,21,22,23,...| 3.863477352529793|
|    5|-9da1xk7zgnnfO1uT...|yofHPSC24EsWTMJq3...|(86,[19,21,22,23,...|3.6645203093244483|
|    4|-9da1xk7zgnnfO1uT...|jaJnPIX9VxsFyfV5z...|(86,[19,21,22,23,...| 4.266858412775395|
+-----+--------------------+--------------------+--------------------+------------------+
only showing top 20 rows

RMSE: 1.05118689

In [32]:
# recDF = recommendation(businessDF, userDF, testStarDF, model)
# recDF.printSchema()
# recDF.show()

CartesianDF = testStarDF.crossJoin(businessDF.select('b_id')).drop(testStarDF.b_id).drop('rev_id')
CartesianDF.printSchema()


root
 |-- u_id: string (nullable = true)
 |-- label: long (nullable = true)
 |-- b_id: string (nullable = true)

In [33]:
recDF = transData4GBT(businessDF, userDF, CartesianDF)
# predDF = model.transform(recDF)
    
# temp = predDF.groupby('u_id').agg(F.max(predDF.prediction)) \
#                  .withColumnRenamed('max(prediction)', 'prediction')
# pred = temp.join(predDF, ['prediction', 'u_id'], 'outer').drop(predDF.u_id).drop(predDF.prediction)
# pred = pred.select('u_id', 'b_id')

recDF.printSchema()


An error occurred while calling o1090.transform.
: org.apache.spark.SparkException: Job 46 cancelled because Stage 214 was cancelled
	at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1435)
	at org.apache.spark.scheduler.DAGScheduler.handleJobCancellation(DAGScheduler.scala:1375)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleStageCancellation$1.apply$mcVI$sp(DAGScheduler.scala:1364)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleStageCancellation$1.apply(DAGScheduler.scala:1363)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleStageCancellation$1.apply(DAGScheduler.scala:1363)
	at scala.collection.IndexedSeqOptimized$class.foreach(IndexedSeqOptimized.scala:33)
	at scala.collection.mutable.ArrayOps$ofInt.foreach(ArrayOps.scala:234)
	at org.apache.spark.scheduler.DAGScheduler.handleStageCancellation(DAGScheduler.scala:1363)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:1619)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1605)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1594)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:48)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:628)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:1925)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:1938)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:1951)
	at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:333)
	at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:38)
	at org.apache.spark.sql.Dataset$$anonfun$org$apache$spark$sql$Dataset$$execute$1$1.apply(Dataset.scala:2378)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:57)
	at org.apache.spark.sql.Dataset.withNewExecutionId(Dataset.scala:2780)
	at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$execute$1(Dataset.scala:2377)
	at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collect(Dataset.scala:2384)
	at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2120)
	at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2119)
	at org.apache.spark.sql.Dataset.withTypedCallback(Dataset.scala:2810)
	at org.apache.spark.sql.Dataset.head(Dataset.scala:2119)
	at org.apache.spark.sql.Dataset.head(Dataset.scala:2128)
	at org.apache.spark.sql.Dataset.first(Dataset.scala:2135)
	at org.apache.spark.ml.feature.VectorAssembler.first$lzycompute$1(VectorAssembler.scala:57)
	at org.apache.spark.ml.feature.VectorAssembler.org$apache$spark$ml$feature$VectorAssembler$$first$1(VectorAssembler.scala:57)
	at org.apache.spark.ml.feature.VectorAssembler$$anonfun$2$$anonfun$1.apply$mcI$sp(VectorAssembler.scala:88)
	at org.apache.spark.ml.feature.VectorAssembler$$anonfun$2$$anonfun$1.apply(VectorAssembler.scala:88)
	at org.apache.spark.ml.feature.VectorAssembler$$anonfun$2$$anonfun$1.apply(VectorAssembler.scala:88)
	at scala.Option.getOrElse(Option.scala:121)
	at org.apache.spark.ml.feature.VectorAssembler$$anonfun$2.apply(VectorAssembler.scala:88)
	at org.apache.spark.ml.feature.VectorAssembler$$anonfun$2.apply(VectorAssembler.scala:58)
	at scala.collection.TraversableLike$$anonfun$flatMap$1.apply(TraversableLike.scala:241)
	at scala.collection.TraversableLike$$anonfun$flatMap$1.apply(TraversableLike.scala:241)
	at scala.collection.IndexedSeqOptimized$class.foreach(IndexedSeqOptimized.scala:33)
	at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:186)
	at scala.collection.TraversableLike$class.flatMap(TraversableLike.scala:241)
	at scala.collection.mutable.ArrayOps$ofRef.flatMap(ArrayOps.scala:186)
	at org.apache.spark.ml.feature.VectorAssembler.transform(VectorAssembler.scala:58)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:280)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:214)
	at java.lang.Thread.run(Thread.java:745)

Traceback (most recent call last):
  File "<stdin>", line 30, in transData4GBT
  File "/usr/hdp/current/spark2-client/python/pyspark/ml/base.py", line 105, in transform
    return self._transform(dataset)
  File "/usr/hdp/current/spark2-client/python/pyspark/ml/wrapper.py", line 252, in _transform
    return DataFrame(self._java_obj.transform(dataset._jdf), dataset.sql_ctx)
  File "/usr/hdp/current/spark2-client/python/lib/py4j-0.10.4-src.zip/py4j/java_gateway.py", line 1133, in __call__
    answer, self.gateway_client, self.target_id, self.name)
  File "/usr/hdp/current/spark2-client/python/pyspark/sql/utils.py", line 63, in deco
    return f(*a, **kw)
  File "/usr/hdp/current/spark2-client/python/lib/py4j-0.10.4-src.zip/py4j/protocol.py", line 319, in get_return_value
    format(target_id, ".", name), value)
Py4JJavaError: An error occurred while calling o1090.transform.
: org.apache.spark.SparkException: Job 46 cancelled because Stage 214 was cancelled
	at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1435)
	at org.apache.spark.scheduler.DAGScheduler.handleJobCancellation(DAGScheduler.scala:1375)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleStageCancellation$1.apply$mcVI$sp(DAGScheduler.scala:1364)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleStageCancellation$1.apply(DAGScheduler.scala:1363)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleStageCancellation$1.apply(DAGScheduler.scala:1363)
	at scala.collection.IndexedSeqOptimized$class.foreach(IndexedSeqOptimized.scala:33)
	at scala.collection.mutable.ArrayOps$ofInt.foreach(ArrayOps.scala:234)
	at org.apache.spark.scheduler.DAGScheduler.handleStageCancellation(DAGScheduler.scala:1363)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:1619)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1605)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1594)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:48)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:628)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:1925)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:1938)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:1951)
	at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:333)
	at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:38)
	at org.apache.spark.sql.Dataset$$anonfun$org$apache$spark$sql$Dataset$$execute$1$1.apply(Dataset.scala:2378)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:57)
	at org.apache.spark.sql.Dataset.withNewExecutionId(Dataset.scala:2780)
	at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$execute$1(Dataset.scala:2377)
	at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collect(Dataset.scala:2384)
	at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2120)
	at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2119)
	at org.apache.spark.sql.Dataset.withTypedCallback(Dataset.scala:2810)
	at org.apache.spark.sql.Dataset.head(Dataset.scala:2119)
	at org.apache.spark.sql.Dataset.head(Dataset.scala:2128)
	at org.apache.spark.sql.Dataset.first(Dataset.scala:2135)
	at org.apache.spark.ml.feature.VectorAssembler.first$lzycompute$1(VectorAssembler.scala:57)
	at org.apache.spark.ml.feature.VectorAssembler.org$apache$spark$ml$feature$VectorAssembler$$first$1(VectorAssembler.scala:57)
	at org.apache.spark.ml.feature.VectorAssembler$$anonfun$2$$anonfun$1.apply$mcI$sp(VectorAssembler.scala:88)
	at org.apache.spark.ml.feature.VectorAssembler$$anonfun$2$$anonfun$1.apply(VectorAssembler.scala:88)
	at org.apache.spark.ml.feature.VectorAssembler$$anonfun$2$$anonfun$1.apply(VectorAssembler.scala:88)
	at scala.Option.getOrElse(Option.scala:121)
	at org.apache.spark.ml.feature.VectorAssembler$$anonfun$2.apply(VectorAssembler.scala:88)
	at org.apache.spark.ml.feature.VectorAssembler$$anonfun$2.apply(VectorAssembler.scala:58)
	at scala.collection.TraversableLike$$anonfun$flatMap$1.apply(TraversableLike.scala:241)
	at scala.collection.TraversableLike$$anonfun$flatMap$1.apply(TraversableLike.scala:241)
	at scala.collection.IndexedSeqOptimized$class.foreach(IndexedSeqOptimized.scala:33)
	at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:186)
	at scala.collection.TraversableLike$class.flatMap(TraversableLike.scala:241)
	at scala.collection.mutable.ArrayOps$ofRef.flatMap(ArrayOps.scala:186)
	at org.apache.spark.ml.feature.VectorAssembler.transform(VectorAssembler.scala:58)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:280)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:214)
	at java.lang.Thread.run(Thread.java:745)



In [30]:
predDF = model.transform(recDF)
predDF.printSchema()


root
 |-- label: long (nullable = true)
 |-- b_features: vector (nullable = true)
 |-- u_features: vector (nullable = true)
 |-- u_id: string (nullable = true)
 |-- b_id: string (nullable = true)
 |-- features: vector (nullable = true)
 |-- prediction: double (nullable = true)

In [39]:
from pyspark.sql import functions as F
temp = prediction.groupby('u_id').agg(F.max(prediction.prediction)) \
                 .withColumnRenamed('max(prediction)', 'prediction')
pred = temp.join(prediction, ['prediction', 'u_id'], 'outer').drop(prediction.u_id).drop(prediction.prediction)

pred.printSchema()
pred.show()


root
 |-- prediction: double (nullable = true)
 |-- u_id: string (nullable = true)
 |-- label: double (nullable = true)
 |-- b_id: string (nullable = true)
 |-- features: vector (nullable = true)

+------------------+--------------------+-----+--------------------+--------------------+
|        prediction|                u_id|label|                b_id|            features|
+------------------+--------------------+-----+--------------------+--------------------+
|1.5334101408098666|09D5rRgsFmHRJKVBm...|  2.0|EtLmuDIMsBCmqdgpy...|(43,[0,5,6,7,8,11...|
|1.5334101408098666|1-VO40rPQDC4Q9Spt...|  1.0|K0pN6x7fzmsO8h36N...|(43,[0,5,6,7,9,10...|
|1.5334101408098666|2ezaoRp1PzHaMgIrz...|  1.0|sPd3E7lFzd_yooiq-...|(43,[0,5,6,7,8,18...|
|1.5334101408098666|4-ElUwzF5CgbEy0ay...|  1.0|FR7dh1_TnWNyGbhhq...|(43,[0,5,6,7,12,1...|
|1.5334101408098666|4c-dhmNntBrpUHOCc...|  1.0|vGLl5xum2u2Qf8_Av...|(43,[0,5,6,7,8,9,...|
|1.5334101408098666|4c-dhmNntBrpUHOCc...|  1.0|MTtI2bqoNHN_0m2cH...|(43,[0,5,6,7,8,9,...|
|1.5334101408098666|5Acw0JaxH1A_hDC3S...|  1.0|LIPtg0tDFCNTt-fqb...|(43,[0,1,5,6,7,8,...|
|1.5334101408098666|5qKemxZi2wnA0NzKr...|  1.0|_HVZ1V8IDa49MWdej...|(43,[0,5,6,7,8,18...|
|1.5334101408098666|6Uc4bSDRcrwA-kx1W...|  1.0|WfroD4iB5M1nFw8j8...|(43,[0,5,6,7,12,1...|
|1.5334101408098666|6qwKduiMppfvjJNZB...|  1.0|XG8dARktPWMFpBiQG...|(43,[0,5,6,7,8,10...|
|1.5334101408098666|74UIfuojXgxw3df1C...|  1.0|VnS1QKpsfGj_7c058...|(43,[0,5,6,7,12,1...|
|1.5334101408098666|85pNTdsC6DWqfifX4...|  1.0|2uLU7C6-59QKdiTaw...|(43,[0,2,5,6,7,19...|
|1.5334101408098666|89Fmu93aliAeF2-6U...|  1.0|_V0yJdpXrbdKzBDoV...|(43,[0,5,6,7,8,10...|
|1.5334101408098666|AVJYfsEnp-pXFynk5...|  1.0|YfDmf2hBB8jEdKIEX...|(43,[0,2,5,6,7,9,...|
|1.5334101408098666|AxqbtVrhqubobl5OM...|  1.0|II-vMV6s9Ke6l9V7j...|(43,[0,5,6,7,9,11...|
|1.5334101408098666|C7F_PCbwjx3yIaqXp...|  1.0|vvvDtPXzZHnAYxECh...|(43,[0,5,6,7,16,1...|
|1.5334101408098666|CRYbqNcA31OautCr2...|  1.0|_THIu8AX6CyBmP_3p...|(43,[0,3,5,6,7,8,...|
|1.5334101408098666|ColBn9YdAVZ0HYSpH...|  2.0|HuzhmzDHcI66G1744...|(43,[0,2,3,4,5,6,...|
|1.5334101408098666|F1CP23wUsStv5ObB-...|  1.0|Zo0DWTyHTSyKRVIWI...|(43,[0,5,6,7,8,10...|
|1.5334101408098666|HpxYfwGSLI2uiGZ1q...|  1.0|fkfVkLnoPNgVddCy0...|(43,[0,5,6,7,8,10...|
+------------------+--------------------+-----+--------------------+--------------------+
only showing top 20 rows

In [34]:
businessDF.printSchema()
userDF.printSchema()
starDF.printSchema()
trainStarDF.printSchema()
prediction.printSchema()


root
 |-- b_id: string (nullable = true)
 |-- b_features: vector (nullable = true)

root
 |-- u_id: string (nullable = true)
 |-- u_features: vector (nullable = true)

root
 |-- b_id: string (nullable = true)
 |-- u_id: string (nullable = true)
 |-- label: long (nullable = true)
 |-- rev_id: string (nullable = true)

root
 |-- b_id: string (nullable = true)
 |-- u_id: string (nullable = true)
 |-- label: long (nullable = true)
 |-- rev_id: string (nullable = true)

root
 |-- label: double (nullable = true)
 |-- u_id: string (nullable = true)
 |-- b_id: string (nullable = true)
 |-- features: vector (nullable = true)
 |-- prediction: double (nullable = true)

In [50]:
CartesianDF = testStarDF.crossJoin(businessDF.select('b_id')).drop(testStarDF.b_id).drop('rev_id')
CartesianDF.show()


+--------------------+-----+--------------------+
|                u_id|label|                b_id|
+--------------------+-----+--------------------+
|dVy5EAV9YZIl9Xl-X...|    3|QgNYM-ccNhJ8eGsQP...|
|dVy5EAV9YZIl9Xl-X...|    3|BqsIt1BQKzS-hEKLY...|
|dVy5EAV9YZIl9Xl-X...|    3|MH0oOCJ7DKnIJWwUQ...|
|dVy5EAV9YZIl9Xl-X...|    3|s9ZY6ESOJF0mABkGr...|
|dVy5EAV9YZIl9Xl-X...|    3|ln8nvcRttTQTZeDjc...|
|dVy5EAV9YZIl9Xl-X...|    3|f8e1MH4YvIY1Km7W2...|
|dVy5EAV9YZIl9Xl-X...|    3|llifBVCFAnr124WdK...|
|dVy5EAV9YZIl9Xl-X...|    3|FiB1rfmgaED4mmHpO...|
|dVy5EAV9YZIl9Xl-X...|    3|lZaBsXK-vhxL1Ck8E...|
|dVy5EAV9YZIl9Xl-X...|    3|krBpN5vbCQrB54QvT...|
|dVy5EAV9YZIl9Xl-X...|    3|xxjxUM-VK4N33LN8N...|
|dVy5EAV9YZIl9Xl-X...|    3|qkjOhzGvUPSdKX7ss...|
|dVy5EAV9YZIl9Xl-X...|    3|2G_6PBM-klbh1u2v2...|
|dVy5EAV9YZIl9Xl-X...|    3|GD9mTnCht2bog2yb0...|
|dVy5EAV9YZIl9Xl-X...|    3|E0T-xQJXpM6Hsm-Ee...|
|dVy5EAV9YZIl9Xl-X...|    3|sPvjzXjzvGFwBLwtn...|
|dVy5EAV9YZIl9Xl-X...|    3|8u6NUtxSPH3CbLKTQ...|
|dVy5EAV9YZIl9Xl-X...|    3|qJRZ7eaarbermHS3Z...|
|dVy5EAV9YZIl9Xl-X...|    3|rslX_CGOBr5m6yOor...|
|dVy5EAV9YZIl9Xl-X...|    3|dQ11s7taakRn8omBU...|
+--------------------+-----+--------------------+
only showing top 20 rows

In [17]:


In [ ]:


In [ ]: