This notebook expands on the latent factor model with features by accounting for anime with genres of "N/A," taken to be anime with unknown genres. I use the Expectation Algorithm (EM) algorithm to come up with an estimate for the best genre representation of the shows with unknown genres. In our model, the "unobserved variables" are the missing genres. We want to find the MLE for our data, but that's complicated by us needing to find a likely configuration of the unobserved variables (i.e. finding a which genres fit the anime).

The expectation (E) step is an inference step. Each anime has a one hot encoding for each genre. If the anime falls under the genre, it is 1 and if it does not, it is 0. The aim in the E step is to find, for each anime without genre information, the probability (0 to 1) it falls under a particular genre. This probability is derived from our latent factor model.

The maximization (M) step simply looks to maximize the likelihood of the data, i.e. it's simply minimizing our normal objective from the LFM with features. Note that for the anime with unknown genres, we are using the inferenced


In [7]:
import matplotlib.pyplot as plt
import matplotlib
import numpy as np

%matplotlib inline
matplotlib.style.use('seaborn')

In [8]:
import pandas as pd

In [9]:
from animerec.data import get_data
users, anime = get_data()

In [10]:
from sklearn.model_selection import train_test_split
train, test = train_test_split(users, test_size = 0.1) #let's split up the dataset into a train and test set.
train, valid = train_test_split(train, test_size = 0.2) #let's split up the dataset into a train and valid set.

In [11]:
from animerec.data import remove_users
train = remove_users(train, 10)

In [12]:
#define validation set
valid_users = valid['user_id']
valid_anime = valid['anime_id']
valid_ratings = valid['rating']

In [7]:
genres = anime.genre.apply(lambda x: str(x).split(","))
genres2 = genres.apply(pd.Series)
all_genres = []
for i in range(len(genres2.columns)):
    genres2[i] = genres2[i].str.strip()
    all_genres += map(lambda s: str(s).strip(), list(genres2[i].unique()))
all_genres = list(np.unique(all_genres))

In [8]:
genre_map = {}
for i, x in enumerate(all_genres): genre_map.update({x:i})
nGenres = len(all_genres)

In [9]:
indexed = anime.set_index('anime_id')
indexed.index.name = None
ind = indexed.ix[train.anime_id.get_values()]

In [10]:
train_genres = ind.genre.map(lambda x: [genre_map[j.strip()] for j in str(x).split(',')])

In [11]:
ohe_genres = np.zeros((len(train_genres), nGenres))
for i, row in enumerate(train_genres): ohe_genres[i, row] = 1

In [27]:
v_ind = indexed.ix[valid.anime_id.get_values()]
valid_genres = v_ind.genre.map(lambda x: [genre_map[j.strip()] for j in str(x).split(',')])

v_ohe_genres = np.zeros((len(valid_genres), nGenres))
for i, row in enumerate(valid_genres): v_ohe_genres[i, row] = 1

In [32]:
test_ind = indexed.ix[test.anime_id.get_values()]
test_genres = test_ind.genre.map(lambda x: [genre_map[j.strip()] for j in str(x).split(',')])

test_ohe_genres = np.zeros((len(test_genres), nGenres))
for i, row in enumerate(test_genres): test_ohe_genres[i, row] = 1

In [12]:
#initialize some local variables
nUsers = len(train.user_id.unique())
nAnime = len(train.anime_id.unique())

# we'll need some data structures in order to vectorize computations
from collections import defaultdict
user_ids = train.user_id
item_ids = train.anime_id

user_index = defaultdict(lambda: -1) # maps a user_id to the index in the bias term.
item_index = defaultdict(lambda: -1) # maps an anime_id to the index in the bias term.

counter = 0
for user in user_ids:
    if user_index[user] == -1:
        user_index[user] = counter
        counter += 1 

counter = 0
for item in item_ids:
    if item_index[item] == -1:
        item_index[item] = counter
        counter += 1

In [13]:
#Terms needed for the latent factors.
k = 3; # hyper-parameter

In [14]:
import tensorflow as tf

In [15]:
y = tf.cast(tf.constant(train['rating'].as_matrix(), shape=[len(train),1]), tf.float32)

In [20]:
def objective(alpha, Bi, Bu, Gi, Gu, Pi, y, lam): #Gi, Gu = gamma_i, gamma_u = latent factors for items, users 
    
    '''
    Like in the linear model, we need to construct the "full" matrix for each (user, item) pair. However, with the
    addition of the latent factor terms, it will waste memory to hold each variable in its own tensor.
    Instead, create one intermediary tensor to represent our prediction ("pred") and simply incrementally add to that
    each additional variable.
    '''
    pred = tf.gather(Bi, train.anime_id.map(lambda _id: item_index[_id]).as_matrix()) #Bi_full
    pred += tf.gather(Bu, train.user_id.map(lambda _id: user_index[_id]).as_matrix()) #Bu_full
    
    Gi_full = tf.gather(Gi, train.anime_id.map(lambda _id: item_index[_id]).as_matrix()) #latent factors of items
    Gu_full = tf.gather(Gu, train.user_id.map(lambda _id: user_index[_id]).as_matrix()) #latent factors of users

    Pi_full = tf.matmul(tf.constant(ohe_genres, dtype=tf.float32), Pi) 
    
    pred += tf.expand_dims(tf.einsum('ij,ji->i', (Gi_full+Pi_full), tf.transpose(Gu_full)), 1)

    pred += tf.tile(alpha, (len(train), 1)) #alpha_full
    obj = tf.reduce_sum(abs(pred-y))


    # regularization
    obj += lam * tf.reduce_sum(Bi**2)
    obj += lam * tf.reduce_sum(Bu**2) 
    obj += lam * tf.reduce_sum(Gi**2) 
    obj += lam * tf.reduce_sum(Gu**2)
    
    return obj

In [17]:
#initialize alpha, Bi, Bu, Gi, Gu 
alpha = tf.Variable(tf.constant([6.9], shape=[1, 1]))
Bi = tf.Variable(tf.constant([0.0]*nAnime, shape=[nAnime, 1]))
Bu = tf.Variable(tf.constant([0.0]*nUsers, shape=[nUsers, 1]))
Gi = tf.Variable(tf.random_normal([nAnime, k], stddev=0.35))
Gu = tf.Variable(tf.random_normal([nUsers, k], stddev=0.35))
Pi = tf.Variable(tf.random_normal([nGenres, k], stddev=0.35))

In [18]:
optimizer = tf.train.AdamOptimizer(0.01)

In [21]:
obj = objective(alpha, Bi, Bu, Gi, Gu, Pi, y, 1)

In [22]:
trainer = optimizer.minimize(obj)

In [29]:
sess = tf.Session()

sess.run(tf.global_variables_initializer())
tLoss = []
vLoss = []
prev = 10e10
for iteration in range(500):
    cvalues = sess.run([trainer, obj])
    print("objective = " + str(cvalues[1]))
    tLoss.append(cvalues[1])
    
    if not iteration % 5:
        cAlpha, cBi, cBu, cGi, cGu, cPi, cLoss = sess.run([alpha, Bi, Bu, Gi, Gu, Pi, obj])
        indices = valid_users.map(lambda x: user_index[x])
        bu = indices.map(lambda x: 0.0 if x == -1 else float(cBu[x]))
        gu = indices.map(lambda x: np.zeros(k) if x == -1 else cGu[x])
        gu = np.vstack(gu.as_matrix()).astype(np.float)


        indices = valid_anime.map(lambda x: item_index[x])
        bi = indices.map(lambda x: 0.0 if x == -1 else float(cBi[x]))
        gi = indices.map(lambda x: np.zeros(k) if x == -1 else cGi[x])
        gi = np.vstack(gi.as_matrix()).astype(np.float)
        
        pi = np.dot(v_ohe_genres, cPi) 


        g = np.einsum('ij,ji->i', (gi+pi), np.transpose(gu)) 

        preds = bu + bi + g + float(cAlpha)
        MAE = 1.0/len(valid) * sum(abs(valid_ratings-preds))
        vLoss.append(MAE)
        if MAE > prev: break
        else: prev = MAE
    
cAlpha, cBi, cBu, cGi, cGu, cLoss = sess.run([alpha, Bi, Bu, Gi, Gu, obj])
print("\nFinal train loss is ", cLoss)


objective = 4.72971e+06
objective = 4.65157e+06
objective = 4.57885e+06
objective = 4.51098e+06
objective = 4.44739e+06
objective = 4.38763e+06
objective = 4.33127e+06
objective = 4.27776e+06
objective = 4.22652e+06
objective = 4.17695e+06
objective = 4.12846e+06
objective = 4.08043e+06
objective = 4.03236e+06
objective = 3.98387e+06
objective = 3.93464e+06
objective = 3.88456e+06
objective = 3.83368e+06
objective = 3.78225e+06
objective = 3.73062e+06
objective = 3.67929e+06
objective = 3.62886e+06
objective = 3.57987e+06
objective = 3.53282e+06
objective = 3.48803e+06
objective = 3.44569e+06
objective = 3.40578e+06
objective = 3.36805e+06
objective = 3.33226e+06
objective = 3.29808e+06
objective = 3.26521e+06
objective = 3.23343e+06
objective = 3.20264e+06
objective = 3.17288e+06
objective = 3.14423e+06
objective = 3.11687e+06
objective = 3.09097e+06
objective = 3.0667e+06
objective = 3.04414e+06
objective = 3.02332e+06
objective = 3.00416e+06
objective = 2.9865e+06
objective = 2.97015e+06
objective = 2.95495e+06
objective = 2.9407e+06
objective = 2.92731e+06
objective = 2.91465e+06
objective = 2.9027e+06
objective = 2.89142e+06
objective = 2.88083e+06
objective = 2.87087e+06
objective = 2.86152e+06
objective = 2.85269e+06
objective = 2.84434e+06
objective = 2.83638e+06
objective = 2.8288e+06
objective = 2.82156e+06
objective = 2.81465e+06
objective = 2.80807e+06
objective = 2.80182e+06
objective = 2.7959e+06
objective = 2.79028e+06
objective = 2.78495e+06
objective = 2.77987e+06
objective = 2.77502e+06
objective = 2.77037e+06
objective = 2.7659e+06
objective = 2.76162e+06
objective = 2.7575e+06
objective = 2.75353e+06
objective = 2.7497e+06
objective = 2.746e+06
objective = 2.7424e+06
objective = 2.73891e+06
objective = 2.73551e+06
objective = 2.73219e+06
objective = 2.72894e+06
objective = 2.72578e+06
objective = 2.72269e+06
objective = 2.71964e+06
objective = 2.71666e+06
objective = 2.71372e+06
objective = 2.71082e+06
objective = 2.70796e+06
objective = 2.70513e+06
objective = 2.70234e+06
objective = 2.69956e+06
objective = 2.69681e+06
objective = 2.69408e+06
objective = 2.69137e+06
objective = 2.68867e+06
objective = 2.68598e+06
objective = 2.6833e+06
objective = 2.68064e+06
objective = 2.67797e+06
objective = 2.67531e+06
objective = 2.67265e+06
objective = 2.67e+06
objective = 2.66735e+06
objective = 2.66471e+06
objective = 2.66206e+06
objective = 2.65941e+06
objective = 2.65677e+06
objective = 2.65412e+06
objective = 2.65148e+06
objective = 2.64883e+06
objective = 2.64618e+06
objective = 2.64353e+06
objective = 2.6409e+06
objective = 2.63826e+06
objective = 2.63562e+06
objective = 2.63299e+06
objective = 2.63036e+06
objective = 2.62773e+06
objective = 2.62511e+06
objective = 2.62248e+06
objective = 2.61987e+06
objective = 2.61726e+06
objective = 2.61466e+06
objective = 2.61207e+06
objective = 2.60949e+06
objective = 2.60692e+06
objective = 2.60437e+06
objective = 2.60183e+06
objective = 2.5993e+06
objective = 2.59678e+06
objective = 2.59429e+06
objective = 2.59181e+06
objective = 2.58936e+06
objective = 2.58691e+06
objective = 2.5845e+06
objective = 2.58211e+06
objective = 2.57974e+06
objective = 2.57739e+06
objective = 2.57507e+06
objective = 2.57278e+06
objective = 2.57052e+06
objective = 2.56827e+06
objective = 2.56607e+06
objective = 2.56389e+06
objective = 2.56173e+06
objective = 2.55961e+06
objective = 2.55752e+06
objective = 2.55546e+06
objective = 2.55342e+06
objective = 2.55142e+06
objective = 2.54944e+06
objective = 2.54751e+06
objective = 2.54559e+06
objective = 2.54371e+06
objective = 2.54186e+06
objective = 2.54004e+06
objective = 2.53826e+06
objective = 2.5365e+06
objective = 2.53478e+06
objective = 2.53308e+06
objective = 2.53142e+06
objective = 2.52979e+06
objective = 2.52818e+06
objective = 2.52661e+06
objective = 2.52507e+06
objective = 2.52355e+06
objective = 2.52206e+06
objective = 2.5206e+06
objective = 2.51917e+06
objective = 2.51776e+06
objective = 2.51639e+06
objective = 2.51504e+06
objective = 2.51371e+06
objective = 2.51241e+06
objective = 2.51114e+06
objective = 2.50989e+06
objective = 2.50866e+06
objective = 2.50746e+06
objective = 2.50628e+06
objective = 2.50512e+06
objective = 2.50398e+06
objective = 2.50288e+06
objective = 2.50178e+06
objective = 2.50071e+06
objective = 2.49966e+06
objective = 2.49863e+06
objective = 2.49762e+06
objective = 2.49663e+06
objective = 2.49566e+06
objective = 2.4947e+06
objective = 2.49377e+06
objective = 2.49286e+06
objective = 2.49196e+06
objective = 2.49108e+06
objective = 2.49022e+06
objective = 2.48938e+06
objective = 2.48855e+06
objective = 2.48773e+06
objective = 2.48694e+06
objective = 2.48616e+06
objective = 2.48539e+06
objective = 2.48464e+06
objective = 2.48391e+06
objective = 2.48318e+06
objective = 2.48248e+06
objective = 2.48178e+06
objective = 2.48109e+06
objective = 2.48042e+06
objective = 2.47976e+06
objective = 2.47912e+06
objective = 2.47848e+06
objective = 2.47785e+06
objective = 2.47724e+06
objective = 2.47664e+06
objective = 2.47604e+06
objective = 2.47546e+06
objective = 2.47488e+06
objective = 2.47432e+06
objective = 2.47377e+06
objective = 2.47323e+06
objective = 2.47269e+06
objective = 2.47217e+06
objective = 2.47165e+06
objective = 2.47115e+06
objective = 2.47065e+06
objective = 2.47016e+06
objective = 2.46968e+06
objective = 2.4692e+06
objective = 2.46874e+06
objective = 2.46828e+06
objective = 2.46782e+06
objective = 2.46738e+06
objective = 2.46694e+06
objective = 2.46652e+06
objective = 2.46609e+06
objective = 2.46568e+06
objective = 2.46527e+06
objective = 2.46486e+06
objective = 2.46446e+06
objective = 2.46408e+06
objective = 2.46369e+06
objective = 2.46331e+06
objective = 2.46293e+06
objective = 2.46256e+06
objective = 2.4622e+06
objective = 2.46184e+06
objective = 2.46149e+06
objective = 2.46114e+06
objective = 2.46079e+06
objective = 2.46045e+06
objective = 2.46012e+06
objective = 2.45978e+06
objective = 2.45946e+06
objective = 2.45913e+06
objective = 2.45881e+06
objective = 2.45849e+06
objective = 2.45818e+06
objective = 2.45788e+06
objective = 2.45757e+06
objective = 2.45727e+06
objective = 2.45697e+06
objective = 2.45668e+06
objective = 2.45638e+06
objective = 2.4561e+06
objective = 2.45581e+06
objective = 2.45554e+06
objective = 2.45526e+06
objective = 2.45498e+06
objective = 2.45471e+06
objective = 2.45444e+06
objective = 2.45418e+06
objective = 2.45391e+06
objective = 2.45365e+06
objective = 2.45339e+06
objective = 2.45314e+06
objective = 2.45288e+06
objective = 2.45263e+06
objective = 2.45238e+06
objective = 2.45213e+06
objective = 2.45189e+06
objective = 2.45164e+06
objective = 2.45141e+06
objective = 2.45117e+06
objective = 2.45093e+06
objective = 2.45069e+06
objective = 2.45046e+06
objective = 2.45023e+06
objective = 2.45e+06
objective = 2.44978e+06
objective = 2.44955e+06
objective = 2.44932e+06
objective = 2.4491e+06
objective = 2.44888e+06
objective = 2.44867e+06
objective = 2.44845e+06
objective = 2.44824e+06
objective = 2.44802e+06
objective = 2.44781e+06
objective = 2.4476e+06
objective = 2.4474e+06
objective = 2.44718e+06
objective = 2.44698e+06
objective = 2.44678e+06
objective = 2.44658e+06
objective = 2.44637e+06
objective = 2.44617e+06
objective = 2.44598e+06
objective = 2.44577e+06
objective = 2.44558e+06
objective = 2.44539e+06
objective = 2.44519e+06
objective = 2.445e+06
objective = 2.44481e+06
objective = 2.44462e+06
objective = 2.44443e+06
objective = 2.44424e+06
objective = 2.44405e+06
objective = 2.44387e+06
objective = 2.44368e+06
objective = 2.4435e+06
objective = 2.44332e+06
objective = 2.44313e+06
objective = 2.44295e+06
objective = 2.44277e+06
objective = 2.44259e+06
objective = 2.44242e+06
objective = 2.44224e+06
objective = 2.44206e+06
objective = 2.44189e+06
objective = 2.44171e+06
objective = 2.44154e+06
objective = 2.44136e+06
objective = 2.44119e+06
objective = 2.44102e+06
objective = 2.44085e+06
objective = 2.44068e+06
objective = 2.44052e+06
objective = 2.44034e+06
objective = 2.44018e+06
objective = 2.44001e+06
objective = 2.43984e+06
objective = 2.43968e+06
objective = 2.43952e+06
objective = 2.43936e+06
objective = 2.43919e+06
objective = 2.43903e+06
objective = 2.43887e+06
objective = 2.43871e+06
objective = 2.43855e+06
objective = 2.43839e+06
objective = 2.43823e+06
objective = 2.43807e+06
objective = 2.43792e+06
objective = 2.43776e+06
objective = 2.4376e+06
objective = 2.43744e+06
objective = 2.43729e+06
objective = 2.43714e+06
objective = 2.43699e+06
objective = 2.43683e+06
objective = 2.43668e+06
objective = 2.43653e+06
objective = 2.43638e+06
objective = 2.43623e+06
objective = 2.43607e+06
objective = 2.43592e+06
objective = 2.43577e+06
objective = 2.43563e+06
objective = 2.43548e+06
objective = 2.43533e+06
objective = 2.43518e+06
objective = 2.43504e+06
objective = 2.43489e+06
objective = 2.43474e+06
objective = 2.4346e+06
objective = 2.43446e+06
objective = 2.43431e+06
objective = 2.43417e+06
objective = 2.43402e+06
objective = 2.43388e+06
objective = 2.43374e+06
objective = 2.4336e+06
objective = 2.43345e+06
objective = 2.43332e+06
objective = 2.43317e+06
objective = 2.43304e+06
objective = 2.4329e+06
objective = 2.43275e+06
objective = 2.43262e+06
objective = 2.43248e+06
objective = 2.43234e+06
objective = 2.43221e+06
objective = 2.43207e+06
objective = 2.43193e+06
objective = 2.4318e+06
objective = 2.43166e+06
objective = 2.43153e+06
objective = 2.43139e+06
objective = 2.43126e+06
objective = 2.43113e+06
objective = 2.43099e+06
objective = 2.43086e+06
objective = 2.43073e+06
objective = 2.43059e+06
objective = 2.43046e+06
objective = 2.43033e+06
objective = 2.4302e+06
objective = 2.43007e+06
objective = 2.42994e+06
objective = 2.42981e+06
objective = 2.42968e+06
objective = 2.42955e+06
objective = 2.42943e+06
objective = 2.42929e+06
objective = 2.42917e+06
objective = 2.42904e+06
objective = 2.42891e+06
objective = 2.42879e+06
objective = 2.42866e+06
objective = 2.42853e+06
objective = 2.42841e+06
objective = 2.42829e+06
objective = 2.42816e+06
objective = 2.42804e+06
objective = 2.42791e+06
objective = 2.42779e+06
objective = 2.42767e+06
objective = 2.42754e+06
objective = 2.42742e+06
objective = 2.4273e+06
objective = 2.42718e+06
objective = 2.42706e+06
objective = 2.42693e+06
objective = 2.42681e+06
objective = 2.42669e+06
objective = 2.42657e+06
objective = 2.42645e+06
objective = 2.42633e+06
objective = 2.42621e+06
objective = 2.42609e+06
objective = 2.42597e+06
objective = 2.42585e+06
objective = 2.42574e+06
objective = 2.42563e+06
objective = 2.42551e+06
objective = 2.42539e+06
objective = 2.42527e+06
objective = 2.42516e+06
objective = 2.42505e+06
objective = 2.42493e+06
objective = 2.42481e+06
objective = 2.4247e+06
objective = 2.42458e+06
objective = 2.42447e+06
objective = 2.42436e+06
objective = 2.42424e+06
objective = 2.42413e+06
objective = 2.42402e+06
objective = 2.42390e+06
objective = 2.42379e+06
objective = 2.42369e+06
objective = 2.42357e+06
objective = 2.42346e+06
objective = 2.42336e+06
objective = 2.42324e+06
objective = 2.42313e+06
objective = 2.42302e+06
objective = 2.42291e+06
objective = 2.42281e+06
objective = 2.4227e+06
objective = 2.42259e+06
objective = 2.42248e+06
objective = 2.42237e+06
objective = 2.42227e+06
objective = 2.42215e+06
objective = 2.42204e+06
objective = 2.42194e+06
objective = 2.42184e+06
objective = 2.42173e+06
objective = 2.42162e+06
objective = 2.42152e+06
objective = 2.42141e+06
objective = 2.42131e+06
objective = 2.4212e+06
objective = 2.4211e+06
objective = 2.421e+06
objective = 2.42089e+06
objective = 2.42079e+06
objective = 2.42068e+06
objective = 2.42058e+06
objective = 2.42048e+06
objective = 2.42038e+06
objective = 2.42028e+06
objective = 2.42017e+06
objective = 2.42007e+06
objective = 2.41997e+06
objective = 2.41987e+06
objective = 2.41976e+06
objective = 2.41966e+06
objective = 2.41957e+06
objective = 2.41947e+06
objective = 2.41937e+06
objective = 2.41927e+06
objective = 2.41917e+06
objective = 2.41907e+06

Final train loss is  2.41897e+06

Let's plot the objective and see how it decreases.


In [30]:
fig, ax1 = plt.subplots()
plt.title('Linear model performance vs. iterations')
ax1.plot(tLoss, 'b-')
ax1.set_xlabel('Training Iterations')
ax1.set_ylabel('Train Loss')

ax2 = ax1.twinx()
ax2.plot(range(0, len(vLoss)*5, 5), vLoss, 'r.')
ax2.set_ylabel('Validation Classification MAE')

fig.tight_layout()


We can see quite clearly that this model does not overfit, just like the linear model. L2 regularization looks to be more than enough. However, we also can see that after about 150 iterations, validation MAE mostly flatlines, which suggests a place for early stopping.

Let's test the model on our test data now.


In [33]:
test_users = test['user_id']
test_anime = test['anime_id']
test_ratings = test['rating']

In [34]:
indices = test_users.map(lambda x: user_index[x])
bu = indices.map(lambda x: 0.0 if x == -1 else float(cBu[x]))
gu = indices.map(lambda x: np.zeros(k) if x == -1 else cGu[x])
gu = np.vstack(gu.as_matrix()).astype(np.float)

indices = test_anime.map(lambda x: item_index[x])
bi = indices.map(lambda x: 0.0 if x == -1 else float(cBi[x]))
gi = indices.map(lambda x: np.zeros(k) if x == -1 else cGi[x])
gi= np.vstack(gi.as_matrix()).astype(np.float)
pi = np.dot(test_ohe_genres, cPi) 


g = np.einsum('ij,ji->i', (gi+pi), np.transpose(gu)) 

preds = bu + bi + g + float(cAlpha)
MAE = 1.0/len(test) * sum(abs(test_ratings-preds))
print ('MAE on test set is: ', float(MAE))


MAE on test set is:  0.8685068394580536

In [ ]:
sess.close()

So, an improvement over the basic latent factor model. What's cool about this model is that in addition