In [ ]:
from fastai.collab import *   # Quick access to collab filtering functionality

Collaborative filtering example

collab models use data in a DataFrame of user, items, and ratings.


In [ ]:
path = untar_data(URLs.ML_SAMPLE)
path


Out[ ]:
PosixPath('/home/sgugger/.fastai/data/movie_lens_sample')

In [ ]:
ratings = pd.read_csv(path/'ratings.csv')
series2cat(ratings, 'userId', 'movieId')
ratings.head()


Out[ ]:
userId movieId rating timestamp
0 73 1097 4.0 1255504951
1 561 924 3.5 1172695223
2 157 260 3.5 1291598691
3 358 1210 5.0 957481884
4 130 316 2.0 1138999234

In [ ]:
data = CollabDataBunch.from_df(ratings, seed=42)

In [ ]:
y_range = [0, 5.5]

That's all we need to create and train a model:


In [ ]:
learn = collab_learner(data, n_factors=50, y_range=y_range)
learn.fit_one_cycle(4, 5e-3)


epoch train_loss valid_loss time
0 1.785755 1.323691 00:01
1 0.907917 0.682382 00:01
2 0.668087 0.658977 00:01
3 0.573576 0.650910 00:01

In [ ]:
learn.predict(ratings.iloc[0])


Out[ ]:
(FloatItem 4.145714, tensor(4.1457), tensor(4.1457))

In [ ]: