In [ ]:
%matplotlib inline
import os, sys, time, gzip
import pickle as pkl
import numpy as np
from scipy.sparse import lil_matrix, csr_matrix, issparse
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
In [ ]:
from tools import calc_metrics, diversity, pairwise_distance_hamming, softmax
In [ ]:
np.seterr(all='raise')
In [ ]:
TOPs = [5, 10, 20, 30, 50, 100, 200, 300, 500, 700, 1000]
In [ ]:
datasets = ['aotm2011', '30music']
In [ ]:
dix = 1
dataset_name = datasets[dix]
dataset_name
In [ ]:
data_dir = 'data/%s/coldstart/setting1' % dataset_name
X_trndev = pkl.load(gzip.open(os.path.join(data_dir, 'X_trndev.pkl.gz'), 'rb'))
Y_trndev = pkl.load(gzip.open(os.path.join(data_dir, 'Y_trndev.pkl.gz'), 'rb'))
X_test = pkl.load(gzip.open(os.path.join(data_dir, 'X_test.pkl.gz'), 'rb'))
Y_test = pkl.load(gzip.open(os.path.join(data_dir, 'Y_test.pkl.gz'), 'rb'))
In [ ]:
songs1 = pkl.load(gzip.open(os.path.join(data_dir, 'songs_train_dev_test_s1.pkl.gz'), 'rb'))
train_songs = songs1['train_song_set']
dev_songs = songs1['dev_song_set']
test_songs = songs1['test_song_set']
In [ ]:
song2index_trndev = {sid: ix for ix, (sid, _) in enumerate(train_songs + dev_songs)}
song2index_test = {sid: ix for ix, (sid, _) in enumerate(test_songs)}
index2song_test = {ix: sid for ix, (sid, _) in enumerate(test_songs)}
In [ ]:
_song2artist = pkl.load(gzip.open('data/msd/song2artist.pkl.gz', 'rb'))
song2artist = {sid: _song2artist[sid] for sid, _ in train_songs + dev_songs + test_songs if sid in _song2artist}
In [ ]:
all_playlists = pkl.load(gzip.open(os.path.join(data_dir, 'playlists_s1.pkl.gz'), 'rb'))
In [ ]:
artist2pop = dict()
test_songset = set(test_songs)
for pl, _ in all_playlists:
for sid in [sid for sid in pl if sid not in test_songset]:
if sid in song2artist:
aid = song2artist[sid]
try:
artist2pop[aid] += 1
except KeyError:
artist2pop[aid] = 1
In [ ]:
song2genre = pkl.load(gzip.open('data/msd/song2genre.pkl.gz', 'rb'))
In [ ]:
cliques_all = pkl.load(gzip.open(os.path.join(data_dir, 'cliques_trndev.pkl.gz'), 'rb'))
In [ ]:
U = len(cliques_all)
pl2u = np.zeros(Y_test.shape[1], dtype=np.int32)
for u in range(U):
clq = cliques_all[u]
pl2u[clq] = u
In [ ]:
song2pop = pkl.load(gzip.open(os.path.join(data_dir, 'song2pop.pkl.gz'), 'rb'))
In [ ]:
Y_test.shape
In [ ]:
X_trndev.shape
In [ ]:
Y_trndev.shape
Let $S \in \mathbb{R}^{M \times D}, P \in \mathbb{R}^{N \times D}, Y \in \mathbb{R}^{M \times N}$ be the latent factors of songs and playlists, respectively.
The optimisation objective: $ \begin{aligned} J = \sum{m=1}^M \sum{n=1}^N \left( y_{m,n} - \mathbf{s}_m^\top \mathbf{p}_n \right)^2
+ C \left( \sum_{m=1}^M \mathbf{s}_m^\top \mathbf{s}_m + \sum_{n=1}^N \mathbf{p}_n^\top \mathbf{p}_n \right)
\end{aligned}
$
Use alternating least squares optimisation method:
In [ ]:
M, N = Y_trndev.shape
D = 80
C = 1
n_sweeps = 200
np.random.seed(0)
S = np.random.rand(M, D)
P = np.random.rand(N, D)
# alternating least squares
for sweep in range(n_sweeps):
# fix S, optimise P
SS = np.dot(S.T, S) # D by D
np.fill_diagonal(SS, C + SS.diagonal())
P_new = np.dot(Y_trndev.transpose().dot(S), np.linalg.inv(SS).T) # N by D
pdiff = (P_new - P).ravel()
P = P_new
# fix P, optimise S
PP = np.dot(P.T, P) # D by D
np.fill_diagonal(PP, C + PP.diagonal())
S_new = np.dot(Y_trndev.dot(P), np.linalg.inv(PP).T) # M by D
sdiff = (S_new - S).ravel()
S = S_new
print('P diff: {:8.6f}, S diff: {:8.6f}'.format(np.sqrt(pdiff.dot(pdiff)), np.sqrt(sdiff.dot(sdiff))))
Sanity check, RMSE
In [ ]:
Y_trndev_coo = Y_trndev.tocoo()
In [ ]:
loss = 0.
for row, col in tqdm(zip(Y_trndev_coo.row, Y_trndev_coo.col)):
diff = S[row, :].dot(P[col, :]) - 1
loss += diff * diff
loss /= Y_trndev_coo.nnz
print('RMSE:', np.sqrt(loss))
In [ ]:
rps = []
hitrates = {top: [] for top in TOPs}
aucs = []
spreads = []
novelties = {top: dict() for top in TOPs}
artist_diversities = {top: [] for top in TOPs}
genre_diversities = {top: [] for top in TOPs}
np.random.seed(0)
npos = Y_test.sum(axis=0).A.reshape(-1)
assert Y_test.shape[0] == len(test_songs)
for j in range(Y_test.shape[1]):
if (j+1) % 100 == 0:
sys.stdout.write('\r%d / %d' % (j+1, Y_test.shape[1]))
sys.stdout.flush()
if npos[j] < 1:
continue
y_true = Y_test[:, j].A.reshape(-1)
y_pred = np.zeros(len(test_songs))
for ix in range(len(test_songs)):
sid = index2song_test[ix]
# map song feature to song latent factor
# score (song, playlist) pair by the dot product of their latent factors
rp, hr_dict, auc = calc_metrics(y_true, y_pred, tops=TOPs)
rps.append(rp)
for top in TOPs:
hitrates[top].append(hr_dict[top])
aucs.append(auc)
# spread
y_pred_prob = softmax(y_pred)
spreads.append(-np.dot(y_pred_prob, np.log(y_pred_prob)))
# novelty
sortix = np.argsort(-y_pred)
u = pl2u[j]
for top in TOPs:
nov = np.mean([-np.log2(song2pop[index2song_test[ix]]) for ix in sortix[:top]])
try:
novelties[top][u].append(nov)
except KeyError:
novelties[top][u] = [nov]
# artist/genre diversity
for top in TOPs:
artist_vec = np.array([song2artist[index2song_test[ix]] for ix in sortix[:top]])
genre_vec = np.array([song2genre[index2song_test[ix]] if index2song_test[ix] in song2genre \
else str(np.random.rand()) for ix in sortix[:top]])
artist_diversities[top].append( diversity(artist_vec) )
genre_diversities[top].append( diversity(genre_vec) )
print('\n%d / %d' % (len(rps), Y_test.shape[1]))
In [ ]:
perf = {dataset_name: {'Test': {'R-Precision': np.mean(rps),
'Hit-Rate': {top: np.mean(hitrates[top]) for top in TOPs},
'AUC': np.mean(aucs),
'Spread': np.mean(spreads),
'Novelty': {t: np.mean([np.mean(novelties[t][u]) for u in novelties[t]])
for t in TOPs},
'Artist-Diversity': {top: np.mean(artist_diversities[top]) for top in TOPs},
'Genre-Diversity': {top: np.mean(genre_diversities[top]) for top in TOPs}},
'Test_All': {'R-Precision': rps,
'Hit-Rate': {top: hitrates[top] for top in TOPs},
'AUC': aucs,
'Spread': spreads,
'Novelty': novelties,
'Artist-Diversity': artist_diversities,
'Genre-Diversity': genre_diversities}}}
perf[dataset_name]['Test']
In [ ]:
fperf = os.path.join(data_dir, 'perf-mfcnn.pkl')
print(fperf)
pkl.dump(perf, open(fperf, 'wb'))
pkl.load(open(fperf, 'rb'))[dataset_name]['Test']