In [1]:
%pylab inline
In [6]:
import h5py
import pandas
from sklearn.metrics import mean_squared_error
In [7]:
def load_h5(name):
print "reading from",name
h5f = h5py.File(name,'r')
labels = h5f['labels'][:]
qids = h5f['qids'][:]
features = h5f['features'][:]
h5f.close()
print "done"
sorter = numpy.argsort(qids)
return features[sorter], qids[sorter], labels[sorter]
In [12]:
Xtr,Qtr,Ytr = load_h5("../data/MSLR/mslr_train")
Xts,Qts,Yts = load_h5("../data/MSLR/mslr_test")
In [32]:
print len(Xtr), len(Xts)
In [50]:
from rep_ef.estimators import EventFilterRegressor
ef = EventFilterRegressor(iterations=10000, connection='test_connection', dataset_name='letor-{random}')
In [ ]:
%%time
ef.fit(Xtr, Ytr)
In [63]:
import cPickle
with open('../data/MSLR10k_ef.mx', 'w') as f:
cPickle.dump(ef.formula_mx, f)
In [64]:
with open('../data/MSLR10k_ef.mx', 'r') as f:
formula_mx = cPickle.load(f)
In [89]:
from _matrixnetapplier import MatrixnetClassifier
from StringIO import StringIO
In [72]:
mean_squared_error(Yts, ef.predict(Xts))
Out[72]:
In [91]:
mn = MatrixnetClassifier(StringIO(formula_mx))
mean_squared_error(Yts, mn.apply(Xts))
Out[91]:
In [56]:
from itertools import islice
def plot_mse_curves(clf, step=5):
mses_ts = []
for p in islice(clf.staged_predict(Xts), None, None, step):
mses_ts.append(mean_squared_error(Yts, p))
mses_tr = []
for p in islice(clf.staged_predict(Xtr), None, None, step):
mses_tr.append(mean_squared_error(Ytr, p))
plot(mses_ts)
plot(mses_tr)
return mses_tr, mses_ts
In [93]:
mses_ef = plot_mse_curves(ef)
ylim(0.5, 0.6), grid()
Out[93]:
In [60]:
min(mses_ef[0]), min(mses_ef[1]), mses_ef[1][-10:]
Out[60]:
In [94]:
ef_collection = {}
In [ ]:
%%time
from rep_ef.estimators import EventFilterRegressor
for reg in [0.01, 0.03, 0.1, 0.3]:
ef = EventFilterRegressor(iterations=600, connection='test_connection', dataset_name='letor-{random}',
regularization=reg)
ef.fit(Xtr, Ytr)
ef_collection[reg] = ef
In [104]:
for reg, clf in sorted(ef_collection.iteritems()):
mses_tr, mses_ts = plot_mse_curves(clf)
print reg, min(mses_ts)
show()