In [1]:
import numpy as np

from scipy import io
from sklearn.metrics import roc_auc_score, average_precision_score

import pmf

In [2]:
train_tracks = list()
with open('train_tracks.txt', 'rb') as f:
    for line in f:
        train_tracks.append(line.strip())
        
test_tracks = list()
with open('test_tracks.txt', 'rb') as f:
    for line in f:
        test_tracks.append(line.strip())
        
tags = list()
with open('voc.txt', 'rb') as f:
    for line in f:
        tags.append(line.strip())

In [3]:
# compute evaluation metrics
def construct_pred_mask(tags_predicted, predictat):
    n_samples, n_tags = tags_predicted.shape
    rankings = np.argsort(-tags_predicted, axis=1)[:, :predictat]
    tags_predicted_binary = np.zeros_like(tags_predicted, dtype=bool)
    for i in xrange(n_samples):
        tags_predicted_binary[i, rankings[i]] = 1
    return tags_predicted_binary

def per_tag_prec_recall(tags_predicted_binary, tags_true_binary):
    mask = np.logical_and(tags_predicted_binary, tags_true_binary)
    prec = mask.sum(axis=0) / (tags_predicted_binary.sum(axis=0) + np.spacing(1))
    tags_true_count = tags_true_binary.sum(axis=0).astype(float)
    idx = (tags_true_count > 0)
    recall = mask.sum(axis=0)[idx] / tags_true_count[idx]
    return prec, recall


def aroc_ap(tags_true_binary, tags_predicted):
    n_tags = tags_true_binary.shape[1]
    
    auc = list()
    aprec = list()
    for i in xrange(n_tags):
        if np.sum(tags_true_binary[:, i]) != 0:
            auc.append(roc_auc_score(tags_true_binary[:, i], tags_predicted[:, i]))
            aprec.append(average_precision_score(tags_true_binary[:, i], tags_predicted[:, i]))
    return auc, aprec


def print_out_metrics(tags_true_binary, tags_predicted, predictat):
    tags_predicted_binary = construct_pred_mask(tags_predicted, predictat)
    prec, recall = per_tag_prec_recall(tags_predicted_binary, tags_true_binary)
    mprec, mrecall = np.mean(prec), np.mean(recall)
    
    print 'Precision = %.3f (%.3f)' % (mprec, np.std(prec) / sqrt(prec.size))
    print 'Recall = %.3f (%.3f)' % (mrecall, np.std(recall) / sqrt(recall.size))
    print 'F-score = %.3f' % (2 * mprec * mrecall / (mprec + mrecall))

    auc, aprec = aroc_ap(tags_true_binary, tags_predicted)
    print 'AROC = %.3f (%.3f)' % (np.mean(auc), np.std(auc) / sqrt(len(auc)))
    print 'AP = %.3f (%.3f)' % (np.mean(aprec), np.std(aprec) / sqrt(len(aprec)))

In [4]:
# codebook size (for in memoery, should not be too large)
K = 512

In [5]:
# load the pre-saved data
data_mat = io.loadmat('data_K%d.mat' % K)
X, X_test, y_test = data_mat['X'], data_mat['X_test'], data_mat['y_test']

In [6]:
tmp = X[:, K:]
tmp[tmp > 0] = 1
X[:, K:] = tmp

In [7]:
hist(np.sum( (y_test > 0), axis=1), bins=50)
pass



In [8]:
D = K + len(tags)

In [9]:
# pick a random song and take a look
bar(np.arange(D), X[0])


Out[9]:
<Container object of 1073 artists>

Batch inference on 10K subset


In [10]:
n_components = 100
coder = pmf.PoissonMF(n_components=n_components, random_state=98765, verbose=True)

In [11]:
coder.fit(X)


	After ITERATION: 39	Objective: 10746172.92	Old objective: 10741000.95	Improvement: 0.00048
pmf.py:164: RuntimeWarning: invalid value encountered in double_scalars
  improvement = (bound - old_bd) / abs(old_bd)
Out[11]:
PoissonMF(max_iter=100, n_components=100, random_state=98765, smoothness=100,
     tol=0.0005, verbose=True)

In [12]:
tagger = pmf.PoissonMF(n_components=n_components, random_state=98765, verbose=True)

In [13]:
tagger.set_components(coder.gamma_b[:, :K], coder.rho_b)


Out[13]:
PoissonMF(max_iter=100, n_components=100, random_state=98765, smoothness=100,
     tol=0.0005, verbose=True)

In [14]:
Et = tagger.transform(X_test)


	After ITERATION: 18	Objective: 3308237.31	Old objective: 3306803.62	Improvement: 0.00043

In [15]:
Et /= Et.sum(axis=1, keepdims=True)

tags_predicted = Et.dot(coder.Eb[:, K:])
print tags_predicted.min(), tags_predicted.max()

div_factor = 3
tags_predicted = tags_predicted - div_factor * np.mean(tags_predicted, axis=0)


0.00094841825642 1.35247354242

In [16]:
predictat = 20
tags_true_binary = (y_test > 0)

print_out_metrics(tags_true_binary, tags_predicted, predictat)


Precision = 0.111 (0.007)
Recall = 0.106 (0.006)
F-score = 0.108
AROC = 0.640 (0.005)
AP = 0.097 (0.005)

Stochastic inference on 10K subset


In [17]:
n_components = 100
online_coder = pmf.OnlinePoissonMF(n_components=n_components, batch_size=500, n_pass=1, 
                                   random_state=98765, verbose=True)

In [18]:
online_coder.fit(X, est_total=len(train_tracks))


Iteration 0: passing through the data...
	Minibatch 1:
	After ITERATION: 99	Objective: -478725.05	Old objective: -478998.31	Improvement: 0.00057
	Minibatch 2:
	After ITERATION: 69	Objective: 83380.31	Old objective: 83338.70	Improvement: 0.00050
	Minibatch 3:
	After ITERATION: 41	Objective: 217748.29	Old objective: 217640.87	Improvement: 0.00049
	Minibatch 4:
	After ITERATION: 33	Objective: 316526.52	Old objective: 316374.09	Improvement: 0.00048
	Minibatch 5:
	After ITERATION: 30	Objective: 322068.90	Old objective: 321908.73	Improvement: 0.00050
	Minibatch 6:
	After ITERATION: 30	Objective: 311596.59	Old objective: 311442.51	Improvement: 0.00049
	Minibatch 7:
	After ITERATION: 28	Objective: 335036.23	Old objective: 334869.13	Improvement: 0.00050
	Minibatch 8:
	After ITERATION: 27	Objective: 335628.81	Old objective: 335470.08	Improvement: 0.00047
	Minibatch 9:
	After ITERATION: 26	Objective: 345758.47	Old objective: 345589.10	Improvement: 0.00049
	Minibatch 10:
	After ITERATION: 26	Objective: 336569.17	Old objective: 336411.84	Improvement: 0.00047
	Minibatch 11:
	After ITERATION: 24	Objective: 401768.07	Old objective: 401568.16	Improvement: 0.00050
	Minibatch 12:
	After ITERATION: 24	Objective: 418374.06	Old objective: 418176.19	Improvement: 0.00047
	Minibatch 13:
	After ITERATION: 24	Objective: 391494.32	Old objective: 391309.49	Improvement: 0.00047
	Minibatch 14:
	After ITERATION: 24	Objective: 398904.84	Old objective: 398724.56	Improvement: 0.00045
	Minibatch 15:
	After ITERATION: 24	Objective: 389183.27	Old objective: 389004.84	Improvement: 0.00046
	Minibatch 16:
	After ITERATION: 22	Objective: 446196.69	Old objective: 445983.59	Improvement: 0.00048
	Minibatch 17:
	After ITERATION: 23	Objective: 420530.37	Old objective: 420340.72	Improvement: 0.00045
	Minibatch 18:
	After ITERATION: 21	Objective: 475206.58	Old objective: 474976.67	Improvement: 0.00048
	Minibatch 19:
	After ITERATION: 21	Objective: 449503.35	Old objective: 449283.21	Improvement: 0.00049
	Minibatch 20:
	After ITERATION: 21	Objective: 473768.83	Old objective: 473546.76	Improvement: 0.00047
Out[18]:
OnlinePoissonMF(batch_size=500, max_iter=100, n_components=100, n_pass=1,
        random_state=98765, shuffle=True, smoothness=100, tol=0.0005,
        verbose=True)

In [19]:
plot(online_coder.bound)
pass



In [20]:
tagger = pmf.PoissonMF(n_components=n_components, random_state=98765, verbose=True)

In [21]:
tagger.set_components(online_coder.gamma_b[:, :K], online_coder.rho_b[:, :K])


Out[21]:
PoissonMF(max_iter=100, n_components=100, random_state=98765, smoothness=100,
     tol=0.0005, verbose=True)

In [22]:
Et = tagger.transform(X_test)


	After ITERATION: 20	Objective: 3068245.32	Old objective: 3066785.25	Improvement: 0.00048

In [23]:
Et /= Et.sum(axis=1, keepdims=True)

tags_predicted = Et.dot(online_coder.Eb[:, K:])
n_samples, n_tags = tags_predicted.shape

print tags_predicted.min(), tags_predicted.max()

div_factor = 3
tags_predicted = tags_predicted - div_factor * np.mean(tags_predicted, axis=0)


5.46251938084e-05 1.09877408196

In [24]:
predictat = 20
tags_true_binary = (y_test > 0)

print_out_metrics(tags_true_binary, tags_predicted, predictat)


Precision = 0.112 (0.007)
Recall = 0.128 (0.007)
F-score = 0.120
AROC = 0.684 (0.005)
AP = 0.110 (0.006)

Stochastic inference on the full set


In [25]:
data_mat = io.loadmat('X_train_K%d.mat' % K)
tag_mat = io.loadmat('y_train.mat')

In [26]:
X = np.hstack((data_mat['X'], tag_mat['y_train']))

In [27]:
n_components = 100
batch_size = 1000
online_coder_full = pmf.OnlinePoissonMF(n_components=n_components, batch_size=batch_size, n_pass=1, 
                                        random_state=98765, verbose=True)

In [28]:
online_coder_full.fit(X)


Iteration 0: passing through the data...
	Minibatch 1:
	After ITERATION: 99	Objective: -838968.48	Old objective: -839520.23	Improvement: 0.00066
	Minibatch 2:
	After ITERATION: 55	Objective: 330451.59	Old objective: 330293.21	Improvement: 0.00048
	Minibatch 3:
	After ITERATION: 38	Objective: 553125.94	Old objective: 552867.14	Improvement: 0.00047
	Minibatch 4:
	After ITERATION: 33	Objective: 635871.16	Old objective: 635556.16	Improvement: 0.00050
	Minibatch 5:
	After ITERATION: 30	Objective: 695878.24	Old objective: 695546.55	Improvement: 0.00048
	Minibatch 6:
	After ITERATION: 28	Objective: 768466.09	Old objective: 768087.58	Improvement: 0.00049
	Minibatch 7:
	After ITERATION: 28	Objective: 739469.91	Old objective: 739118.06	Improvement: 0.00048
	Minibatch 8:
	After ITERATION: 26	Objective: 836343.78	Old objective: 835954.64	Improvement: 0.00047
	Minibatch 9:
	After ITERATION: 26	Objective: 852619.10	Old objective: 852227.48	Improvement: 0.00046
	Minibatch 10:
	After ITERATION: 25	Objective: 819482.01	Old objective: 819082.19	Improvement: 0.00049
	Minibatch 11:
	After ITERATION: 24	Objective: 845208.45	Old objective: 844814.10	Improvement: 0.00047
	Minibatch 12:
	After ITERATION: 24	Objective: 845044.81	Old objective: 844642.22	Improvement: 0.00048
	Minibatch 13:
	After ITERATION: 24	Objective: 816059.30	Old objective: 815666.91	Improvement: 0.00048
	Minibatch 14:
	After ITERATION: 24	Objective: 794065.94	Old objective: 793680.55	Improvement: 0.00049
	Minibatch 15:
	After ITERATION: 22	Objective: 974278.44	Old objective: 973796.11	Improvement: 0.00050
	Minibatch 16:
	After ITERATION: 23	Objective: 896549.23	Old objective: 896143.46	Improvement: 0.00045
	Minibatch 17:
	After ITERATION: 22	Objective: 863382.90	Old objective: 862961.57	Improvement: 0.00049
	Minibatch 18:
	After ITERATION: 23	Objective: 882706.00	Old objective: 882306.14	Improvement: 0.00045
	Minibatch 19:
	After ITERATION: 22	Objective: 891975.51	Old objective: 891550.33	Improvement: 0.00048
	Minibatch 20:
	After ITERATION: 21	Objective: 924968.90	Old objective: 924517.78	Improvement: 0.00049
	Minibatch 21:
	After ITERATION: 21	Objective: 948423.94	Old objective: 947972.97	Improvement: 0.00048
	Minibatch 22:
	After ITERATION: 21	Objective: 855681.70	Old objective: 855254.13	Improvement: 0.00050
	Minibatch 23:
	After ITERATION: 20	Objective: 979569.87	Old objective: 979086.64	Improvement: 0.00049
	Minibatch 24:
	After ITERATION: 21	Objective: 959566.25	Old objective: 959114.87	Improvement: 0.00047
	Minibatch 25:
	After ITERATION: 20	Objective: 966727.63	Old objective: 966257.28	Improvement: 0.00049
	Minibatch 26:
	After ITERATION: 21	Objective: 924176.66	Old objective: 923751.42	Improvement: 0.00046
	Minibatch 27:
	After ITERATION: 20	Objective: 955808.82	Old objective: 955350.94	Improvement: 0.00048
	Minibatch 28:
	After ITERATION: 20	Objective: 933417.55	Old objective: 932984.50	Improvement: 0.00046
	Minibatch 29:
	After ITERATION: 20	Objective: 934965.92	Old objective: 934510.41	Improvement: 0.00049
	Minibatch 30:
	After ITERATION: 20	Objective: 923400.92	Old objective: 922950.94	Improvement: 0.00049
	Minibatch 31:
	After ITERATION: 19	Objective: 1011637.22	Old objective: 1011153.42	Improvement: 0.00048
	Minibatch 32:
	After ITERATION: 20	Objective: 919927.35	Old objective: 919490.85	Improvement: 0.00047
	Minibatch 33:
	After ITERATION: 19	Objective: 983848.50	Old objective: 983359.24	Improvement: 0.00050
	Minibatch 34:
	After ITERATION: 19	Objective: 1011851.79	Old objective: 1011374.01	Improvement: 0.00047
	Minibatch 35:
	After ITERATION: 19	Objective: 943447.83	Old objective: 942977.35	Improvement: 0.00050
	Minibatch 36:
	After ITERATION: 19	Objective: 950197.67	Old objective: 949724.08	Improvement: 0.00050
	Minibatch 37:
	After ITERATION: 19	Objective: 992837.84	Old objective: 992395.29	Improvement: 0.00045
	Minibatch 38:
	After ITERATION: 19	Objective: 981298.24	Old objective: 980827.19	Improvement: 0.00048
	Minibatch 39:
	After ITERATION: 19	Objective: 994155.85	Old objective: 993681.70	Improvement: 0.00048
	Minibatch 40:
	After ITERATION: 19	Objective: 944250.43	Old objective: 943814.25	Improvement: 0.00046
	Minibatch 41:
	After ITERATION: 19	Objective: 952229.46	Old objective: 951816.84	Improvement: 0.00043
	Minibatch 42:
	After ITERATION: 18	Objective: 991395.34	Old objective: 990905.65	Improvement: 0.00049
	Minibatch 43:
	After ITERATION: 19	Objective: 979112.73	Old objective: 978678.59	Improvement: 0.00044
	Minibatch 44:
	After ITERATION: 18	Objective: 1021684.05	Old objective: 1021176.08	Improvement: 0.00050
	Minibatch 45:
	After ITERATION: 19	Objective: 964037.28	Old objective: 963590.83	Improvement: 0.00046
	Minibatch 46:
	After ITERATION: 19	Objective: 996080.71	Old objective: 995649.74	Improvement: 0.00043
	Minibatch 47:
	After ITERATION: 18	Objective: 1007201.75	Old objective: 1006727.54	Improvement: 0.00047
	Minibatch 48:
	After ITERATION: 19	Objective: 963821.48	Old objective: 963378.24	Improvement: 0.00046
	Minibatch 49:
	After ITERATION: 18	Objective: 977100.85	Old objective: 976616.29	Improvement: 0.00050
	Minibatch 50:
	After ITERATION: 18	Objective: 1037832.34	Old objective: 1037336.13	Improvement: 0.00048
	Minibatch 51:
	After ITERATION: 18	Objective: 1019284.29	Old objective: 1018816.05	Improvement: 0.00046
	Minibatch 52:
	After ITERATION: 19	Objective: 957454.85	Old objective: 957031.01	Improvement: 0.00044
	Minibatch 53:
	After ITERATION: 19	Objective: 932207.71	Old objective: 931787.92	Improvement: 0.00045
	Minibatch 54:
	After ITERATION: 18	Objective: 1043985.60	Old objective: 1043509.42	Improvement: 0.00046
	Minibatch 55:
	After ITERATION: 17	Objective: 1054681.95	Old objective: 1054169.84	Improvement: 0.00049
	Minibatch 56:
	After ITERATION: 18	Objective: 947185.90	Old objective: 946717.33	Improvement: 0.00049
	Minibatch 57:
	After ITERATION: 18	Objective: 962078.55	Old objective: 961625.79	Improvement: 0.00047
	Minibatch 58:
	After ITERATION: 18	Objective: 954449.56	Old objective: 953988.60	Improvement: 0.00048
	Minibatch 59:
	After ITERATION: 18	Objective: 1033015.34	Old objective: 1032560.56	Improvement: 0.00044
	Minibatch 60:
	After ITERATION: 18	Objective: 950030.86	Old objective: 949565.39	Improvement: 0.00049
	Minibatch 61:
	After ITERATION: 18	Objective: 1006097.61	Old objective: 1005648.49	Improvement: 0.00045
	Minibatch 62:
	After ITERATION: 18	Objective: 1010524.60	Old objective: 1010063.96	Improvement: 0.00046
	Minibatch 63:
	After ITERATION: 18	Objective: 960181.94	Old objective: 959740.15	Improvement: 0.00046
	Minibatch 64:
	After ITERATION: 18	Objective: 1005279.89	Old objective: 1004814.24	Improvement: 0.00046
	Minibatch 65:
	After ITERATION: 18	Objective: 1006458.92	Old objective: 1006024.37	Improvement: 0.00043
	Minibatch 66:
	After ITERATION: 18	Objective: 1025903.07	Old objective: 1025452.39	Improvement: 0.00044
	Minibatch 67:
	After ITERATION: 18	Objective: 984433.06	Old objective: 983998.63	Improvement: 0.00044
	Minibatch 68:
	After ITERATION: 17	Objective: 1024722.72	Old objective: 1024232.46	Improvement: 0.00048
	Minibatch 69:
	After ITERATION: 17	Objective: 1049168.83	Old objective: 1048676.28	Improvement: 0.00047
	Minibatch 70:
	After ITERATION: 18	Objective: 939387.34	Old objective: 938955.50	Improvement: 0.00046
	Minibatch 71:
	After ITERATION: 18	Objective: 1017890.59	Old objective: 1017460.27	Improvement: 0.00042
	Minibatch 72:
	After ITERATION: 17	Objective: 1083556.84	Old objective: 1083076.20	Improvement: 0.00044
	Minibatch 73:
	After ITERATION: 17	Objective: 1003922.28	Old objective: 1003433.21	Improvement: 0.00049
	Minibatch 74:
	After ITERATION: 17	Objective: 1043622.76	Old objective: 1043121.90	Improvement: 0.00048
	Minibatch 75:
	After ITERATION: 18	Objective: 974777.54	Old objective: 974338.72	Improvement: 0.00045
	Minibatch 76:
	After ITERATION: 17	Objective: 1018929.83	Old objective: 1018472.21	Improvement: 0.00045
	Minibatch 77:
	After ITERATION: 17	Objective: 977949.65	Old objective: 977465.20	Improvement: 0.00050
	Minibatch 78:
	After ITERATION: 17	Objective: 1066141.66	Old objective: 1065660.94	Improvement: 0.00045
	Minibatch 79:
	After ITERATION: 18	Objective: 969007.97	Old objective: 968573.76	Improvement: 0.00045
	Minibatch 80:
	After ITERATION: 18	Objective: 990091.70	Old objective: 989666.39	Improvement: 0.00043
	Minibatch 81:
	After ITERATION: 17	Objective: 1012554.69	Old objective: 1012060.69	Improvement: 0.00049
	Minibatch 82:
	After ITERATION: 17	Objective: 1075548.13	Old objective: 1075070.57	Improvement: 0.00044
	Minibatch 83:
	After ITERATION: 17	Objective: 1013357.05	Old objective: 1012917.77	Improvement: 0.00043
	Minibatch 84:
	After ITERATION: 17	Objective: 951811.85	Old objective: 951341.39	Improvement: 0.00049
	Minibatch 85:
	After ITERATION: 17	Objective: 1039373.30	Old objective: 1038916.34	Improvement: 0.00044
	Minibatch 86:
	After ITERATION: 17	Objective: 1041026.72	Old objective: 1040543.29	Improvement: 0.00046
	Minibatch 87:
	After ITERATION: 17	Objective: 981806.83	Old objective: 981337.39	Improvement: 0.00048
	Minibatch 88:
	After ITERATION: 17	Objective: 984409.75	Old objective: 983922.09	Improvement: 0.00050
	Minibatch 89:
	After ITERATION: 17	Objective: 974591.89	Old objective: 974136.58	Improvement: 0.00047
	Minibatch 90:
	After ITERATION: 17	Objective: 1042582.16	Old objective: 1042111.95	Improvement: 0.00045
	Minibatch 91:
	After ITERATION: 17	Objective: 1072026.12	Old objective: 1071567.77	Improvement: 0.00043
	Minibatch 92:
	After ITERATION: 16	Objective: 1040577.82	Old objective: 1040063.57	Improvement: 0.00049
	Minibatch 93:
	After ITERATION: 17	Objective: 974010.32	Old objective: 973526.75	Improvement: 0.00050
	Minibatch 94:
	After ITERATION: 17	Objective: 1070298.98	Old objective: 1069809.17	Improvement: 0.00046
	Minibatch 95:
	After ITERATION: 17	Objective: 919223.29	Old objective: 918769.84	Improvement: 0.00049
	Minibatch 96:
	After ITERATION: 17	Objective: 1029070.82	Old objective: 1028594.58	Improvement: 0.00046
	Minibatch 97:
	After ITERATION: 17	Objective: 1017395.00	Old objective: 1016947.57	Improvement: 0.00044
	Minibatch 98:
	After ITERATION: 17	Objective: 1017731.10	Old objective: 1017278.11	Improvement: 0.00045
	Minibatch 99:
	After ITERATION: 17	Objective: 972302.05	Old objective: 971851.16	Improvement: 0.00046
	Minibatch 100:
	After ITERATION: 17	Objective: 1032231.04	Old objective: 1031776.26	Improvement: 0.00044
	Minibatch 101:
	After ITERATION: 17	Objective: 992875.16	Old objective: 992428.54	Improvement: 0.00045
	Minibatch 102:
	After ITERATION: 16	Objective: 1071424.25	Old objective: 1070903.82	Improvement: 0.00049
	Minibatch 103:
	After ITERATION: 16	Objective: 1068643.98	Old objective: 1068111.96	Improvement: 0.00050
	Minibatch 104:
	After ITERATION: 16	Objective: 1076202.58	Old objective: 1075678.17	Improvement: 0.00049
	Minibatch 105:
	After ITERATION: 17	Objective: 1052727.79	Old objective: 1052275.29	Improvement: 0.00043
	Minibatch 106:
	After ITERATION: 16	Objective: 1077327.28	Old objective: 1076794.34	Improvement: 0.00049
	Minibatch 107:
	After ITERATION: 17	Objective: 1035360.77	Old objective: 1034917.38	Improvement: 0.00043
	Minibatch 108:
	After ITERATION: 16	Objective: 1051049.96	Old objective: 1050535.16	Improvement: 0.00049
	Minibatch 109:
	After ITERATION: 16	Objective: 1089714.15	Old objective: 1089189.53	Improvement: 0.00048
	Minibatch 110:
	After ITERATION: 16	Objective: 1056248.04	Old objective: 1055756.68	Improvement: 0.00047
	Minibatch 111:
	After ITERATION: 16	Objective: 1120971.59	Old objective: 1120436.46	Improvement: 0.00048
	Minibatch 112:
	After ITERATION: 16	Objective: 1063079.05	Old objective: 1062564.31	Improvement: 0.00048
	Minibatch 113:
	After ITERATION: 17	Objective: 988163.42	Old objective: 987743.95	Improvement: 0.00042
	Minibatch 114:
	After ITERATION: 17	Objective: 1013004.93	Old objective: 1012561.58	Improvement: 0.00044
	Minibatch 115:
	After ITERATION: 16	Objective: 1082110.48	Old objective: 1081606.30	Improvement: 0.00047
	Minibatch 116:
	After ITERATION: 17	Objective: 1005101.06	Old objective: 1004662.91	Improvement: 0.00044
	Minibatch 117:
	After ITERATION: 17	Objective: 1003218.58	Old objective: 1002777.58	Improvement: 0.00044
	Minibatch 118:
	After ITERATION: 16	Objective: 1115720.47	Old objective: 1115202.88	Improvement: 0.00046
	Minibatch 119:
	After ITERATION: 16	Objective: 1051355.74	Old objective: 1050854.45	Improvement: 0.00048
	Minibatch 120:
	After ITERATION: 16	Objective: 1086837.69	Old objective: 1086318.34	Improvement: 0.00048
	Minibatch 121:
	After ITERATION: 16	Objective: 1061516.71	Old objective: 1061009.79	Improvement: 0.00048
	Minibatch 122:
	After ITERATION: 17	Objective: 988589.83	Old objective: 988164.22	Improvement: 0.00043
	Minibatch 123:
	After ITERATION: 16	Objective: 1025563.35	Old objective: 1025072.77	Improvement: 0.00048
	Minibatch 124:
	After ITERATION: 17	Objective: 1067431.06	Old objective: 1066969.37	Improvement: 0.00043
	Minibatch 125:
	After ITERATION: 16	Objective: 1050025.33	Old objective: 1049537.93	Improvement: 0.00046
	Minibatch 126:
	After ITERATION: 16	Objective: 1019439.84	Old objective: 1018956.99	Improvement: 0.00047
	Minibatch 127:
	After ITERATION: 16	Objective: 1045361.87	Old objective: 1044845.38	Improvement: 0.00049
	Minibatch 128:
	After ITERATION: 16	Objective: 1056190.13	Old objective: 1055689.20	Improvement: 0.00047
	Minibatch 129:
	After ITERATION: 16	Objective: 1102463.37	Old objective: 1101964.14	Improvement: 0.00045
	Minibatch 130:
	After ITERATION: 16	Objective: 1051361.05	Old objective: 1050865.13	Improvement: 0.00047
	Minibatch 131:
	After ITERATION: 16	Objective: 1060032.77	Old objective: 1059545.76	Improvement: 0.00046
	Minibatch 132:
	After ITERATION: 16	Objective: 1043248.71	Old objective: 1042758.84	Improvement: 0.00047
	Minibatch 133:
	After ITERATION: 16	Objective: 1012186.66	Old objective: 1011731.13	Improvement: 0.00045
	Minibatch 134:
	After ITERATION: 16	Objective: 1023425.71	Old objective: 1022977.10	Improvement: 0.00044
	Minibatch 135:
	After ITERATION: 17	Objective: 1012973.09	Old objective: 1012536.65	Improvement: 0.00043
	Minibatch 136:
	After ITERATION: 17	Objective: 933590.39	Old objective: 933189.01	Improvement: 0.00043
	Minibatch 137:
	After ITERATION: 16	Objective: 1060170.38	Old objective: 1059658.45	Improvement: 0.00048
	Minibatch 138:
	After ITERATION: 17	Objective: 1000358.72	Old objective: 999924.72	Improvement: 0.00043
	Minibatch 139:
	After ITERATION: 17	Objective: 1010732.43	Old objective: 1010312.39	Improvement: 0.00042
	Minibatch 140:
	After ITERATION: 16	Objective: 1025983.25	Old objective: 1025508.92	Improvement: 0.00046
	Minibatch 141:
	After ITERATION: 16	Objective: 1090344.03	Old objective: 1089878.39	Improvement: 0.00043
	Minibatch 142:
	After ITERATION: 16	Objective: 1041233.29	Old objective: 1040722.70	Improvement: 0.00049
	Minibatch 143:
	After ITERATION: 16	Objective: 1020799.80	Old objective: 1020322.90	Improvement: 0.00047
	Minibatch 144:
	After ITERATION: 16	Objective: 1093435.88	Old objective: 1092948.65	Improvement: 0.00045
	Minibatch 145:
	After ITERATION: 16	Objective: 1012263.74	Old objective: 1011792.12	Improvement: 0.00047
	Minibatch 146:
	After ITERATION: 16	Objective: 1055632.66	Old objective: 1055154.68	Improvement: 0.00045
	Minibatch 147:
	After ITERATION: 15	Objective: 1147956.81	Old objective: 1147393.11	Improvement: 0.00049
	Minibatch 148:
	After ITERATION: 16	Objective: 1092930.99	Old objective: 1092449.68	Improvement: 0.00044
	Minibatch 149:
	After ITERATION: 16	Objective: 1055100.46	Old objective: 1054621.33	Improvement: 0.00045
	Minibatch 150:
	After ITERATION: 16	Objective: 1065273.87	Old objective: 1064767.40	Improvement: 0.00048
	Minibatch 151:
	After ITERATION: 16	Objective: 1066387.52	Old objective: 1065905.20	Improvement: 0.00045
	Minibatch 152:
	After ITERATION: 16	Objective: 1009737.48	Old objective: 1009272.22	Improvement: 0.00046
	Minibatch 153:
	After ITERATION: 16	Objective: 1096544.91	Old objective: 1096056.93	Improvement: 0.00045
	Minibatch 154:
	After ITERATION: 16	Objective: 967781.76	Old objective: 967302.45	Improvement: 0.00050
	Minibatch 155:
	After ITERATION: 16	Objective: 1123083.12	Old objective: 1122600.07	Improvement: 0.00043
	Minibatch 156:
	After ITERATION: 15	Objective: 1077953.25	Old objective: 1077473.43	Improvement: 0.00045
	Minibatch 289:
	After ITERATION: 15	Objective: 1089217.47	Old objective: 1088718.46	Improvement: 0.00046
	Minibatch 290:
	After ITERATION: 14	Objective: 1165825.84	Old objective: 1165249.45	Improvement: 0.00049
	Minibatch 291:
	After ITERATION: 15	Objective: 1106836.86	Old objective: 1106340.41	Improvement: 0.00045
	Minibatch 292:
	After ITERATION: 15	Objective: 1032560.01	Old objective: 1032077.71	Improvement: 0.00047
	Minibatch 293:
	After ITERATION: 15	Objective: 1027049.26	Old objective: 1026549.71	Improvement: 0.00049
	Minibatch 294:
	After ITERATION: 14	Objective: 1228125.42	Old objective: 1227538.27	Improvement: 0.00048
	Minibatch 295:
	After ITERATION: 15	Objective: 1097882.46	Old objective: 1097374.40	Improvement: 0.00046
	Minibatch 296:
	After ITERATION: 15	Objective: 1056043.15	Old objective: 1055546.46	Improvement: 0.00047
	Minibatch 297:
	After ITERATION: 15	Objective: 1076393.49	Old objective: 1075893.55	Improvement: 0.00046
	Minibatch 298:
	After ITERATION: 15	Objective: 1093885.09	Old objective: 1093368.91	Improvement: 0.00047
	Minibatch 299:
	After ITERATION: 15	Objective: 1017895.30	Old objective: 1017392.30	Improvement: 0.00049
	Minibatch 300:
	After ITERATION: 15	Objective: 1093496.13	Old objective: 1093015.35	Improvement: 0.00044
	Minibatch 301:
	After ITERATION: 15	Objective: 1038203.66	Old objective: 1037734.18	Improvement: 0.00045
	Minibatch 302:
	After ITERATION: 15	Objective: 1091164.84	Old objective: 1090694.54	Improvement: 0.00043
	Minibatch 303:
	After ITERATION: 15	Objective: 1099725.70	Old objective: 1099207.36	Improvement: 0.00047
	Minibatch 304:
	After ITERATION: 15	Objective: 1177362.43	Old objective: 1176845.71	Improvement: 0.00044
	Minibatch 305:
	After ITERATION: 15	Objective: 1139684.47	Old objective: 1139149.42	Improvement: 0.00047
	Minibatch 306:
	After ITERATION: 15	Objective: 1070106.28	Old objective: 1069603.90	Improvement: 0.00047
	Minibatch 307:
	After ITERATION: 15	Objective: 1127686.79	Old objective: 1127213.48	Improvement: 0.00042
	Minibatch 308:
	After ITERATION: 15	Objective: 1107884.48	Old objective: 1107392.22	Improvement: 0.00044
	Minibatch 309:
	After ITERATION: 15	Objective: 1094298.45	Old objective: 1093776.53	Improvement: 0.00048
	Minibatch 310:
	After ITERATION: 15	Objective: 1063033.31	Old objective: 1062522.29	Improvement: 0.00048
	Minibatch 311:
	After ITERATION: 15	Objective: 1105122.11	Old objective: 1104636.89	Improvement: 0.00044
	Minibatch 312:
	After ITERATION: 15	Objective: 1142106.67	Old objective: 1141616.15	Improvement: 0.00043
	Minibatch 313:
	After ITERATION: 15	Objective: 1108809.55	Old objective: 1108329.21	Improvement: 0.00043
	Minibatch 314:
	After ITERATION: 15	Objective: 1069988.46	Old objective: 1069468.85	Improvement: 0.00049
	Minibatch 315:
	After ITERATION: 15	Objective: 1097799.84	Old objective: 1097298.06	Improvement: 0.00046
	Minibatch 316:
	After ITERATION: 15	Objective: 1162605.14	Old objective: 1162114.87	Improvement: 0.00042
	Minibatch 317:
	After ITERATION: 15	Objective: 1109930.94	Old objective: 1109448.15	Improvement: 0.00044
	Minibatch 318:
	After ITERATION: 15	Objective: 1038812.48	Old objective: 1038346.80	Improvement: 0.00045
	Minibatch 319:
	After ITERATION: 15	Objective: 1067610.58	Old objective: 1067128.58	Improvement: 0.00045
	Minibatch 320:
	After ITERATION: 15	Objective: 1069347.33	Old objective: 1068856.14	Improvement: 0.00046
	Minibatch 321:
	After ITERATION: 15	Objective: 1094636.35	Old objective: 1094137.08	Improvement: 0.00046
	Minibatch 322:
	After ITERATION: 15	Objective: 1096606.83	Old objective: 1096113.41	Improvement: 0.00045
	Minibatch 323:
	After ITERATION: 15	Objective: 1038113.38	Old objective: 1037628.07	Improvement: 0.00047
	Minibatch 324:
	After ITERATION: 14	Objective: 1173746.46	Old objective: 1173168.28	Improvement: 0.00049
	Minibatch 325:
	After ITERATION: 15	Objective: 1104455.22	Old objective: 1103955.33	Improvement: 0.00045
	Minibatch 326:
	After ITERATION: 15	Objective: 1031683.65	Old objective: 1031186.31	Improvement: 0.00048
	Minibatch 327:
	After ITERATION: 15	Objective: 1081303.79	Old objective: 1080816.41	Improvement: 0.00045
	Minibatch 328:
	After ITERATION: 15	Objective: 1113955.30	Old objective: 1113462.86	Improvement: 0.00044
	Minibatch 329:
	After ITERATION: 15	Objective: 1113190.81	Old objective: 1112713.83	Improvement: 0.00043
	Minibatch 330:
	After ITERATION: 15	Objective: 1093214.67	Old objective: 1092716.61	Improvement: 0.00046
	Minibatch 331:
	After ITERATION: 15	Objective: 1037706.20	Old objective: 1037235.50	Improvement: 0.00045
	Minibatch 332:
	After ITERATION: 15	Objective: 1023487.36	Old objective: 1022989.60	Improvement: 0.00049
	Minibatch 333:
	After ITERATION: 15	Objective: 1013706.55	Old objective: 1013220.04	Improvement: 0.00048
	Minibatch 334:
	After ITERATION: 15	Objective: 1048025.35	Old objective: 1047560.67	Improvement: 0.00044
	Minibatch 335:
	After ITERATION: 15	Objective: 1026867.44	Old objective: 1026385.18	Improvement: 0.00047
	Minibatch 336:
	After ITERATION: 15	Objective: 1039473.65	Old objective: 1038973.57	Improvement: 0.00048
	Minibatch 337:
	After ITERATION: 15	Objective: 1073516.73	Old objective: 1073039.06	Improvement: 0.00045
	Minibatch 338:
	After ITERATION: 15	Objective: 1131881.75	Old objective: 1131397.01	Improvement: 0.00043
	Minibatch 339:
	After ITERATION: 15	Objective: 1081600.08	Old objective: 1081118.21	Improvement: 0.00045
	Minibatch 340:
	After ITERATION: 15	Objective: 1060541.36	Old objective: 1060060.66	Improvement: 0.00045
	Minibatch 341:
	After ITERATION: 15	Objective: 1085819.39	Old objective: 1085344.10	Improvement: 0.00044
	Minibatch 342:
	After ITERATION: 15	Objective: 1120657.77	Old objective: 1120163.03	Improvement: 0.00044
	Minibatch 343:
	After ITERATION: 15	Objective: 1086096.47	Old objective: 1085602.53	Improvement: 0.00045
	Minibatch 344:
	After ITERATION: 15	Objective: 1016054.40	Old objective: 1015585.74	Improvement: 0.00046
	Minibatch 345:
	After ITERATION: 15	Objective: 1024968.83	Old objective: 1024483.74	Improvement: 0.00047
	Minibatch 346:
	After ITERATION: 15	Objective: 1024766.33	Old objective: 1024270.37	Improvement: 0.00048
	Minibatch 347:
	After ITERATION: 15	Objective: 1071465.65	Old objective: 1070975.24	Improvement: 0.00046
	Minibatch 348:
	After ITERATION: 15	Objective: 1085319.86	Old objective: 1084823.87	Improvement: 0.00046
	Minibatch 349:
	After ITERATION: 15	Objective: 1069978.64	Old objective: 1069490.76	Improvement: 0.00046
	Minibatch 350:
	After ITERATION: 15	Objective: 1125202.56	Old objective: 1124713.97	Improvement: 0.00043
	Minibatch 351:
	After ITERATION: 15	Objective: 1109112.28	Old objective: 1108605.04	Improvement: 0.00046
	Minibatch 352:
	After ITERATION: 15	Objective: 1080803.82	Old objective: 1080348.42	Improvement: 0.00042
	Minibatch 353:
	After ITERATION: 15	Objective: 1112990.55	Old objective: 1112501.74	Improvement: 0.00044
	Minibatch 354:
	After ITERATION: 15	Objective: 1136617.48	Old objective: 1136130.21	Improvement: 0.00043
	Minibatch 355:
	After ITERATION: 15	Objective: 1130881.53	Old objective: 1130410.07	Improvement: 0.00042
	Minibatch 356:
	After ITERATION: 15	Objective: 1079733.41	Old objective: 1079255.95	Improvement: 0.00044
	Minibatch 357:
	After ITERATION: 15	Objective: 1079419.66	Old objective: 1078932.18	Improvement: 0.00045
	Minibatch 358:
	After ITERATION: 15	Objective: 1092988.90	Old objective: 1092499.83	Improvement: 0.00045
	Minibatch 359:
	After ITERATION: 15	Objective: 1088301.20	Old objective: 1087839.03	Improvement: 0.00042
	Minibatch 360:
	After ITERATION: 15	Objective: 1116411.89	Old objective: 1115935.17	Improvement: 0.00043
	Minibatch 361:
	After ITERATION: 15	Objective: 1120961.60	Old objective: 1120464.83	Improvement: 0.00044
	Minibatch 362:
	After ITERATION: 15	Objective: 1101373.62	Old objective: 1100895.62	Improvement: 0.00043
	Minibatch 363:
	After ITERATION: 15	Objective: 1052165.65	Old objective: 1051686.17	Improvement: 0.00046
	Minibatch 364:
	After ITERATION: 15	Objective: 1045082.61	Old objective: 1044591.19	Improvement: 0.00047
	Minibatch 365:
	After ITERATION: 14	Objective: 1106699.89	Old objective: 1106155.34	Improvement: 0.00049
	Minibatch 366:
	After ITERATION: 15	Objective: 1110061.80	Old objective: 1109568.83	Improvement: 0.00044
	Minibatch 367:
	After ITERATION: 15	Objective: 1049802.42	Old objective: 1049321.96	Improvement: 0.00046
	Minibatch 368:
	After ITERATION: 15	Objective: 1081693.68	Old objective: 1081217.30	Improvement: 0.00044
	Minibatch 369:
	After ITERATION: 15	Objective: 1074035.69	Old objective: 1073549.99	Improvement: 0.00045
	Minibatch 370:
	After ITERATION: 14	Objective: 1188333.55	Old objective: 1187756.43	Improvement: 0.00049
	Minibatch 371:
	After ITERATION: 14	Objective: 1190537.93	Old objective: 1189976.66	Improvement: 0.00047
	Minibatch 372:
	After ITERATION: 14	Objective: 265251.94	Old objective: 265135.87	Improvement: 0.00044
Out[28]:
OnlinePoissonMF(batch_size=1000, max_iter=100, n_components=100, n_pass=1,
        random_state=98765, shuffle=True, smoothness=100, tol=0.0005,
        verbose=True)

In [29]:
# the last batch is not full
plot(online_coder_full.bound[:-1])
pass



In [30]:
tagger = pmf.PoissonMF(n_components=n_components, random_state=98765, verbose=True)

In [31]:
tagger.set_components(online_coder_full.gamma_b[:, :K], online_coder_full.rho_b[:, :K])


Out[31]:
PoissonMF(max_iter=100, n_components=100, random_state=98765, smoothness=100,
     tol=0.0005, verbose=True)

In [32]:
Et = tagger.transform(X_test)


	After ITERATION: 15	Objective: 3303796.93	Old objective: 3302344.12	Improvement: 0.00044

In [33]:
Et /= Et.sum(axis=1, keepdims=True)
tags_predicted = Et.dot(online_coder_full.Eb[:, K:])
print tags_predicted.min(), tags_predicted.max()

div_factor = 3
tags_predicted = tags_predicted - div_factor * np.mean(tags_predicted, axis=0)


2.70206653139e-05 1.07069406765

In [34]:
predictat = 20
tags_true_binary = (y_test > 0)

print_out_metrics(tags_true_binary, tags_predicted, predictat)


Precision = 0.131 (0.008)
Recall = 0.154 (0.008)
F-score = 0.141
AROC = 0.718 (0.005)
AP = 0.122 (0.006)

In [ ]: