This notebook expands on the simple model to predict ratings for a user, anime pair by adding in an interaction term each user and each item.

rating(user, item) = offset + user_bias + item_bias + dot_product(user_latent_features, item_latent_features)

When formulated this way, this is a natural extension of the linear model from the previous notebook. However, it's actually a bit deceptive what is going on with "latent_features." While a proper explanation is outside the scope of this notebook, the key here is that, as the dot_product indicates, the user and item latent feature terms are vectors.

What's actually going on in this formulation is a dimensionality reduction problem. What we want to do is essentially embed each of the m users and n items (i.e. anime) in a k-dimensional space. There is no proper label for what these dimensions represent, but similar shows will group nearer to each other, and likewise for users. Since we embed them in the same space, we can take their dot product as a measure of how similar (i.e. compatible) a show and anime pair are. k itself is a hyper-parameter, though I'll begin with 2 because it's good for visualizations.

Therefore, though the formulation above looks similar to the formulation of the linear model, they're actually conceptually quite difference. The objective function will be mean absolute error (MAE) with L2 regularization of all the parameters. This model is implemented in tensorflow, just for the sake of practice.

Much of this code builds off of the linear model.


In [1]:
import matplotlib.pyplot as plt
import matplotlib
import numpy as np

%matplotlib inline
matplotlib.style.use('seaborn')

In [2]:
import pandas

In [3]:
from animerec.data import get_data
users, anime = get_data()

In [4]:
from sklearn.model_selection import train_test_split
train, test = train_test_split(users, test_size = 0.1) #let's split up the dataset into a train and test set.
train, valid = train_test_split(train, test_size = 0.2) #let's split up the dataset into a train and valid set.

In [5]:
from animerec.data import remove_users
train = remove_users(train, 10)

In [6]:
#define validation set
valid_users = valid['user_id']
valid_anime = valid['anime_id']
valid_ratings = valid['rating']

In [7]:
#initialize some local variables
nUsers = len(train.user_id.unique())
nAnime = len(train.anime_id.unique())

# we'll need some data structures in order to vectorize computations
from collections import defaultdict
user_ids = train.user_id
item_ids = train.anime_id

user_index = defaultdict(lambda: -1) # maps a user_id to the index in the bias term.
item_index = defaultdict(lambda: -1) # maps an anime_id to the index in the bias term.

counter = 0
for user in user_ids:
    if user_index[user] == -1:
        user_index[user] = counter
        counter += 1 

counter = 0
for item in item_ids:
    if item_index[item] == -1:
        item_index[item] = counter
        counter += 1

In [8]:
#Terms needed for the latent factors.
k = 3; # hyper-parameter

In [9]:
import tensorflow as tf

In [10]:
y = tf.cast(tf.constant(train['rating'].as_matrix(), shape=[len(train),1]), tf.float32)

In [11]:
def objective(alpha, Bi, Bu, Gi, Gu, y, lam): #Gi, Gu = gamma_i, gamma_u = latent factors for items, users 
    
    '''
    Like in the linear model, we need to construct the "full" matrix for each (user, item) pair. However, with the
    addition of the latent factor terms, it will wasate memory to hold each variable in its own tensor.
    Instead, create one intermediary tensor to represent our prediction ("pred") and simply incrementally add to that
    each additional variable.
    '''
    pred = tf.gather(Bi, train.anime_id.map(lambda _id: item_index[_id]).as_matrix()) #Bi_full
    pred += tf.gather(Bu, train.user_id.map(lambda _id: user_index[_id]).as_matrix()) #Bu_full
    
    Gi_full = tf.gather(Gi, train.anime_id.map(lambda _id: item_index[_id]).as_matrix()) #latent factors of items
    Gu_full = tf.gather(Gu, train.user_id.map(lambda _id: user_index[_id]).as_matrix()) #latent factors of users
    pred += tf.expand_dims(tf.einsum('ij,ji->i', Gi_full, tf.transpose(Gu_full)), 1) # dot product latent factors

    pred += tf.tile(alpha, (len(train), 1)) #alpha_full
    obj = tf.reduce_sum(abs(pred-y))


    # regularization
    obj += lam * tf.reduce_sum(Bi**2)
    obj += lam * tf.reduce_sum(Bu**2) 
    obj += lam * tf.reduce_sum(Gi**2) 
    obj += lam * tf.reduce_sum(Gu**2)
    
    return obj

In [12]:
#initialize alpha, Bi, Bu, Gi, Gu 
alpha = tf.Variable(tf.constant([6.9], shape=[1, 1]))
Bi = tf.Variable(tf.constant([0.0]*nAnime, shape=[nAnime, 1]))
Bu = tf.Variable(tf.constant([0.0]*nUsers, shape=[nUsers, 1]))
Gi = tf.Variable(tf.random_normal([nAnime, k], stddev=0.35))
Gu = tf.Variable(tf.random_normal([nUsers, k], stddev=0.35))

In [13]:
optimizer = tf.train.AdamOptimizer(0.01)

In [14]:
obj = objective(alpha, Bi, Bu, Gi, Gu, y, 0.1)

In [15]:
trainer = optimizer.minimize(obj)

In [16]:
sess = tf.Session()

sess.run(tf.global_variables_initializer())
tLoss = []
vLoss = []
prev = 10e10
for iteration in range(500):
    cvalues = sess.run([trainer, obj])
    print("objective = " + str(cvalues[1]))
    tLoss.append(cvalues[1])
    
    if not iteration % 5:
        cAlpha, cBi, cBu, cGi, cGu, cLoss = sess.run([alpha, Bi, Bu, Gi, Gu, obj])
        indices = valid_users.map(lambda x: user_index[x])
        bu = indices.map(lambda x: 0.0 if x == -1 else float(cBu[x]))
        gu = indices.map(lambda x: np.zeros(k) if x == -1 else cGu[x])
        gu = np.vstack(gu.as_matrix()).astype(np.float)


        indices = valid_anime.map(lambda x: item_index[x])
        bi = indices.map(lambda x: 0.0 if x == -1 else float(cBi[x]))
        gi = indices.map(lambda x: np.zeros(k) if x == -1 else cGi[x])
        gi = np.vstack(gi.as_matrix()).astype(np.float)

        g = np.einsum('ij,ji->i', gi, np.transpose(gu)) 

        preds = bu + bi + g + float(cAlpha)
        MAE = 1.0/len(valid) * sum(abs(valid_ratings-preds))
        vLoss.append(MAE)
        if MAE > prev: break
        else: prev = MAE
    
cAlpha, cBi, cBu, cGi, cGu, cLoss = sess.run([alpha, Bi, Bu, Gi, Gu, obj])
print("\nFinal train loss is ", cLoss)


objective = 4.5784e+06
objective = 4.51974e+06
objective = 4.46316e+06
objective = 4.40923e+06
objective = 4.35878e+06
objective = 4.31225e+06
objective = 4.26918e+06
objective = 4.22897e+06
objective = 4.19105e+06
objective = 4.15489e+06
objective = 4.12004e+06
objective = 4.08617e+06
objective = 4.05301e+06
objective = 4.02032e+06
objective = 3.98796e+06
objective = 3.95576e+06
objective = 3.92362e+06
objective = 3.89145e+06
objective = 3.85917e+06
objective = 3.82674e+06
objective = 3.79408e+06
objective = 3.76119e+06
objective = 3.72802e+06
objective = 3.69457e+06
objective = 3.66086e+06
objective = 3.62688e+06
objective = 3.59272e+06
objective = 3.55844e+06
objective = 3.52418e+06
objective = 3.49009e+06
objective = 3.45637e+06
objective = 3.42332e+06
objective = 3.39119e+06
objective = 3.36028e+06
objective = 3.33084e+06
objective = 3.30305e+06
objective = 3.27699e+06
objective = 3.25263e+06
objective = 3.22986e+06
objective = 3.20846e+06
objective = 3.18822e+06
objective = 3.16893e+06
objective = 3.15039e+06
objective = 3.13238e+06
objective = 3.11476e+06
objective = 3.09738e+06
objective = 3.08015e+06
objective = 3.06303e+06
objective = 3.04597e+06
objective = 3.02898e+06
objective = 3.01208e+06
objective = 2.99533e+06
objective = 2.97873e+06
objective = 2.96238e+06
objective = 2.94629e+06
objective = 2.9305e+06
objective = 2.91508e+06
objective = 2.90004e+06
objective = 2.8854e+06
objective = 2.87116e+06
objective = 2.85735e+06
objective = 2.84395e+06
objective = 2.83096e+06
objective = 2.81836e+06
objective = 2.80615e+06
objective = 2.79434e+06
objective = 2.7829e+06
objective = 2.77181e+06
objective = 2.76105e+06
objective = 2.75062e+06
objective = 2.7405e+06
objective = 2.73067e+06
objective = 2.72115e+06
objective = 2.71191e+06
objective = 2.70295e+06
objective = 2.69426e+06
objective = 2.68584e+06
objective = 2.67767e+06
objective = 2.66977e+06
objective = 2.66212e+06
objective = 2.65472e+06
objective = 2.64757e+06
objective = 2.64067e+06
objective = 2.63401e+06
objective = 2.62757e+06
objective = 2.62136e+06
objective = 2.61539e+06
objective = 2.60963e+06
objective = 2.60409e+06
objective = 2.59875e+06
objective = 2.59363e+06
objective = 2.58871e+06
objective = 2.58398e+06
objective = 2.57943e+06
objective = 2.57506e+06
objective = 2.57086e+06
objective = 2.56682e+06
objective = 2.56295e+06
objective = 2.55922e+06
objective = 2.55565e+06
objective = 2.55222e+06
objective = 2.54891e+06
objective = 2.54574e+06
objective = 2.54269e+06
objective = 2.53976e+06
objective = 2.53694e+06
objective = 2.53423e+06
objective = 2.53162e+06
objective = 2.5291e+06
objective = 2.52666e+06
objective = 2.52432e+06
objective = 2.52206e+06
objective = 2.51987e+06
objective = 2.51776e+06
objective = 2.51573e+06
objective = 2.51376e+06
objective = 2.51185e+06
objective = 2.51e+06
objective = 2.50821e+06
objective = 2.50648e+06
objective = 2.50478e+06
objective = 2.50315e+06
objective = 2.50156e+06
objective = 2.5e+06
objective = 2.4985e+06
objective = 2.49703e+06
objective = 2.49561e+06
objective = 2.49422e+06
objective = 2.49287e+06
objective = 2.49155e+06
objective = 2.49027e+06
objective = 2.48901e+06
objective = 2.4878e+06
objective = 2.4866e+06
objective = 2.48544e+06
objective = 2.48431e+06
objective = 2.4832e+06
objective = 2.48212e+06
objective = 2.48106e+06
objective = 2.48002e+06
objective = 2.47901e+06
objective = 2.47802e+06
objective = 2.47704e+06
objective = 2.4761e+06
objective = 2.47516e+06
objective = 2.47425e+06
objective = 2.47336e+06
objective = 2.47248e+06
objective = 2.47162e+06
objective = 2.47079e+06
objective = 2.46996e+06
objective = 2.46915e+06
objective = 2.46835e+06
objective = 2.46758e+06
objective = 2.46681e+06
objective = 2.46605e+06
objective = 2.46531e+06
objective = 2.46458e+06
objective = 2.46387e+06
objective = 2.46316e+06
objective = 2.46247e+06
objective = 2.46178e+06
objective = 2.4611e+06
objective = 2.46044e+06
objective = 2.45979e+06
objective = 2.45914e+06
objective = 2.45851e+06
objective = 2.45789e+06
objective = 2.45728e+06
objective = 2.45668e+06
objective = 2.45608e+06
objective = 2.45549e+06
objective = 2.45491e+06
objective = 2.45434e+06
objective = 2.45378e+06
objective = 2.45323e+06
objective = 2.45268e+06
objective = 2.45214e+06
objective = 2.4516e+06
objective = 2.45107e+06
objective = 2.45054e+06
objective = 2.45004e+06
objective = 2.44953e+06
objective = 2.44903e+06
objective = 2.44853e+06
objective = 2.44804e+06
objective = 2.44757e+06
objective = 2.44709e+06
objective = 2.44661e+06
objective = 2.44614e+06
objective = 2.44568e+06
objective = 2.44523e+06
objective = 2.44478e+06
objective = 2.44433e+06
objective = 2.44389e+06
objective = 2.44346e+06
objective = 2.44302e+06
objective = 2.4426e+06
objective = 2.44217e+06
objective = 2.44176e+06
objective = 2.44134e+06
objective = 2.44093e+06
objective = 2.44052e+06
objective = 2.44011e+06
objective = 2.43972e+06
objective = 2.43932e+06
objective = 2.43892e+06
objective = 2.43854e+06
objective = 2.43815e+06
objective = 2.43777e+06
objective = 2.43739e+06
objective = 2.43702e+06
objective = 2.43665e+06
objective = 2.43628e+06
objective = 2.43591e+06
objective = 2.43555e+06
objective = 2.4352e+06
objective = 2.43484e+06
objective = 2.43449e+06
objective = 2.43414e+06
objective = 2.43379e+06
objective = 2.43345e+06
objective = 2.43311e+06
objective = 2.43277e+06
objective = 2.43243e+06
objective = 2.4321e+06
objective = 2.43177e+06
objective = 2.43144e+06
objective = 2.43111e+06
objective = 2.43079e+06
objective = 2.43047e+06
objective = 2.43015e+06
objective = 2.42983e+06
objective = 2.42951e+06
objective = 2.4292e+06
objective = 2.42889e+06
objective = 2.42857e+06
objective = 2.42826e+06
objective = 2.42797e+06
objective = 2.42766e+06
objective = 2.42736e+06
objective = 2.42706e+06
objective = 2.42676e+06
objective = 2.42647e+06
objective = 2.42617e+06
objective = 2.42588e+06
objective = 2.42558e+06
objective = 2.4253e+06
objective = 2.42501e+06
objective = 2.42472e+06
objective = 2.42444e+06
objective = 2.42415e+06
objective = 2.42388e+06
objective = 2.4236e+06
objective = 2.42332e+06
objective = 2.42304e+06
objective = 2.42277e+06
objective = 2.42249e+06
objective = 2.42222e+06
objective = 2.42195e+06
objective = 2.42168e+06
objective = 2.42141e+06
objective = 2.42115e+06
objective = 2.42088e+06
objective = 2.42062e+06
objective = 2.42036e+06
objective = 2.4201e+06
objective = 2.41984e+06
objective = 2.41958e+06
objective = 2.41931e+06
objective = 2.41906e+06
objective = 2.41881e+06
objective = 2.41855e+06
objective = 2.41831e+06
objective = 2.41805e+06
objective = 2.4178e+06
objective = 2.41755e+06
objective = 2.41731e+06
objective = 2.41706e+06
objective = 2.41682e+06
objective = 2.41657e+06
objective = 2.41633e+06
objective = 2.41609e+06
objective = 2.41585e+06
objective = 2.41561e+06
objective = 2.41537e+06
objective = 2.41513e+06
objective = 2.41489e+06
objective = 2.41466e+06
objective = 2.41443e+06
objective = 2.41419e+06
objective = 2.41396e+06
objective = 2.41373e+06
objective = 2.4135e+06
objective = 2.41327e+06
objective = 2.41304e+06
objective = 2.41281e+06
objective = 2.41259e+06
objective = 2.41236e+06
objective = 2.41213e+06
objective = 2.41191e+06
objective = 2.41168e+06
objective = 2.41146e+06
objective = 2.41124e+06
objective = 2.41102e+06
objective = 2.4108e+06
objective = 2.41058e+06
objective = 2.41036e+06
objective = 2.41014e+06
objective = 2.40993e+06
objective = 2.40971e+06
objective = 2.4095e+06
objective = 2.40928e+06
objective = 2.40907e+06
objective = 2.40886e+06
objective = 2.40864e+06
objective = 2.40843e+06
objective = 2.40822e+06
objective = 2.40801e+06
objective = 2.40780e+06
objective = 2.4076e+06
objective = 2.40739e+06
objective = 2.40718e+06
objective = 2.40698e+06
objective = 2.40677e+06
objective = 2.40656e+06
objective = 2.40636e+06
objective = 2.40616e+06
objective = 2.40595e+06
objective = 2.40575e+06
objective = 2.40555e+06
objective = 2.40535e+06
objective = 2.40515e+06
objective = 2.40496e+06
objective = 2.40475e+06
objective = 2.40456e+06
objective = 2.40436e+06
objective = 2.40416e+06
objective = 2.40397e+06
objective = 2.40377e+06
objective = 2.40358e+06
objective = 2.40339e+06
objective = 2.40319e+06
objective = 2.403e+06
objective = 2.40281e+06
objective = 2.40261e+06
objective = 2.40242e+06
objective = 2.40223e+06
objective = 2.40204e+06
objective = 2.40185e+06
objective = 2.40166e+06
objective = 2.40148e+06
objective = 2.40129e+06
objective = 2.4011e+06
objective = 2.40092e+06
objective = 2.40073e+06
objective = 2.40054e+06
objective = 2.40036e+06
objective = 2.40017e+06
objective = 2.39999e+06
objective = 2.39981e+06
objective = 2.39963e+06
objective = 2.39944e+06
objective = 2.39926e+06
objective = 2.39908e+06
objective = 2.3989e+06
objective = 2.39872e+06
objective = 2.39854e+06
objective = 2.39836e+06
objective = 2.39818e+06
objective = 2.398e+06
objective = 2.39782e+06
objective = 2.39765e+06
objective = 2.39747e+06
objective = 2.3973e+06
objective = 2.39712e+06
objective = 2.39694e+06
objective = 2.39677e+06
objective = 2.39659e+06
objective = 2.39643e+06
objective = 2.39625e+06
objective = 2.39608e+06
objective = 2.39591e+06
objective = 2.39573e+06
objective = 2.39556e+06
objective = 2.39539e+06
objective = 2.39522e+06
objective = 2.39506e+06
objective = 2.39489e+06
objective = 2.39472e+06
objective = 2.39455e+06
objective = 2.39438e+06
objective = 2.39422e+06
objective = 2.39405e+06
objective = 2.39388e+06
objective = 2.39372e+06
objective = 2.39355e+06
objective = 2.39339e+06
objective = 2.39322e+06
objective = 2.39306e+06
objective = 2.3929e+06
objective = 2.39274e+06
objective = 2.39258e+06
objective = 2.39242e+06
objective = 2.39226e+06
objective = 2.3921e+06
objective = 2.39194e+06
objective = 2.39178e+06
objective = 2.39162e+06
objective = 2.39147e+06
objective = 2.39131e+06
objective = 2.39115e+06
objective = 2.39099e+06
objective = 2.39084e+06
objective = 2.39068e+06
objective = 2.39053e+06
objective = 2.39038e+06
objective = 2.39023e+06
objective = 2.39007e+06
objective = 2.38992e+06
objective = 2.38977e+06
objective = 2.38961e+06
objective = 2.38947e+06
objective = 2.38931e+06
objective = 2.38916e+06
objective = 2.38901e+06
objective = 2.38887e+06
objective = 2.38871e+06
objective = 2.38857e+06
objective = 2.38843e+06
objective = 2.38828e+06
objective = 2.38813e+06
objective = 2.38798e+06
objective = 2.38784e+06
objective = 2.3877e+06
objective = 2.38756e+06
objective = 2.38742e+06
objective = 2.38727e+06
objective = 2.38713e+06
objective = 2.38698e+06
objective = 2.38684e+06
objective = 2.3867e+06
objective = 2.38656e+06
objective = 2.38642e+06
objective = 2.38628e+06
objective = 2.38614e+06
objective = 2.386e+06
objective = 2.38587e+06
objective = 2.38573e+06
objective = 2.38559e+06
objective = 2.38545e+06
objective = 2.38531e+06
objective = 2.38518e+06
objective = 2.38504e+06
objective = 2.38491e+06
objective = 2.38477e+06
objective = 2.38463e+06
objective = 2.3845e+06
objective = 2.38437e+06
objective = 2.38424e+06
objective = 2.38411e+06
objective = 2.38398e+06
objective = 2.38385e+06
objective = 2.38372e+06
objective = 2.38359e+06
objective = 2.38346e+06
objective = 2.38333e+06
objective = 2.3832e+06
objective = 2.38307e+06
objective = 2.38294e+06
objective = 2.38281e+06
objective = 2.38269e+06
objective = 2.38256e+06
objective = 2.38244e+06
objective = 2.38231e+06
objective = 2.38219e+06
objective = 2.38206e+06
objective = 2.38194e+06
objective = 2.38182e+06
objective = 2.3817e+06
objective = 2.38157e+06
objective = 2.38145e+06
objective = 2.38133e+06
objective = 2.38121e+06
objective = 2.38109e+06
objective = 2.38097e+06
objective = 2.38085e+06
objective = 2.38073e+06
objective = 2.38061e+06
objective = 2.3805e+06
objective = 2.38038e+06
objective = 2.38027e+06
objective = 2.38015e+06
objective = 2.38003e+06
objective = 2.37992e+06
objective = 2.3798e+06
objective = 2.37969e+06
objective = 2.37958e+06
objective = 2.37946e+06
objective = 2.37935e+06

Final train loss is  2.37924e+06

Let's plot the objective and see how it decreases.


In [17]:
fig, ax1 = plt.subplots()
plt.title('Linear model performance vs. iterations')
ax1.plot(tLoss, 'b-')
ax1.set_xlabel('Training Iterations')
ax1.set_ylabel('Train Loss')

ax2 = ax1.twinx()
ax2.plot(range(0, len(vLoss)*5, 5), vLoss, 'r.')
ax2.set_ylabel('Validation Classification MAE')

fig.tight_layout()


We can see quite clearly that this model does not overfit, just like the linear model. L2 regularization looks to be more than enough. However, we also can see that after about 150 iterations, validation MAE mostly flatlines, which suggests a place for early stopping.

Let's test the model on our test data now.


In [18]:
test_users = test['user_id']
test_anime = test['anime_id']
test_ratings = test['rating']

In [19]:
indices = test_users.map(lambda x: user_index[x])
bu = indices.map(lambda x: 0.0 if x == -1 else float(cBu[x]))
gu = indices.map(lambda x: np.zeros(k) if x == -1 else cGu[x])
gu = np.vstack(gu.as_matrix()).astype(np.float)

indices = test_anime.map(lambda x: item_index[x])
bi = indices.map(lambda x: 0.0 if x == -1 else float(cBi[x]))
gi = indices.map(lambda x: np.zeros(k) if x == -1 else cGi[x])
gi= np.vstack(gi.as_matrix()).astype(np.float)

g = np.einsum('ij,ji->i', gi, np.transpose(gu)) 

preds = bu + bi + g + float(cAlpha)
MAE = 1.0/len(test) * sum(abs(test_ratings-preds))
print ('MAE on test set is: ', float(MAE))


MAE on test set is:  0.8725550557027907

In [20]:
sess.close()

So, an improvement over the linear model, though not by much. But the important thing is that we are now actually running a recommender system. The linear model would recommend the same shows to everyone, but now we have a legitimate interaction between users and items.

We could likely improve on this by tuning our hyperparameters.