Demonstrates matrix factorization with MXNet on the MovieLens 100k dataset. We perform collaborative filtering, where the recommendations are based on previous rating of users.
We are trying to learn embeddings for users and movies, based on user partial ratings of movies, to estimate future movie ratings
For more deep learning based architecture for recommendation, refer to this survey: Deep Learning based Recommender System: A Survey and New Perspectives
In [1]:
import matplotlib.pyplot as plt
import mxnet as mx
from mxnet import gluon, nd, autograd
import numpy as np
from matrix_fact import train
from movielens_data import get_dataset, max_id
In [2]:
ctx = [mx.gpu(0)] if len(mx.test_utils.list_gpus()) > 0 else [mx.cpu()]
batch_size = 128
In [3]:
train_dataset, test_dataset = get_dataset()
max_user, max_item = max_id('./ml-100k/u.data')
(max_user, max_item)
Out[3]:
In [4]:
train_data = gluon.data.DataLoader(train_dataset, shuffle=True, last_batch='rollover', batch_size=batch_size, num_workers=0)
test_data = gluon.data.DataLoader(test_dataset, shuffle=True, batch_size=batch_size, num_workers=0)
In [5]:
for user, item, score in test_data:
print(user[0], item[0], score[0])
break
In [6]:
class LinearMatrixFactorization(gluon.HybridBlock):
def __init__(self, k, max_user=max_user, max_item=max_item):
super(LinearMatrixFactorization, self).__init__(prefix='linearMF_')
# user feature lookup
with self.name_scope():
self.user_embedding = gluon.nn.Embedding(input_dim=max_user, output_dim = k, prefix='emb_user_')
# item feature lookup
self.item_embedding = gluon.nn.Embedding(input_dim=max_item, output_dim = k, prefix='emb_item_')
def hybrid_forward(self, F, user, item):
user_embeddings = self.user_embedding(user).relu()
items_embeddings = self.item_embedding(item).relu()
# predict by the inner product, which is elementwise product and then sum
pred = (user_embeddings * items_embeddings).sum(axis=1)
return pred.flatten()
net1 = LinearMatrixFactorization(64)
net1.initialize(mx.init.Xavier(), ctx=ctx)
mx.viz.plot_network(net1(mx.sym.var('user'), mx.sym.var('item')), node_attrs={"fixedsize":"false"})
Out[6]:
In [7]:
net1.summary(user.as_in_context(ctx[0]), item.as_in_context(ctx[0]))
In [8]:
losses_1 = train(net1, train_data, test_data, epochs=15, learning_rate=1, ctx=ctx)
In [9]:
losses_1
Out[9]:
The optimizer used for training and hyper-parameter influence greatly how fast the model converge. We can try with the Adam optimizer which will often converge much faster than SGD without momentum as we used before. You should see this model over-fitting quickly.
In [10]:
net1 = LinearMatrixFactorization(64)
net1.initialize(mx.init.Xavier(), ctx=ctx)
In [11]:
losses_1_adam = train(net1, train_data, test_data, epochs=15, optimizer='adam', learning_rate=0.01, ctx=ctx)
In [12]:
ratings = nd.dot(net1.user_embedding.weight.data(ctx=ctx[0]), net1.item_embedding.weight.data(ctx=ctx[0]).T).asnumpy()
ratings.shape
Out[12]:
In [13]:
# Helper function to print the recommendation matrix
# And the top 5 movies in several categories
def evaluate_embeddings(ratings):
plt.figure(figsize=(15,15))
plt.xlabel('items')
plt.ylabel('users')
plt.title('Users estimated ratings of items sorted by mean ratings across users')
im = plt.imshow(((ratings[:, ratings.mean(axis=0).argsort()[::-1]])))
cb = plt.colorbar(im,fraction=0.026, pad=0.04, label="score")
top_5_movies = ratings.mean(axis=0).argsort()[::-1][:5] # Highest mean projected rating
worst_5_movies = ratings.mean(axis=0).argsort()[:5] # Lowest mean projected rating
top_5_controversial = ratings.std(axis=0).argsort()[::-1][:5] # With most variance
with open('ml-100k/u.item', 'rb') as f:
movies = f.readlines()
print("Top 5 movies:")
for movie in top_5_movies:
print("{}, average rating {:.2f}".format(str(movies[int(movie)-1]).split("|")[1], ratings.mean(axis=0)[movie]))
print("\nWorst 5 movies:")
for movie in worst_5_movies:
print("{}, average rating {:.2f}".format(str(movies[int(movie)-1]).split("|")[1], ratings.mean(axis=0)[movie]))
print("\n5 most controversial movies:")
for movie in top_5_controversial:
print("{}, average rating {:.2f}".format(str(movies[int(movie)-1]).split("|")[1], ratings.mean(axis=0)[movie]))
In [14]:
evaluate_embeddings(ratings)
We can observe that some movies tend to be widely recommended or not recommended, whilst some other have more variance in their predicted score
In [15]:
class MLPMatrixFactorization(gluon.HybridBlock):
def __init__(self, k, hidden, max_user=max_user, max_item=max_item):
super(MLPMatrixFactorization, self).__init__(prefix='MLP_MF_')
# user feature lookup
with self.name_scope():
self.user_embedding = gluon.nn.Embedding(input_dim=max_user, output_dim = k, prefix='emb_user_')
self.user_mlp = gluon.nn.Dense(hidden, prefix='dense_user_')
# item feature lookup
self.item_embedding = gluon.nn.Embedding(input_dim=max_item, output_dim = k, prefix='emb_item_')
self.item_mlp = gluon.nn.Dense(hidden, prefix='dense_item_')
def hybrid_forward(self, F, user, item):
user_embeddings = self.user_embedding(user)
user_embeddings_relu = user_embeddings.relu()
user_transformed = self.user_mlp(user_embeddings_relu)
items_embeddings = self.item_embedding(item)
items_embeddings_relu = items_embeddings.relu()
items_transformed = self.item_mlp(items_embeddings_relu)
# predict by the inner product, which is elementwise product and then sum
pred = (user_transformed * items_transformed).sum(axis=1)
return pred.flatten()
net2 = MLPMatrixFactorization(64, 64)
net2.initialize(mx.init.Xavier(), ctx=ctx)
mx.viz.plot_network(net2(mx.sym.var('user'), mx.sym.var('item')), node_attrs={"fixedsize":"false"})
Out[15]:
In [16]:
net2.summary(user.as_in_context(ctx[0]), item.as_in_context(ctx[0]))
In [17]:
losses_2 = train(net2, train_data, test_data, epochs=15, ctx=ctx)
We can try training with the Adam optimizer instead
In [18]:
net2 = MLPMatrixFactorization(64, 64)
net2.initialize(mx.init.Xavier(), ctx=ctx)
In [19]:
losses_2_adam = train(net2, train_data, test_data, epochs=15, optimizer='adam', learning_rate=0.01, ctx=ctx)
Borrowing ideas from Deep Residual Learning for Image Recognition (He, et al.) to build a complex deep network that is aggressively regularized, thanks to the dropout layers, to avoid over-fitting, but still achieves good performance.
In [20]:
def get_residual_block(prefix='res_block_', hidden=64):
block = gluon.nn.HybridSequential(prefix=prefix)
with block.name_scope():
block.add(
gluon.nn.Dense(hidden, activation='relu', prefix='d1_'),
gluon.nn.Dropout(0.5, prefix='dropout_'),
gluon.nn.Dense(hidden, prefix='d2_')
)
return block
class ResNetMatrixFactorization(gluon.HybridBlock):
def __init__(self, k, hidden, max_user=max_user, max_item=max_item):
super(ResNetMatrixFactorization, self).__init__(prefix='ResNet_MF_')
# user feature lookup
with self.name_scope():
self.user_embedding = gluon.nn.Embedding(input_dim=max_user, output_dim = k, prefix='emb_user_')
self.user_block1 = get_residual_block('u_block1_', hidden)
self.user_dropout = gluon.nn.Dropout(0.5)
self.user_block2 = get_residual_block('u_block2_', hidden)
# item feature lookup
self.item_embedding = gluon.nn.Embedding(input_dim=max_item, output_dim = k, prefix='emb_item_')
self.item_block1 = get_residual_block('i_block1_', hidden)
self.item_dropout = gluon.nn.Dropout(0.5)
self.item_block2 = get_residual_block('i_block2_', hidden)
def hybrid_forward(self, F, user, item):
user_embeddings = self.user_embedding(user)
user_block1 = self.user_block1(user_embeddings)
user1 = (user_embeddings + user_block1).relu()
user2 = self.user_dropout(user1)
user_block2 = self.user_block2(user2)
user_transformed = (user2 + user_block2).relu()
item_embeddings = self.item_embedding(item)
item_block1 = self.item_block1(item_embeddings)
item1 = (item_embeddings + item_block1).relu()
item2 = self.item_dropout(item1)
item_block2 = self.item_block2(item2)
item_transformed = (item2 + item_block2).relu()
# predict by the inner product, which is elementwise product and then sum
pred = (user_transformed * item_transformed).sum(axis=1)
return pred.flatten()
net3 = ResNetMatrixFactorization(128, 128)
net3.initialize(mx.init.Xavier(), ctx=ctx)
mx.viz.plot_network(net3(mx.sym.var('user'), mx.sym.var('item')), node_attrs={"fixedsize":"false"})
Out[20]:
In [21]:
net3.summary(user.as_in_context(ctx[0]), item.as_in_context(ctx[0]))
In [22]:
losses_3 = train(net3, train_data, test_data, epochs=15, optimizer='adam', learning_rate=0.001, ctx=ctx, num_epoch_lr=10)
Contrary to the linear model where we can use directly the embedding weights, here we compute each combination of user / items and store predicted rating.
In [23]:
%%time
users = []
items = []
for i in range(max_user):
for j in range(max_item):
users.append(i+1)
items.append(j+1)
dataset = gluon.data.ArrayDataset(np.array(users).astype('float32'), np.array(items).astype('float32'))
dataloader = gluon.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)
ratings = np.zeros((max_user+1, max_item+1))
for users, items in dataloader:
users = users.as_in_context(ctx[0])
items = items.as_in_context(ctx[0])
scores = net3(users, items).asnumpy()
ratings[users.asnumpy().astype('int32'), items.asnumpy().astype('int32')] = scores.reshape(-1)
In [24]:
evaluate_embeddings(ratings)
In [25]:
train_1, test_1 = list(zip(*losses_1))
train_1a, test_1a = list(zip(*losses_1_adam))
train_2, test_2 = list(zip(*losses_2))
train_2a, test_2a = list(zip(*losses_2_adam))
train_3a, test_3a = list(zip(*losses_3))
In [26]:
losses_1_adam
Out[26]:
In [27]:
plt.figure(figsize=(20,20))
plt.xlabel('epochs')
plt.ylabel('loss')
plt.title('Evolution of training and testing losses')
x = range(15)
h1, = plt.plot(x, test_1, 'c', label='test loss Linear')
h2, = plt.plot(x, train_1, 'c--', label='train loss Linear')
h3, = plt.plot(x, test_1a, 'b', label='test loss Linear Adam')
h4, = plt.plot(x, train_1a, 'b--', label='train loss Linear Adam')
h5, = plt.plot(x, test_2, 'r', label='test loss MLP')
h6, = plt.plot(x, train_2, 'r--', label='train loss MLP')
h7, = plt.plot(x, test_2a, 'm', label='test loss MLP Adam')
h8, = plt.plot(x, train_2a, 'm--', label='train loss MLP Adam')
h9, = plt.plot(x, test_3a, 'g', label='test loss ResNet Adam')
h10, = plt.plot(x, train_3a, 'g--', label='train loss ResNet Adam')
l = plt.legend(handles=[h1, h2, h3, h4, h5, h6, h7, h8, h9, h10])
This tutorial is inspired by some examples from xlvector/github.