This notebook will use a simple model to predict ratings for a user, anime pair: rating(user, item) = offset + user_bias + item_bias

Put simply, the predicted rating is the sum of three terms. The user_bias refers to whether a user is predisposed to rating higher or lower. For example, a user who gives all perfect scores will have a highly positive user_bias term. The item_bias is similar, except it refers to an anime. So a highly rated anime will have a highly positive anime_bias term. Finally, there's an offset to adjust for the general anime score. This will be initialized to 6.9 (the mean rating in the exploratory analysis) but since this is being run for only users with 10 or more ratings, that number is only a rough guess. Note that this entire program is optimized for users with 10 or more ratings, but the validation and test sets may include users with fewer ratings.

The objective function will be mean absolute error (MAE) with L2 regularization. This model is implemented in tensorflow, just for the sake of practice.

Broadly speaking, this code is organized into two parts.

In the first part, I do some basic data prep:

1) Import some packages for plotting 2) Use my animerec package to get data 3) split the data into train, test, and validation sets 4) Define some data structures that'll be used to vectorize computations.

In the second part, I use tensorflow to implement this simple model, using a validation set for early stopping. I then see the performance on the test set.


In [1]:
import matplotlib.pyplot as plt
import matplotlib
%matplotlib inline
matplotlib.style.use('seaborn')

In [2]:
from animerec.data import get_data
users, anime = get_data()

In [3]:
from sklearn.model_selection import train_test_split
train, test = train_test_split(users, test_size = 0.1) #let's split up the dataset into a train and test set.
train, valid = train_test_split(train, test_size = 0.2) #let's split up the dataset into a train and valid set.

In [4]:
from animerec.data import remove_users
train = remove_users(train, 10)

In [5]:
#define validation set
valid_users = valid['user_id']
valid_anime = valid['anime_id']
valid_ratings = valid['rating']

In [6]:
#initialize some local variables
nUsers = len(train.user_id.unique())
nAnime = len(train.anime_id.unique())

# we'll need some data structures in order to vectorize computations
from collections import defaultdict
user_ids = train.user_id
item_ids = train.anime_id

user_index = defaultdict(lambda: -1) # maps a user_id to the index in the bias term.
item_index = defaultdict(lambda: -1) # maps an anime_id to the index in the bias term.

counter = 0
for user in user_ids:
    if user_index[user] == -1:
        user_index[user] = counter
        counter += 1 

counter = 0
for item in item_ids:
    if item_index[item] == -1:
        item_index[item] = counter
        counter += 1

In [7]:
import tensorflow as tf

In [8]:
y = tf.cast(tf.constant(train['rating'].as_matrix(), shape=[len(train),1]), tf.float32)

In [9]:
def objective(alpha, Bi, Bu, y, lam):
    #construct the full items and user matrix.
    Bi_full = tf.gather(Bi, train.anime_id.map(lambda _id: item_index[_id]).as_matrix())
    Bu_full = tf.gather(Bu, train.user_id.map(lambda _id: user_index[_id]).as_matrix())
    alpha_full = tf.tile(alpha, (len(train), 1))
    
    return tf.reduce_sum(abs(alpha_full+Bi_full+Bu_full-y)) + lam * (tf.reduce_sum(Bi**2) + tf.reduce_sum(Bu**2))

In [10]:
#initialize alpha, Bi, Bu
alpha = tf.Variable(tf.constant([6.9], shape=[1, 1]))
Bi = tf.Variable(tf.constant([0.0]*nAnime, shape=[nAnime, 1]))
Bu = tf.Variable(tf.constant([0.0]*nUsers, shape=[nUsers, 1]))

In [11]:
optimizer = tf.train.AdamOptimizer(0.01)

In [12]:
obj = objective(alpha, Bi, Bu, y, 1)

In [13]:
trainer = optimizer.minimize(obj)

In [14]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())
tLoss = []
vLoss = []
prev = 10e10
for iteration in range(500):
    cvalues = sess.run([trainer, obj])
    print("objective = " + str(cvalues[1]))
    tLoss.append(cvalues[1])
    
    if not iteration % 5:
        cAlpha, cBi, cBu, cLoss = sess.run([alpha, Bi, Bu, obj])
        indices = valid_users.map(lambda x: user_index[x])
        bu = indices.map(lambda x: 0.0 if x == -1 else float(cBu[x]))
        indices = valid_anime.map(lambda x: item_index[x])
        bi = indices.map(lambda x: 0.0 if x == -1 else float(cBi[x]))
        preds = bu + bi + float(cAlpha)
        MAE = 1.0/len(valid) * sum(abs(valid_ratings-preds))
        vLoss.append(MAE)
        if MAE > prev: break
        else: prev = MAE
    
cAlpha, cBi, cBu, cLoss = sess.run([alpha, Bi, Bu, obj])
print("\nFinal train loss is ", cLoss)


objective = 3.58663e+06
objective = 3.53622e+06
objective = 3.48537e+06
objective = 3.4368e+06
objective = 3.40849e+06
objective = 3.39107e+06
objective = 3.37123e+06
objective = 3.35283e+06
objective = 3.3336e+06
objective = 3.31469e+06
objective = 3.29555e+06
objective = 3.27648e+06
objective = 3.25736e+06
objective = 3.2384e+06
objective = 3.21947e+06
objective = 3.20058e+06
objective = 3.18168e+06
objective = 3.16278e+06
objective = 3.14389e+06
objective = 3.125e+06
objective = 3.1061e+06
objective = 3.0872e+06
objective = 3.0683e+06
objective = 3.04941e+06
objective = 3.0305e+06
objective = 3.01159e+06
objective = 2.99267e+06
objective = 2.97375e+06
objective = 2.95482e+06
objective = 2.93587e+06
objective = 2.91692e+06
objective = 2.89794e+06
objective = 2.87895e+06
objective = 2.85995e+06
objective = 2.84092e+06
objective = 2.82188e+06
objective = 2.80283e+06
objective = 2.78373e+06
objective = 2.76464e+06
objective = 2.74551e+06
objective = 2.72636e+06
objective = 2.70747e+06
objective = 2.69001e+06
objective = 2.67471e+06
objective = 2.66149e+06
objective = 2.65013e+06
objective = 2.64032e+06
objective = 2.63171e+06
objective = 2.624e+06
objective = 2.617e+06
objective = 2.61052e+06
objective = 2.60439e+06
objective = 2.59845e+06
objective = 2.59261e+06
objective = 2.58679e+06
objective = 2.58095e+06
objective = 2.57506e+06
objective = 2.56911e+06
objective = 2.5631e+06
objective = 2.55705e+06
objective = 2.55097e+06
objective = 2.54489e+06
objective = 2.53883e+06
objective = 2.53282e+06
objective = 2.52685e+06
objective = 2.52096e+06
objective = 2.51516e+06
objective = 2.50945e+06
objective = 2.50385e+06
objective = 2.49838e+06
objective = 2.49303e+06
objective = 2.4878e+06
objective = 2.48271e+06
objective = 2.47774e+06
objective = 2.47288e+06
objective = 2.46812e+06
objective = 2.46347e+06
objective = 2.45891e+06
objective = 2.45443e+06
objective = 2.45002e+06
objective = 2.44566e+06
objective = 2.44137e+06
objective = 2.43712e+06
objective = 2.43292e+06
objective = 2.42875e+06
objective = 2.42461e+06
objective = 2.4205e+06
objective = 2.41642e+06
objective = 2.41237e+06
objective = 2.40835e+06
objective = 2.40435e+06
objective = 2.40039e+06
objective = 2.39647e+06
objective = 2.3926e+06
objective = 2.38877e+06
objective = 2.385e+06
objective = 2.38127e+06
objective = 2.37762e+06
objective = 2.37402e+06
objective = 2.37049e+06
objective = 2.36703e+06
objective = 2.36364e+06
objective = 2.36032e+06
objective = 2.35707e+06
objective = 2.35388e+06
objective = 2.35077e+06
objective = 2.34773e+06
objective = 2.34476e+06
objective = 2.34185e+06
objective = 2.33901e+06
objective = 2.33624e+06
objective = 2.33353e+06
objective = 2.33088e+06
objective = 2.3283e+06
objective = 2.32577e+06
objective = 2.32331e+06
objective = 2.32090e+06
objective = 2.31856e+06
objective = 2.31626e+06
objective = 2.31402e+06
objective = 2.31184e+06
objective = 2.3097e+06
objective = 2.30761e+06
objective = 2.30556e+06
objective = 2.30356e+06
objective = 2.30161e+06
objective = 2.29969e+06
objective = 2.29782e+06
objective = 2.29598e+06
objective = 2.29419e+06
objective = 2.29243e+06
objective = 2.2907e+06
objective = 2.28901e+06
objective = 2.28734e+06
objective = 2.28572e+06
objective = 2.28412e+06
objective = 2.28255e+06
objective = 2.281e+06
objective = 2.2795e+06
objective = 2.27801e+06
objective = 2.27655e+06
objective = 2.27512e+06
objective = 2.27372e+06
objective = 2.27235e+06
objective = 2.271e+06
objective = 2.26968e+06
objective = 2.2684e+06
objective = 2.26713e+06
objective = 2.2659e+06
objective = 2.26469e+06
objective = 2.2635e+06
objective = 2.26234e+06
objective = 2.2612e+06
objective = 2.26009e+06
objective = 2.259e+06
objective = 2.25795e+06
objective = 2.25691e+06
objective = 2.2559e+06
objective = 2.25491e+06
objective = 2.25394e+06
objective = 2.253e+06
objective = 2.25208e+06
objective = 2.25118e+06
objective = 2.2503e+06
objective = 2.24944e+06
objective = 2.24861e+06
objective = 2.24779e+06
objective = 2.247e+06
objective = 2.24622e+06
objective = 2.24547e+06
objective = 2.24473e+06
objective = 2.24401e+06
objective = 2.24331e+06
objective = 2.24262e+06
objective = 2.24196e+06
objective = 2.24131e+06
objective = 2.24068e+06
objective = 2.24006e+06
objective = 2.23946e+06
objective = 2.23887e+06
objective = 2.2383e+06
objective = 2.23775e+06
objective = 2.23721e+06
objective = 2.23669e+06
objective = 2.23617e+06
objective = 2.23567e+06
objective = 2.23519e+06
objective = 2.23471e+06
objective = 2.23425e+06
objective = 2.23381e+06
objective = 2.23336e+06
objective = 2.23293e+06
objective = 2.23252e+06
objective = 2.23211e+06
objective = 2.23171e+06
objective = 2.23133e+06
objective = 2.23095e+06
objective = 2.23058e+06
objective = 2.23022e+06
objective = 2.22987e+06
objective = 2.22952e+06
objective = 2.22919e+06
objective = 2.22886e+06
objective = 2.22854e+06
objective = 2.22823e+06
objective = 2.22792e+06
objective = 2.22762e+06
objective = 2.22734e+06
objective = 2.22705e+06
objective = 2.22677e+06
objective = 2.2265e+06
objective = 2.22624e+06
objective = 2.22598e+06
objective = 2.22573e+06
objective = 2.22548e+06
objective = 2.22523e+06
objective = 2.22499e+06
objective = 2.22476e+06
objective = 2.22453e+06
objective = 2.22431e+06
objective = 2.22409e+06
objective = 2.22388e+06
objective = 2.22367e+06
objective = 2.22346e+06
objective = 2.22326e+06
objective = 2.22306e+06
objective = 2.22287e+06
objective = 2.22268e+06
objective = 2.22249e+06
objective = 2.22231e+06
objective = 2.22213e+06
objective = 2.22195e+06
objective = 2.22178e+06
objective = 2.22162e+06
objective = 2.22145e+06
objective = 2.22128e+06
objective = 2.22112e+06
objective = 2.22097e+06
objective = 2.22081e+06
objective = 2.22066e+06
objective = 2.22051e+06
objective = 2.22037e+06
objective = 2.22022e+06
objective = 2.22008e+06
objective = 2.21994e+06
objective = 2.21981e+06
objective = 2.21967e+06
objective = 2.21954e+06
objective = 2.21941e+06
objective = 2.21928e+06
objective = 2.21916e+06
objective = 2.21904e+06
objective = 2.21891e+06
objective = 2.21879e+06
objective = 2.21868e+06
objective = 2.21856e+06
objective = 2.21845e+06
objective = 2.21834e+06
objective = 2.21823e+06
objective = 2.21813e+06
objective = 2.21802e+06
objective = 2.21792e+06
objective = 2.21781e+06
objective = 2.21772e+06
objective = 2.21762e+06
objective = 2.21752e+06
objective = 2.21743e+06
objective = 2.21734e+06
objective = 2.21725e+06
objective = 2.21716e+06
objective = 2.21707e+06
objective = 2.21698e+06
objective = 2.2169e+06
objective = 2.21681e+06
objective = 2.21673e+06
objective = 2.21665e+06
objective = 2.21657e+06
objective = 2.2165e+06
objective = 2.21642e+06
objective = 2.21634e+06
objective = 2.21626e+06
objective = 2.21619e+06
objective = 2.21612e+06
objective = 2.21605e+06
objective = 2.21598e+06
objective = 2.21591e+06
objective = 2.21585e+06
objective = 2.21578e+06
objective = 2.21571e+06
objective = 2.21565e+06
objective = 2.21558e+06
objective = 2.21552e+06
objective = 2.21546e+06
objective = 2.2154e+06
objective = 2.21534e+06
objective = 2.21528e+06
objective = 2.21522e+06
objective = 2.21516e+06
objective = 2.21511e+06
objective = 2.21505e+06
objective = 2.21499e+06
objective = 2.21494e+06
objective = 2.21489e+06
objective = 2.21484e+06
objective = 2.21479e+06
objective = 2.21473e+06
objective = 2.21468e+06
objective = 2.21462e+06
objective = 2.21458e+06
objective = 2.21453e+06
objective = 2.21448e+06
objective = 2.21443e+06
objective = 2.21438e+06
objective = 2.21434e+06
objective = 2.21429e+06
objective = 2.21425e+06
objective = 2.2142e+06
objective = 2.21416e+06
objective = 2.21412e+06
objective = 2.21407e+06
objective = 2.21403e+06
objective = 2.21399e+06
objective = 2.21395e+06
objective = 2.21391e+06
objective = 2.21387e+06
objective = 2.21382e+06
objective = 2.21378e+06
objective = 2.21373e+06
objective = 2.2137e+06
objective = 2.21366e+06
objective = 2.21362e+06
objective = 2.21359e+06
objective = 2.21355e+06
objective = 2.21351e+06
objective = 2.21348e+06
objective = 2.21344e+06
objective = 2.2134e+06
objective = 2.21336e+06
objective = 2.21332e+06
objective = 2.21329e+06
objective = 2.21326e+06
objective = 2.21323e+06
objective = 2.2132e+06
objective = 2.21317e+06
objective = 2.21313e+06
objective = 2.21309e+06
objective = 2.21305e+06
objective = 2.21302e+06
objective = 2.21298e+06
objective = 2.21295e+06
objective = 2.21292e+06
objective = 2.21289e+06
objective = 2.21287e+06
objective = 2.21284e+06
objective = 2.21281e+06
objective = 2.21277e+06
objective = 2.21274e+06
objective = 2.21271e+06
objective = 2.21268e+06
objective = 2.21266e+06
objective = 2.21263e+06
objective = 2.2126e+06
objective = 2.21257e+06
objective = 2.21254e+06
objective = 2.2125e+06
objective = 2.21247e+06
objective = 2.21244e+06
objective = 2.21242e+06
objective = 2.2124e+06
objective = 2.21237e+06
objective = 2.21234e+06
objective = 2.21232e+06
objective = 2.21229e+06
objective = 2.21226e+06
objective = 2.21222e+06
objective = 2.2122e+06
objective = 2.21217e+06
objective = 2.21216e+06
objective = 2.21213e+06
objective = 2.21211e+06
objective = 2.21208e+06
objective = 2.21204e+06
objective = 2.21202e+06
objective = 2.21199e+06
objective = 2.21197e+06
objective = 2.21195e+06
objective = 2.21193e+06
objective = 2.21191e+06
objective = 2.21188e+06
objective = 2.21185e+06
objective = 2.21182e+06
objective = 2.2118e+06
objective = 2.21177e+06
objective = 2.21175e+06
objective = 2.21173e+06
objective = 2.21171e+06
objective = 2.21169e+06
objective = 2.21166e+06
objective = 2.21163e+06
objective = 2.21161e+06
objective = 2.21158e+06
objective = 2.21156e+06
objective = 2.21154e+06
objective = 2.21153e+06
objective = 2.2115e+06
objective = 2.21148e+06
objective = 2.21145e+06
objective = 2.21142e+06
objective = 2.2114e+06
objective = 2.21138e+06
objective = 2.21137e+06
objective = 2.21135e+06
objective = 2.21133e+06
objective = 2.2113e+06
objective = 2.21127e+06
objective = 2.21125e+06
objective = 2.21123e+06
objective = 2.21121e+06
objective = 2.2112e+06
objective = 2.21119e+06
objective = 2.21116e+06
objective = 2.21114e+06
objective = 2.21112e+06
objective = 2.21109e+06
objective = 2.21107e+06
objective = 2.21105e+06
objective = 2.21103e+06
objective = 2.21101e+06
objective = 2.211e+06
objective = 2.21097e+06
objective = 2.21095e+06
objective = 2.21092e+06
objective = 2.2109e+06
objective = 2.21089e+06
objective = 2.21088e+06
objective = 2.21087e+06
objective = 2.21085e+06
objective = 2.21083e+06
objective = 2.2108e+06
objective = 2.21077e+06
objective = 2.21075e+06
objective = 2.21073e+06
objective = 2.21072e+06
objective = 2.21071e+06
objective = 2.21069e+06
objective = 2.21067e+06
objective = 2.21065e+06
objective = 2.21062e+06
objective = 2.21061e+06
objective = 2.21059e+06
objective = 2.21058e+06
objective = 2.21057e+06
objective = 2.21055e+06
objective = 2.21053e+06
objective = 2.2105e+06
objective = 2.21048e+06
objective = 2.21047e+06
objective = 2.21046e+06
objective = 2.21044e+06
objective = 2.21042e+06
objective = 2.2104e+06
objective = 2.21037e+06
objective = 2.21035e+06
objective = 2.21034e+06
objective = 2.21033e+06
objective = 2.21032e+06
objective = 2.21031e+06
objective = 2.21029e+06
objective = 2.21027e+06
objective = 2.21024e+06
objective = 2.21023e+06
objective = 2.21022e+06
objective = 2.2102e+06
objective = 2.21019e+06
objective = 2.21017e+06
objective = 2.21015e+06
objective = 2.21012e+06
objective = 2.21011e+06
objective = 2.2101e+06
objective = 2.21009e+06
objective = 2.21008e+06
objective = 2.21007e+06
objective = 2.21004e+06
objective = 2.21001e+06
objective = 2.21e+06
objective = 2.20999e+06
objective = 2.20998e+06
objective = 2.20996e+06
objective = 2.20994e+06
objective = 2.20992e+06
objective = 2.2099e+06
objective = 2.20988e+06
objective = 2.20986e+06
objective = 2.20985e+06
objective = 2.20984e+06
objective = 2.20983e+06
objective = 2.20982e+06
objective = 2.2098e+06
objective = 2.20978e+06
objective = 2.20976e+06

Final train loss is  2.20975e+06

Let's plot the objective and see how it decreases.


In [15]:
fig, ax1 = plt.subplots()
plt.title('Linear model performance vs. iterations')
ax1.plot(tLoss, 'b-')
ax1.set_xlabel('Training Iterations')
ax1.set_ylabel('Train Loss')

ax2 = ax1.twinx()
ax2.plot(range(0, len(vLoss)*5, 5), vLoss, 'r.')
ax2.set_ylabel('Validation Classification MAE')

fig.tight_layout()


So by 50 iterations, the model hits a bend and from there we see incremental improvement. We can see quite clearly that this model does not overfit. It's a very simple model (that likely underfits the data) and we use L2 regularization, so to that end it's not surprising at all. But it's good to confirm anyway.

Let's test the model on our test data now.


In [16]:
test_users = test['user_id']
test_anime = test['anime_id']
test_ratings = test['rating']

In [17]:
indices = test_users.map(lambda x: user_index[x])
bu = indices.map(lambda x: 0.0 if x == -1 else float(cBu[x]))
indices = test_anime.map(lambda x: item_index[x])
bi = indices.map(lambda x: 0.0 if x == -1 else float(cBi[x]))
preds = bu + bi + float(cAlpha)
MAE = 1.0/len(test) * sum(abs(test_ratings-preds))
print ('MAE on test set is: ', float(MAE))


MAE on test set is:  0.909247177087559

In [18]:
sess.close()

So a simple linear model, with no real user/item interaction gets us within 0.9 on average of a typical anime. That's pretty good as a baseline. What's also impressive is that this test set includes users who have seen maybe even one anime, which the model did not even train on. In the next notebook, we'll be using a simple latent factor method and seeing how much that improves performance. Then, we'll consider using features in the latent factor method.