Factorization Machine example


In [2]:
import numpy as np
import pandas as pd
from sklearn.metrics import mean_squared_error
from reco.datasets import loadMovieLens100k
from reco.recommender import FM

At first we'll test only with the bare minimum userId, itemId and rating columns.


In [3]:
train, test, _, _ = loadMovieLens100k(train_test_split=True)
print(train.head())


   userId  itemId  rating
0       1       1       5
1       1       2       3
2       1       3       4
3       1       4       3
4       1       5       3

So we have the user ids, item ids and the respective ratings in the 3 columns. Next we need to separate the rating column since we are going to predict that. Also we need to explicitly set the column data type to string for userId and itemId so that the model treats them as categorical variables, not integers. We'll do this for both the train and test sets.


In [4]:
y_train = train['rating']
train.drop(['rating'], axis=1, inplace=True)

train['userId'] = train['userId'].astype('str')
train['itemId'] = train['itemId'].astype('str')

y_test = test['rating']
test.drop(['rating'], axis=1, inplace=True)

test['userId'] = test['userId'].astype('str')
test['itemId'] = test['itemId'].astype('str')

Next we'll train the model. We choose 60 iterations here. Tweak the hyperparameters to get the best performance.


In [5]:
f = FM(k=10, iterations = 60, learning_rate = 0.003, regularizer=0.005)
f.fit(X=train, y=y_train)


epoch 0 time 1.3084863953578691 mse 1.103937278612045
epoch 1 time 1.3081417801617612 mse 0.9918084223513931
epoch 2 time 1.350830668491486 mse 0.9441643862460812
epoch 3 time 1.2941405570513194 mse 0.918135365419286
epoch 4 time 1.4166798860326555 mse 0.9012910466433169
epoch 5 time 1.263930385542353 mse 0.8891798658571386
epoch 6 time 1.3548125238632345 mse 0.8798711358070531
epoch 7 time 1.2486451517382644 mse 0.8723796737482616
epoch 8 time 1.334478039259876 mse 0.8661391582826142
epoch 9 time 1.3298616543206911 mse 0.8607936833117162
epoch 10 time 1.3456986371107167 mse 0.8560972256015189
epoch 11 time 1.3040315601878643 mse 0.8518732029213533
epoch 12 time 1.2945989499629675 mse 0.8479907528352713
epoch 13 time 1.2999260809521154 mse 0.8443463899115101
epoch 14 time 1.5280591527425713 mse 0.8408569059317575
epoch 15 time 1.2989221384813376 mse 0.8374490178145726
epoch 16 time 1.3151708361458674 mse 0.8340607647692536
epoch 17 time 1.3207630837991111 mse 0.8306367098314906
epoch 18 time 1.3031858854367826 mse 0.8271262597987651
epoch 19 time 1.2858701570168947 mse 0.8234816346995734
epoch 20 time 1.3624046336704012 mse 0.8196576523896757
epoch 21 time 1.3828550840222924 mse 0.8156125948517245
epoch 22 time 1.378272248922304 mse 0.8113088650218219
epoch 23 time 1.3169453309016994 mse 0.8067143435877449
epoch 24 time 1.2876844010387316 mse 0.8018056674977648
epoch 25 time 1.4080057939113644 mse 0.7965697647492628
epoch 26 time 1.3344948141794717 mse 0.7910051668891229
epoch 27 time 1.3392107546197565 mse 0.7851247160243937
epoch 28 time 1.3637294876465518 mse 0.7789566240764688
epoch 29 time 1.3135400222231226 mse 0.7725430262729343
epoch 30 time 1.3222819433671447 mse 0.7659354928117184
epoch 31 time 1.326075627901652 mse 0.7591904014333649
epoch 32 time 1.298048019301369 mse 0.7523653422892986
epoch 33 time 1.3614546626800674 mse 0.7455137151528289
epoch 34 time 1.3428180916778558 mse 0.7386838396123836
epoch 35 time 1.288762736631206 mse 0.7319171458386186
epoch 36 time 1.3293376204193095 mse 0.725242658316449
epoch 37 time 1.3000971121975908 mse 0.718682402264571
epoch 38 time 1.3512872380423104 mse 0.7122551547086085
epoch 39 time 1.502409571331519 mse 0.7059741139342881
epoch 40 time 1.5003907462250012 mse 0.6998471114763135
epoch 41 time 1.1779399596453786 mse 0.6938797681625268
epoch 42 time 1.3463758333214813 mse 0.6880713515266206
epoch 43 time 1.2903691175188854 mse 0.6824222927777186
epoch 44 time 1.4245148675018413 mse 0.6769357576798637
epoch 45 time 1.393100548503405 mse 0.6716091010529329
epoch 46 time 1.396695122035716 mse 0.666441059633255
epoch 47 time 1.3886603002200033 mse 0.6614292919955723
epoch 48 time 1.2709485013615023 mse 0.6565718336179993
epoch 49 time 1.2618918681389175 mse 0.6518625029640728
epoch 50 time 1.1742144688057579 mse 0.6472990698513169
epoch 51 time 1.299181420390795 mse 0.6428769946181062
epoch 52 time 1.3833145709504322 mse 0.6385938200562767
epoch 53 time 1.376311042017889 mse 0.6344446901176573
epoch 54 time 1.4503672066367415 mse 0.6304259856180925
epoch 55 time 1.3107200123696714 mse 0.6265355494433584
epoch 56 time 1.3655674353591252 mse 0.6227682319377223
epoch 57 time 1.3806196436496663 mse 0.6191188047229711
epoch 58 time 1.417063156478278 mse 0.6155828409804598
epoch 59 time 1.4089167449800186 mse 0.6121578331236209

In [7]:
y_pred = f.predict(test)
print("MSE: {}".format(mean_squared_error(y_test, y_pred)))


MSE: 0.9257072902506759

Now we'll try with all the columns and train our model on the whole dataset.


In [8]:
train, test, _, _ = loadMovieLens100k(train_test_split=True, all_columns=True)
print(train.head())


  userId itemId  rating  age gender  occupation  5  6  7  8 ...  14  15  16  \
0      1      1     5.0   24      M  technician  0  0  0  1 ...   0   0   0   
1      2      1     4.0   53      F       other  0  0  0  1 ...   0   0   0   
2      6      1     4.0   42      M   executive  0  0  0  1 ...   0   0   0   
3     10      1     4.0   53      M      lawyer  0  0  0  1 ...   0   0   0   
4     13      1     3.0   47      M    educator  0  0  0  1 ...   0   0   0   

   17  18  19  20  21  22  23  
0   0   0   0   0   0   0   0  
1   0   0   0   0   0   0   0  
2   0   0   0   0   0   0   0  
3   0   0   0   0   0   0   0  
4   0   0   0   0   0   0   0  

[5 rows x 25 columns]

This time, we also need to change the data type of the columns gender and occupation to string so that they are treated as categorical variables and hence one-hot encoded.


In [9]:
y_train = train['rating']
train.drop(['rating'], axis=1, inplace=True)
train['userId'] = train['userId'].astype('str')
train['itemId'] = train['itemId'].astype('str')
train['gender'] = train['gender'].astype('str')
train['occupation'] = train['occupation'].astype('str')


y_test = test['rating']
test.drop(['rating'], axis=1, inplace=True)
test['userId'] = test['userId'].astype('str')
test['itemId'] = test['itemId'].astype('str')
test['gender'] = test['gender'].astype('str')
test['occupation'] = test['occupation'].astype('str')

Before training, we also need to normalize the age column since the values are greatly different from the other columns and hence will hamper the performance of the model. We choose min-max normaliztion here.


In [11]:
train['age'] = (train['age']-train['age'].min())/(train['age'].max()-train['age'].min())
test['age'] = (test['age']-test['age'].min())/(test['age'].max()-test['age'].min())

In [17]:
f = FM(k=10, iterations = 60, learning_rate = 0.003, regularizer=0.005)
f.fit(X=train, y=y_train)


epoch 0 time 7.168705366406357 mse 0.9966038516961236
epoch 1 time 7.158128050254618 mse 0.9222644952229134
epoch 2 time 7.203841164850473 mse 0.8932051216741685
epoch 3 time 7.34335999451514 mse 0.8760120494763252
epoch 4 time 7.146919851257735 mse 0.8635113788610401
epoch 5 time 7.183824309702686 mse 0.8532038650239079
epoch 6 time 7.159351890041307 mse 0.8440652823004936
epoch 7 time 7.132119266761038 mse 0.8356230268739092
epoch 8 time 7.165921459096808 mse 0.8276217140381995
epoch 9 time 7.163955876126693 mse 0.8199056253557706
epoch 10 time 7.1612620428422815 mse 0.8123744204657366
epoch 11 time 7.17143894895662 mse 0.8049685822245547
epoch 12 time 7.152979973298443 mse 0.797652948928141
epoch 13 time 7.142680537337128 mse 0.7904183100351516
epoch 14 time 7.199183571956382 mse 0.7832648678385027
epoch 15 time 7.281502478494531 mse 0.7762022575076472
epoch 16 time 7.312662256321346 mse 0.7692416144569367
epoch 17 time 7.2228478781003105 mse 0.7623965802119932
epoch 18 time 7.257174833682257 mse 0.7556843192832916
epoch 19 time 7.31355351509319 mse 0.7491185749053207
epoch 20 time 7.302148393126572 mse 0.7427076533727938
epoch 21 time 7.311096354044366 mse 0.736460116691938
epoch 22 time 7.273140181074268 mse 0.7303784034808909
epoch 23 time 7.195639687854964 mse 0.7244711058271042
epoch 24 time 7.209344067823167 mse 0.7187406231693347
epoch 25 time 7.146166803236611 mse 0.7131862807652857
epoch 26 time 7.173868394921101 mse 0.7078013352816999
epoch 27 time 7.156963287359304 mse 0.7025815270893169
epoch 28 time 7.131278332748025 mse 0.6975221286584601
epoch 29 time 7.165291305595474 mse 0.6926182447293624
epoch 30 time 7.152377899554267 mse 0.6878619028923977
epoch 31 time 7.172890708845898 mse 0.6832444490279198
epoch 32 time 7.155224895147512 mse 0.6787638895336717
epoch 33 time 7.175669510744683 mse 0.6744148571740275
epoch 34 time 7.132099209792159 mse 0.6701925872971451
epoch 35 time 7.124478655555777 mse 0.6660875943410898
epoch 36 time 7.142547431997173 mse 0.6620930160500093
epoch 37 time 7.194092748530693 mse 0.6582047665804932
epoch 38 time 7.149415302883881 mse 0.6544209314345084
epoch 39 time 7.1314435292388225 mse 0.6507388755535733
epoch 40 time 7.140981529719738 mse 0.6471559918564882
epoch 41 time 7.181439718414822 mse 0.6436703190300617
epoch 42 time 7.11765235729581 mse 0.6402755400997237
epoch 43 time 7.156116153919356 mse 0.6369706816313878
epoch 44 time 7.119326931877822 mse 0.6337548100162012
epoch 45 time 7.14502720272003 mse 0.6306248569176972
epoch 46 time 7.124262769634242 mse 0.6275790932310583
epoch 47 time 7.132586411804823 mse 0.624612621706506
epoch 48 time 7.481870870315561 mse 0.6217249439365939
epoch 49 time 7.162597472311063 mse 0.6189151638696572
epoch 50 time 7.1752844169382115 mse 0.6161795463602048
epoch 51 time 7.130334561184554 mse 0.6135146150166595
epoch 52 time 7.20361251540271 mse 0.6109206924603029
epoch 53 time 7.2039892217494526 mse 0.608396365959085
epoch 54 time 7.1821099858543676 mse 0.605938737549157
epoch 55 time 7.118031251674893 mse 0.6035460913600308
epoch 56 time 6.2072293339369935 mse 0.601216417630386
epoch 57 time 7.172039928684626 mse 0.598948075474274
epoch 58 time 7.12781795857245 mse 0.5967390789444804
epoch 59 time 7.033004289638939 mse 0.59458741861308

In [18]:
y_pred = f.predict(test)
print("MSE: {}".format(mean_squared_error(y_test, y_pred)))


MSE: 1.0619317734747407

In [ ]: