In [1]:
from recommender import Recommender
from pyspark.sql import functions as F
import numpy as np

In [2]:
# Load restaurant reviews
reviews_df = spark.read.parquet('../data/ratings_ugt10_igt10')

# Randomly split data into train and test datasets
train_df, test_df = reviews_df.randomSplit(weights=[0.75, 0.25])

print(train_df.printSchema())


root
 |-- user: integer (nullable = true)
 |-- item: integer (nullable = true)
 |-- rating: byte (nullable = true)

None

In [3]:
estimator = Recommender(
    useALS=True,
    useBias=True,
    lambda_1=7,
    lambda_2=12,
    userCol='user',
    itemCol='item',
    ratingCol='rating',
    rank=76,
    regParam=0.7,
    maxIter=15,
    nonnegative=False
)
model = estimator.fit(train_df)

train_predictions_df = model.transform(train_df)
test_predictions_df = model.transform(test_df)

print(test_predictions_df.printSchema())


Fit done in 69.23029148101341 seconds
Transform done in 0.2065514309797436 seconds
Transform done in 0.26547738403314725 seconds
root
 |-- user: integer (nullable = true)
 |-- item: integer (nullable = true)
 |-- rating: byte (nullable = true)
 |-- prediction: double (nullable = true)

None

In [4]:
train_predictions_df.registerTempTable('train_predictions_df')
test_predictions_df.registerTempTable('test_predictions_df')
df1 = spark.sql(
'''
select
    user,
    item,
    rating,
    prediction,
    row_number() over (
        partition by user
        order by prediction desc
    ) as pred_row_num,
    row_number() over (
        partition by user
        order by rating desc
    ) as actual_row_num
from train_predictions_df
where user = 3000
order by pred_row_num
'''
)


df2 = spark.sql(
'''
select
    user,
    item,
    rating,
    prediction,
    row_number() over (
        partition by user
        order by prediction desc
    ) as pred_row_num,
    row_number() over (
        partition by user
        order by rating desc
    ) as actual_row_num
from test_predictions_df
where user = 3000
order by pred_row_num
'''
)

print(df1.show(100))
print(df2.show(100))


+----+----+------+------------------+------------+--------------+
|user|item|rating|        prediction|pred_row_num|actual_row_num|
+----+----+------+------------------+------------+--------------+
|3000| 480|     4| 4.156959697988615|           1|            14|
|3000|1408|     5|3.9188534476927734|           2|             4|
|3000|1159|     5| 3.891992777999069|           3|             3|
|3000| 460|     3|3.8918789298133114|           4|            18|
|3000|  84|     4|3.8664549280181117|           5|            12|
|3000|1277|     4|3.8565775172204755|           6|            16|
|3000| 155|     5| 3.849672341016376|           7|             1|
|3000| 358|     4|3.8364246954375174|           8|            10|
|3000| 474|     4|3.8229278745707402|           9|             8|
|3000| 445|     4|3.7888557441245867|          10|            11|
|3000| 261|     3| 3.732590468515996|          11|            20|
|3000|  83|     4| 3.715063090416448|          12|            15|
|3000| 361|     4|3.6976815655068678|          13|            13|
|3000|2740|     2|3.6934577708616523|          14|            24|
|3000| 920|     4|3.6668854801646726|          15|             9|
|3000|1300|     1|3.6326871941913663|          16|            25|
|3000|1138|     3|3.6309059924756384|          17|            23|
|3000|1606|     3|3.6112455254114626|          18|            22|
|3000|1779|     3| 3.608781551857008|          19|            19|
|3000| 930|     1|3.6044224141105414|          20|            28|
|3000| 115|     3|3.5671699145522853|          21|            17|
|3000| 712|     3| 3.492542564480853|          22|            21|
|3000|1399|     4|3.4864800113879246|          23|             7|
|3000| 546|     5| 3.432808835668477|          24|             5|
|3000|1568|     5| 3.422309281434936|          25|             6|
|3000| 192|     5|3.3236665570569057|          26|             2|
|3000|1002|     1| 3.137779689127302|          27|            26|
|3000| 819|     1| 2.813132299544523|          28|            27|
+----+----+------+------------------+------------+--------------+

None
+----+----+------+------------------+------------+--------------+
|user|item|rating|        prediction|pred_row_num|actual_row_num|
+----+----+------+------------------+------------+--------------+
|3000|  11|     4| 4.063751880822793|           1|             3|
|3000| 108|     4|3.9815876055915678|           2|             4|
|3000| 565|     1|3.9698843601596123|           3|            19|
|3000|  81|     3|3.9410310393464494|           4|             9|
|3000|1173|     3| 3.912003252448265|           5|            10|
|3000|1755|     5|3.7595302246612543|           6|             1|
|3000|1207|     3| 3.741954547770718|           7|            11|
|3000|  90|     5|3.7310606722129744|           8|             2|
|3000|2294|     2| 3.714858779685562|           9|            16|
|3000| 566|     4| 3.654530590807677|          10|             5|
|3000|1019|     4|3.6023476622435844|          11|             6|
|3000| 303|     3| 3.590431380930554|          12|            12|
|3000|1269|     4|3.5849896959285315|          13|             7|
|3000|  16|     4| 3.553801078816406|          14|             8|
|3000| 275|     2|3.5476632238023145|          15|            17|
|3000| 740|     3| 3.458659232114541|          16|            13|
|3000| 297|     3| 3.399766468988462|          17|            14|
|3000| 952|     2| 3.341331525043009|          18|            18|
|3000| 123|     3|3.1481461771234347|          19|            15|
+----+----+------+------------------+------------+--------------+

None

In [6]:
user_id = 3000

new_user_df = spark.sql(
'''
select
    user,
    item,
    rating,
    prediction as orig_prediction
from train_predictions_df
where user = {}
'''.format(user_id)
)

new_user_validate_df = spark.sql(
'''
select
    user,
    item,
    rating,
    prediction as orig_prediction
from test_predictions_df
where user = {}
'''.format(user_id)
)

print(new_user_df.show(100))
print(new_user_validate_df.show(100))


+----+----+------+------------------+
|user|item|rating|   orig_prediction|
+----+----+------+------------------+
|3000| 155|     5| 3.849672341016376|
|3000| 115|     3|3.5671699145522853|
|3000| 192|     5|3.3236665570569057|
|3000| 460|     3|3.8918789298133114|
|3000|1300|     1|3.6326871941913663|
|3000|1399|     4|3.4864800113879246|
|3000| 474|     4|3.8229278745707402|
|3000| 920|     4|3.6668854801646726|
|3000|2740|     2|3.6934577708616523|
|3000|1002|     1| 3.137779689127302|
|3000| 358|     4|3.8364246954375174|
|3000| 819|     1| 2.813132299544523|
|3000|1159|     5| 3.891992777999069|
|3000| 445|     4|3.7888557441245867|
|3000|1779|     3| 3.608781551857008|
|3000| 930|     1|3.6044224141105414|
|3000|  84|     4|3.8664549280181117|
|3000| 361|     4|3.6976815655068678|
|3000| 261|     3| 3.732590468515996|
|3000| 480|     4| 4.156959697988615|
|3000| 712|     3| 3.492542564480853|
|3000|1408|     5|3.9188534476927734|
|3000|  83|     4| 3.715063090416448|
|3000| 546|     5| 3.432808835668477|
|3000|1606|     3|3.6112455254114626|
|3000|1568|     5| 3.422309281434936|
|3000|1277|     4|3.8565775172204755|
|3000|1138|     3|3.6309059924756384|
+----+----+------+------------------+

None
+----+----+------+------------------+
|user|item|rating|   orig_prediction|
+----+----+------+------------------+
|3000| 108|     4|3.9815876055915678|
|3000|2294|     2| 3.714858779685562|
|3000|  81|     3|3.9410310393464494|
|3000|1019|     4|3.6023476622435844|
|3000|1269|     4|3.5849896959285315|
|3000|1207|     3| 3.741954547770718|
|3000| 297|     3| 3.399766468988462|
|3000|  16|     4| 3.553801078816406|
|3000| 565|     1|3.9698843601596123|
|3000|1755|     5|3.7595302246612543|
|3000| 952|     2| 3.341331525043009|
|3000| 740|     3| 3.458659232114541|
|3000|1173|     3| 3.912003252448265|
|3000|  90|     5|3.7310606722129744|
|3000| 303|     3| 3.590431380930554|
|3000|  11|     4| 4.063751880822793|
|3000| 123|     3|3.1481461771234347|
|3000| 566|     4| 3.654530590807677|
|3000| 275|     2|3.5476632238023145|
+----+----+------+------------------+

None

In [7]:
# Pull out the item H matrix
item_factors_df = model.itemFactors
user_factors_df = model.userFactors.filter('id={}'.format(user_id))
user_factors = np.array(user_factors_df.collect()[0]['features'])
print(len(user_factors))
print(user_factors)
filtered_item_factors_df = item_factors_df.join(new_user_df, F.col('id') == new_user_df['item'])
print(filtered_item_factors_df.show(100))


76
[ -1.20576360e-07  -8.76581154e-08  -1.51162027e-07   8.71703321e-07
   3.21286763e-07  -1.28828788e-07   1.39126257e-06  -3.77913238e-07
  -1.86078751e-07  -3.04741349e-07   1.25926715e-06  -6.70592840e-07
  -6.93869708e-07  -1.22939127e-06  -1.31133660e-07   2.85535236e-07
   1.78593507e-06  -7.31320483e-07  -5.98695806e-07   4.97694259e-07
  -1.38382450e-06  -2.93457258e-08   6.43191470e-07  -1.55329599e-06
  -2.27806254e-06   1.41107012e-06   2.47590356e-06   2.75408775e-06
   5.62041180e-07   3.36179227e-07  -2.85554425e-08   1.52786797e-07
   2.53999502e-07   3.30409762e-06   2.90849073e-07   3.56460475e-07
   5.93028858e-07   5.03058629e-07  -7.70375550e-07   9.61990281e-07
  -2.66904266e-07  -1.69946088e-06  -9.13650638e-07   2.11928423e-07
   8.26959990e-07   1.06914842e-06   9.22470008e-07  -2.17247589e-06
   1.42489625e-06   6.28112161e-07  -7.53763274e-08   7.00499015e-07
   5.99065800e-07   4.28337074e-07  -1.07303583e-06  -2.10397548e-06
  -2.36752771e-06   8.93263774e-08  -1.93525148e-06  -1.85716237e-06
  -6.72233682e-07  -4.68524064e-07  -2.66593105e-07  -9.78253752e-07
  -3.81688920e-07  -3.02801681e-07  -1.33146978e-06  -1.66978873e-06
   2.32492766e-07   3.60807888e-07   2.69814251e-07  -1.08216136e-06
  -5.94597850e-07  -1.98544171e-06  -4.13569467e-07   1.67778339e-07]
+----+--------------------+----+----+------+------------------+
|  id|            features|user|item|rating|   orig_prediction|
+----+--------------------+----+----+------+------------------+
| 155|[6.7439736E-8, -4...|3000| 155|     5| 3.849672341016376|
| 115|[-1.6427094E-6, -...|3000| 115|     3|3.5671699145522853|
| 192|[-1.009878E-6, 2....|3000| 192|     5|3.3236665570569057|
| 460|[2.0108375E-7, 9....|3000| 460|     3|3.8918789298133114|
|1300|[1.3386878E-7, 3....|3000|1300|     1|3.6326871941913663|
|1399|[2.5853936E-7, -2...|3000|1399|     4|3.4864800113879246|
| 474|[-1.7339776E-6, 4...|3000| 474|     4|3.8229278745707402|
| 920|[1.9017942E-7, -7...|3000| 920|     4|3.6668854801646726|
|2740|[6.1225485E-8, -9...|3000|2740|     2|3.6934577708616523|
|1002|[3.1157538E-7, 2....|3000|1002|     1| 3.137779689127302|
| 358|[3.6583282E-7, -7...|3000| 358|     4|3.8364246954375174|
| 819|[7.424917E-9, 1.2...|3000| 819|     1| 2.813132299544523|
|1159|[-1.6000612E-7, -...|3000|1159|     5| 3.891992777999069|
| 445|[9.284124E-7, 3.5...|3000| 445|     4|3.7888557441245867|
|1779|[1.8435967E-8, 3....|3000|1779|     3| 3.608781551857008|
| 930|[4.935266E-7, -3....|3000| 930|     1|3.6044224141105414|
|  84|[8.6385405E-8, 5....|3000|  84|     4|3.8664549280181117|
| 361|[1.6421822E-7, 5....|3000| 361|     4|3.6976815655068678|
| 261|[2.711676E-7, -2....|3000| 261|     3| 3.732590468515996|
| 480|[-2.2908354E-8, -...|3000| 480|     4| 4.156959697988615|
| 712|[-1.447679E-7, 5....|3000| 712|     3| 3.492542564480853|
|1408|[2.079922E-7, -1....|3000|1408|     5|3.9188534476927734|
|  83|[3.8330705E-7, -1...|3000|  83|     4| 3.715063090416448|
| 546|[-4.1839897E-7, 6...|3000| 546|     5| 3.432808835668477|
|1606|[8.5061515E-8, -9...|3000|1606|     3|3.6112455254114626|
|1568|[6.81889E-7, 3.34...|3000|1568|     5| 3.422309281434936|
|1277|[-7.1591575E-7, 9...|3000|1277|     4|3.8565775172204755|
|1138|[-6.772744E-7, -8...|3000|1138|     3|3.6309059924756384|
+----+--------------------+----+----+------+------------------+

None

In [14]:
rating_stats_df = model.rating_stats_df
item_bias_df = model.item_bias_df

filtered_item_factors_df2 = (
    filtered_item_factors_df
    .crossJoin(rating_stats_df)
    .join(item_bias_df, on='item')
    .withColumn(
        'orig_rating',
        F.col('rating')
    )
    .withColumn(
        'rating',
        F.col('rating')
        - F.col('avg_rating')
        - F.col('item_bias')
    )
)

(
    filtered_item_factors_df2
    .select(
        'item', 'user', 'rating', 'orig_prediction',
        'avg_rating', #'stddev_rating',
        'item_bias', 'avg_diffs_item_rating',
        'stderr_diffs_item_rating', 'stddev_diffs_item_rating',
        'count_item_rating', 'orig_rating'
    )
    .show(100, truncate=False)
)


+----+----+-------------------+------------------+------------------+----------------------+----------------------+------------------------+------------------------+-----------------+-----------+
|item|user|rating             |orig_prediction   |avg_rating        |item_bias             |avg_diffs_item_rating |stderr_diffs_item_rating|stddev_diffs_item_rating|count_item_rating|orig_rating|
+----+----+-------------------+------------------+------------------+----------------------+----------------------+------------------------+------------------------+-----------------+-----------+
|155 |3000|1.0765790671313908 |3.849672341016376 |3.8081435881372085|0.11527734473140061   |0.15244754486771733   |0.32244150160575175     |0.957187309111377       |609              |5          |
|115 |3000|-0.6409185064717741|3.5671699145522853|3.8081435881372085|-0.16722508166543446  |-0.2549520987755066   |0.5246044207984703      |1.0420334869674042      |235              |3          |
|192 |3000|1.6025848518593546 |3.3236665570569057|3.8081435881372085|-0.41072843999656317  |-0.5479665969867656   |0.33413356277775874     |0.9422641194199479      |565              |5          |
|460 |3000|-0.9656275217198275|3.8918789298133114|3.8081435881372085|0.15748393358261895   |0.232597152603532     |0.47695798112324456     |0.8372193572997266      |270              |3          |
|1300|3000|-2.706435786179496 |3.6326871941913663|3.8081435881372085|-0.10170780195771253  |-0.17533108813720852  |0.7238705857600447      |1.1896607982784966      |128              |1          |
|1399|3000|0.43977139683058575|3.4864800113879246|3.8081435881372085|-0.2479149849677943   |-0.4246849415206675   |0.7130265101798374      |1.2230228603086486      |133              |4          |
|474 |3000|0.10332353465444351|3.8229278745707402|3.8081435881372085|0.08853287720834795   |0.13070533272610083   |0.4763479607525547      |0.942307697404239       |278              |4          |
|920 |3000|0.2593659279456562 |3.6668854801646726|3.8081435881372085|-0.06750951608286476  |-0.1020211391576168   |0.5112112347597139      |1.0017215024902655      |245              |4          |
|2740|3000|-1.7672063627465209|3.6934577708616523|3.8081435881372085|-0.04093722539068763  |-0.09846616878236984  |1.4052965935685728      |0.8243602928775247      |31               |2          |
|1002|3000|-2.2115282810175447|3.137779689127302 |3.8081435881372085|-0.596615307119664    |-1.0039477839414048   |0.6827388971769064      |1.164369751898941       |143              |1          |
|358 |3000|0.08982671270251798|3.8364246954375174|3.8081435881372085|0.10202969916027348   |0.14923346104311921   |0.4626472710528691      |1.0797872317817503      |305              |4          |
|819 |3000|-1.8868808914597501|2.813132299544523 |3.8081435881372085|-0.9212626966774585   |-1.3846141763725022   |0.5029526120683326      |1.0315091450188536      |255              |1          |
|1159|3000|1.0342586301460155 |3.891992777999069 |3.8081435881372085|0.157597781716776     |0.2608219291041707    |0.6549847736619928      |0.8870610997197517      |145              |5          |
|445 |3000|0.13739566396375197|3.7888557441245867|3.8081435881372085|0.05446074789903948   |0.07782132414349308   |0.42894336096448665     |0.9325533824301733      |342              |4          |
|1779|3000|-0.6825301438112514|3.608781551857008 |3.8081435881372085|-0.12561344432595709  |-0.22651093507598408  |0.8032379916930377      |0.9516504316593488      |98               |3          |
|930 |3000|-2.6781710059522865|3.6044224141105414|3.8081435881372085|-0.12997258218492214  |-0.20599305050280003  |0.5848961914884296      |0.9769205333942352      |186              |1          |
|84  |3000|0.05979648008423219|3.8664549280181117|3.8081435881372085|0.13205993177855926   |0.17428452687876556   |0.3197381259518545      |0.9998453554546729      |626              |4          |
|361 |3000|0.22856984261749846|3.6976815655068678|3.8081435881372085|-0.036713430754707    |-0.051251357560767555 |0.39598388129926176     |0.9097718339584423      |399              |4          |
|261 |3000|-0.8063390604249037|3.732590468515996 |3.8081435881372085|-0.0018045277123049101|-0.0025880325816531303|0.4341883275083955      |0.8153898951511179      |324              |3          |
|480 |3000|-0.2307082898300415|4.156959697988615 |3.8081435881372085|0.42256470169283294   |0.6106875806939599    |0.44519307516100953     |0.813106760102887       |308              |4          |
|712 |3000|-0.5662911563619386|3.492542564480853 |3.8081435881372085|-0.24185243177527     |-0.38860335825215125  |0.6067787923391347      |1.0039619872199375      |174              |3          |
|1408|3000|1.007397960468999  |3.9188534476927734|3.8081435881372085|0.18445845139379238   |0.3077984408482986    |0.6686600072945045      |0.8549765333196228      |138              |5          |
|83  |3000|0.21118831770771077|3.715063090416448 |3.8081435881372085|-0.019331905844919322 |-0.02511159926377352  |0.2989717343555744      |1.0166805113734032      |719              |4          |
|546 |3000|1.493442572435921  |3.432808835668477 |3.8081435881372085|-0.30158616057312965  |-0.43707440574727185  |0.4492521968404069      |1.0113142844583416      |318              |5          |
|1606|3000|-0.6849941172464733|3.6112455254114626|3.8081435881372085|-0.12314947089073523  |-0.24564358813720855  |0.9946783884695423      |0.957427107756338       |64               |3          |
|1568|3000|1.5039421266077504 |3.422309281434936 |3.8081435881372085|-0.312085714744959    |-0.5581435881372087   |0.7884304271771353      |1.0404442665730156      |104              |5          |
|1277|3000|0.06967389088745982|3.8565775172204755|3.8081435881372085|0.12218252097533164   |0.2308174508238304    |0.8891200556455364      |0.8019968244056837      |77               |4          |
|1138|3000|-0.7046545843288887|3.6309059924756384|3.8081435881372085|-0.1034890038083198   |-0.17238199873323506  |0.665703527811684       |1.1803019032355482      |151              |3          |
+----+----+-------------------+------------------+------------------+----------------------+----------------------+------------------------+------------------------+-----------------+-----------+


In [20]:
filtered_item_factors = []
item_ratings = []
for row in filtered_item_factors_df2.collect():
    filtered_item_factors.append(row['features'])
    item_ratings.append(row['rating'])
filtered_item_factors = np.array(filtered_item_factors)
item_ratings = np.array(item_ratings)
print(filtered_item_factors.shape)
print(filtered_item_factors)
print(item_ratings.shape)
print(item_ratings)


(37, 76)
[[ 0.17072918  0.19680859  0.21965314 ...,  0.25315386  0.18137982
   0.19664668]
 [ 0.17004225  0.19632128  0.21882892 ...,  0.25262594  0.18086621
   0.19610749]
 [ 0.1702452   0.19617768  0.21901591 ...,  0.25231725  0.18081307
   0.19602866]
 ..., 
 [ 0.17055109  0.19631332  0.21936715 ...,  0.25242195  0.18098387
   0.19620164]
 [ 0.17201108  0.19791038  0.22122876 ...,  0.25444826  0.18247357
   0.1978122 ]
 [ 0.17415997  0.20037992  0.22399253 ...,  0.25762111  0.18475056
   0.20028046]]
(37,)
[ 5.36298233  5.45245724  6.74634579  3.97958128  4.20604889  2.37075211
  4.6622507   5.53515798  5.40685973  5.11824118  5.32137605  3.44466695
  3.09691901  5.01685809  3.50894662  6.00216361  5.11175627  4.49832273
  2.36172007  5.05172291  5.22307164  6.07871711  3.7280378   4.17301462
  4.58152686  3.87842346  6.2052313   4.47357016  5.90032127  5.23544264
  6.58630592  4.53662521  5.24434425  6.64344013  3.45982981  5.01527288
  4.40924559]

In [21]:
new_user_factors = np.dot(item_ratings, filtered_item_factors) / sum(item_ratings)
print(sum(item_ratings), item_ratings.mean())
print((new_user_factors / user_factors).mean())
# 35 * 3.5 ~ 122.6 # user 3000, sum(ratings) = 120, avg(ratings) = 3.4285714
# 40 * 3.5 ~ 142.4 # user 3001, sum(ratings) = 143, avg(ratings) = 3.575
# 33 * 4.2 ~ 138.3 # user 3002, sum(ratings) = 144, avg(ratings) = 4.3636
print(user_factors.shape)
print(user_factors)
print(new_user_factors.shape)
print(new_user_factors)
print(new_user_factors / user_factors)


177.627550163 4.80074459901
1.01670848543
(76,)
[ 0.16789205  0.19315493  0.21592818  0.22271433  0.22237928  0.3477492
  0.28929842  0.23317149  0.22982211  0.23681533  0.28274027  0.17150716
  0.27634874  0.19176431  0.18317878  0.30227867  0.24703541  0.19815147
  0.26319936  0.29353014  0.2363289   0.2156568   0.19762208  0.22113441
  0.24244733  0.24345616  0.23763084  0.27925789  0.16919948  0.22582635
  0.23148493  0.28613418  0.18429343  0.18181464  0.22680299  0.18386897
  0.18166584  0.23261586  0.23669659  0.26836577  0.21740308  0.22983143
  0.30741614  0.27158031  0.15876412  0.22398892  0.17326188  0.25610662
  0.21324217  0.20605992  0.25590393  0.28742862  0.26515552  0.20188807
  0.216563    0.2077571   0.22658667  0.30082989  0.19148758  0.2963239
  0.18150458  0.18148848  0.28001389  0.2405971   0.21947584  0.25990134
  0.24066807  0.24068025  0.26468825  0.15339871  0.24988359  0.27242273
  0.21852316  0.24832866  0.17809258  0.19306172]
(76,)
[ 0.17069177  0.19638767  0.21953114  0.22642189  0.22609441  0.35357813
  0.29413988  0.23706955  0.23367617  0.24075716  0.28747752  0.17436845
  0.28097637  0.19496627  0.18622261  0.30734642  0.25115221  0.20145735
  0.26759772  0.29843365  0.24025811  0.21925805  0.20091691  0.2248223
  0.24647953  0.24752074  0.24159321  0.28394191  0.17201172  0.2295928
  0.23534731  0.2909274   0.18738655  0.18485776  0.23060096  0.18694789
  0.1847035   0.23650115  0.24065936  0.27286768  0.22101421  0.2336694
  0.31256794  0.27610539  0.16141218  0.2277332   0.17616472  0.26038719
  0.2168148   0.20951361  0.26019475  0.29222944  0.26959793  0.20525436
  0.220181    0.2112294   0.23036556  0.30584479  0.19467918  0.30129273
  0.18452735  0.18450998  0.28468982  0.24461777  0.22313849  0.26426713
  0.24469372  0.24469128  0.26912891  0.15595233  0.25406796  0.27697799
  0.22217379  0.25248888  0.18107086  0.19629089]
[ 1.01667571  1.01673653  1.01668591  1.01664714  1.01670625  1.01676188
  1.0167352   1.01671757  1.01676973  1.01664514  1.0167548   1.01668315
  1.01674563  1.01669739  1.0166167   1.01676516  1.01666482  1.01668363
  1.01671115  1.01670532  1.01662604  1.01669897  1.01667238  1.01667716
  1.01663123  1.01669534  1.01667447  1.01677307  1.01662084  1.01667852
  1.01668521  1.01675162  1.01678363  1.01673749  1.01674566  1.01674513
  1.01672117  1.01670261  1.01674201  1.01677526  1.01661028  1.01669907
  1.0167584   1.01666203  1.01667917  1.01671636  1.01675403  1.01671404
  1.01675384  1.01676063  1.0167673   1.01670265  1.01675395  1.01667403
  1.01670643  1.01671325  1.01667744  1.01667021  1.01666737  1.01676824
  1.016654    1.0166484   1.01669891  1.01671124  1.01668819  1.01679787
  1.01672697  1.01666538  1.01677694  1.01664699  1.01674525  1.01672129
  1.01670591  1.0167529   1.01672319  1.0167261 ]

In [22]:
# make predictions for "new user"
item_factors = []
item_ids = []
for row in item_factors_df.collect():
    item_factors.append(row['features'])
    item_ids.append(row['id'])
item_factors = np.array(item_factors)
item_ids = np.array(item_ids)
print(item_factors.shape)
print(item_ids.shape)

new_predictions = np.dot(new_user_factors, item_factors.T)
print(new_predictions.shape)
print(new_predictions)


(5055, 76)
(5055,)
(5055,)
[ 4.31537276  4.33377624  4.31210298 ...,  4.56491196  4.18962354
  4.29849381]

In [23]:
new_prediction_df = spark.createDataFrame(zip(item_ids.tolist(), new_predictions.tolist()), ['item', 'prediction'])

In [24]:
new_prediction_df.count()


Out[24]:
5055

In [26]:
new_predicted_rating_df = (
    new_prediction_df
    .crossJoin(avg_rating_df)
    .join(item_bias_df, on='item')
    .withColumn(
        'prediction',
        F.col('prediction')
        + F.col('avg_rating')
        + F.col('item_bias')
        - 5.0
    )
)

In [29]:
new_predicted_rating_df.registerTempTable('new_predicted_rating_df')
new_user_df.registerTempTable('new_user_df')
new_user_validate_df.registerTempTable('new_user_validate_df')

compare_df = spark.sql(
'''
select
    n.item, n.user, n.rating, n.orig_prediction, p.prediction, p.prediction - n.orig_prediction as diff,
    row_number() over (
        partition by n.user
        order by n.rating desc
    ) as actual_row_num,
    row_number() over (
        partition by n.user
        order by n.orig_prediction desc
    ) as orig_row_num,
    row_number() over (
        partition by n.user
        order by p.prediction desc
    ) as new_row_num
from new_user_df n
join new_predicted_rating_df p on n.item = p.item
order by new_row_num
'''
)

compare_validate_df = spark.sql(
'''
select
    n.item, n.user, n.rating, n.orig_prediction, p.prediction, p.prediction - n.orig_prediction as diff,
    row_number() over (
        partition by n.user
        order by n.rating desc
    ) as actual_row_num,
    row_number() over (
        partition by n.user
        order by n.orig_prediction desc
    ) as orig_row_num,
    row_number() over (
        partition by n.user
        order by p.prediction desc
    ) as new_row_num
from new_user_validate_df n
join new_predicted_rating_df p on n.item = p.item
order by new_row_num
'''
)

print(compare_df.show(100))
print(compare_validate_df.show(100))


+----+----+------+------------------+------------------+-------------------+--------------+------------+-----------+
|item|user|rating|   orig_prediction|        prediction|               diff|actual_row_num|orig_row_num|new_row_num|
+----+----+------+------------------+------------------+-------------------+--------------+------------+-----------+
| 480|3000|     4|3.5927897257681742|  3.81553432228821|0.22274459652003564|             8|           1|          1|
|1173|3000|     3| 3.235255175185257| 3.456985877071812|0.22173070188655508|            22|           2|          2|
|1408|3000|     5| 3.224184429480438| 3.446095890419633|0.22191146093919478|             1|           3|          3|
| 358|3000|     4| 3.159180354462533| 3.381953848089708| 0.2227734936271748|             9|           4|          4|
| 460|3000|     3|3.1488988573980254| 3.370878155473447|0.22197929807542138|            23|           5|          5|
|1277|3000|     4| 3.130123051173655|3.3523843730220673| 0.2222613218484124|            10|           6|          6|
|1159|3000|     5|  3.09157607868851| 3.312974221623433|0.22139814293492321|             2|           7|          7|
|1755|3000|     5|3.0541775835641776|3.2762301366690796|0.22205255310490202|             3|           8|          8|
|  84|3000|     4|3.0505960277997666|3.2721373891663497| 0.2215413613665831|            11|           9|          9|
| 474|3000|     4| 3.003364389636584| 3.225228054418393|0.22186366478180908|            12|          10|         10|
| 445|3000|     4|2.9684564989757565| 3.189629642864084|0.22117314388832732|            13|          11|         11|
| 361|3000|     4| 2.904718511917526|3.1266863877281583|0.22196787581063226|            14|          12|         12|
|1207|3000|     3|2.9001820215821423|3.1217894486706577|0.22160742708851533|            24|          13|         13|
| 261|3000|     3|2.9001609896618383|3.1212148278781253|  0.221053838216287|            25|          14|         14|
|  90|3000|     5|  2.89288813118647|3.1143593242786434|0.22147119309217356|             4|          15|         15|
| 566|3000|     4|2.8563396113909443| 3.077853577847968| 0.2215139664570236|            15|          16|         16|
|  83|3000|     4| 2.849663430565349| 3.070918809872005|0.22125537930665562|            16|          17|         17|
|1300|3000|     1|2.7951153966894093|  3.01771887070006|0.22260347401065061|            34|          18|         18|
|1138|3000|     3|2.7897754487794515| 3.012932514366815| 0.2231570655873636|            26|          19|         19|
| 920|3000|     4| 2.787465541215588|3.0091162815094155|0.22165074029382748|            17|          20|         20|
|1019|3000|     4| 2.755550505240503| 2.977363367099808| 0.2218128618593047|            18|          21|         21|
| 930|3000|     1| 2.755358412798887| 2.977146767825861|0.22178835502697414|            35|          22|         22|
|  16|3000|     4|2.7055371679702684|2.9272462355945335|  0.221709067624265|            19|          23|         23|
|1269|3000|     4| 2.653100378460609| 2.874696514191873|0.22159613573126435|            20|          24|         24|
| 275|3000|     2|  2.65024631542013| 2.871917519684116| 0.2216712042639859|            31|          25|         25|
| 303|3000|     3|  2.64703310402618|   2.8688798318065|0.22184672778031977|            27|          26|         26|
|1399|3000|     4|2.6082919758004133|2.8305196977138767| 0.2222277219134634|            21|          27|         27|
|2740|3000|     2|2.6075902892939773|2.8282948160778814|0.22070452678390406|            32|          28|         28|
|1606|3000|     3|2.5423902201777757|2.7635417383201384|0.22115151814236267|            28|          29|         29|
|1779|3000|     3|2.5403346386890897|  2.76081219422941|0.22047755554032022|            29|          30|         30|
| 546|3000|     5|   2.5372346553823| 2.759131376742266|0.22189672135996563|             5|          31|         31|
| 297|3000|     3| 2.465421668746856|2.6873863176395067|0.22196464889265055|            30|          32|         32|
|1568|3000|     5|2.4464648262268573| 2.667798774902712|0.22133394867585476|             6|          33|         33|
| 192|3000|     5|  2.35906925777471|2.5806630281368568| 0.2215937703621469|             7|          34|         34|
| 952|3000|     2|2.3319565972099188|2.5527900296896755|0.22083343247975673|            33|          35|         35|
|1002|3000|     1|1.9824755064100739| 2.203633744277419|0.22115823786734534|            36|          36|         36|
| 819|3000|     1|1.6003150723374606|1.8219719213977097|0.22165684906024907|            37|          37|         37|
+----+----+------+------------------+------------------+-------------------+--------------+------------+-----------+

None
+----+----+------+------------------+------------------+-------------------+--------------+------------+-----------+
|item|user|rating|   orig_prediction|        prediction|               diff|actual_row_num|orig_row_num|new_row_num|
+----+----+------+------------------+------------------+-------------------+--------------+------------+-----------+
|  11|3000|     4|3.3061328933823084| 3.527706094319786|0.22157320093747757|             2|           1|          1|
| 108|3000|     4|3.2931053822728344| 3.514990858097507|0.22188547582467244|             3|           2|          2|
| 565|3000|     1|  3.25730409713311|3.4783955973966822|0.22109150026357227|            10|           3|          3|
|  81|3000|     3| 3.185044459935323| 3.406360301304616|0.22131584136929305|             4|           4|          4|
| 155|3000|     5|3.0966694269736745| 3.318378914754307|0.22170948778063249|             1|           5|          5|
|2294|3000|     2|2.9668669832650725|3.1887791320502394|0.22191214878516696|             9|           6|          6|
| 115|3000|     3|2.7437323403100766|2.9648518333746283| 0.2211194930645517|             5|           7|          7|
| 712|3000|     3|2.5115074421390213| 2.732983440206212|0.22147599806719054|             6|           8|          8|
| 740|3000|     3| 2.494558598644832|2.7153208112398293| 0.2207622125949973|             7|           9|          9|
| 123|3000|     3|2.2141604310865315| 2.436003266129572| 0.2218428350430406|             8|          10|         10|
+----+----+------+------------------+------------------+-------------------+--------------+------------+-----------+

None

In [33]:
discount_factor_df = (
    reviews_df
    .groupBy('item')
    .count()
    .select(
        F.col('item'),
        F.col('count').alias('num_ratings'),
        (1 - (1 / F.sqrt(F.col('count')))).alias('discount_factor')
    )
)

discount_factor_df.show(20)


+----+-----------+------------------+
|item|num_ratings|   discount_factor|
+----+-----------+------------------+
| 496|        273|0.9394772467331197|
| 148|        705|0.9623378211422645|
|1645|        148|0.9178005063473214|
|1959|         83|0.8902357400103097|
| 463|        361|0.9473684210526316|
| 833|        212|0.9313197180256555|
| 471|        381| 0.948768448042144|
|1342|        202|0.9296402455269708|
|1238|        154| 0.919417703597462|
|1829|        101|0.9004962809790011|
|1088|        161|0.9211889593760899|
|2366|         36|0.8333333333333334|
|2659|         40| 0.841886116991581|
|1591|         80|0.8881966011250105|
|1580|         77| 0.886039423540362|
|2866|         47|0.8541350085021054|
|2122|         66|0.8769085090206673|
|2142|         81|0.8888888888888888|
|3794|         27|0.8075499102701247|
|3997|         18|0.7642977396044841|
+----+-----------+------------------+
only showing top 20 rows


In [ ]: