In [1]:
import pandas as pd
import numpy as np

import scipy
import scipy.sparse

import sklearn
import sklearn.svm
import sklearn.datasets
import sklearn.cross_validation

import warnings
warnings.filterwarnings('ignore')

In [2]:
X, y = sklearn.datasets.load_svmlight_file('data/news20.binary')

instance_ids = np.arange(y.size)

In [3]:
splits = sklearn.cross_validation.StratifiedShuffleSplit(y, n_iter=1, test_size=0.95)
labeled_indices, unlabeled_indices = splits.__iter__().next()

In [4]:
L = X[labeled_indices]
L_ids = instance_ids[labeled_indices]

U = X[unlabeled_indices]
U_ids = instance_ids[unlabeled_indices]

y_l = y[labeled_indices]
y_u = y[unlabeled_indices]

In [5]:
def increment_svm(svm, L_ids, baseline_accuracy):
    
    L = X[L_ids]
    y_l = y[L_ids]
    
    U_ids = np.array(list((set(instance_ids) - set(L_ids))))
    U = X[U_ids]
    y_u = y[U_ids]

    ordered_indices = np.argsort(svm.decision_function(U))
    smallest_indices = ordered_indices[:500]
    smallest_ids = U_ids[smallest_indices]
    largest_indices = ordered_indices[-500:]
    largest_ids = U_ids[largest_indices]
    
    high_confidence_unlabeled = scipy.sparse.vstack([U[smallest_indices], U[largest_indices]])
    high_confidence_ids = np.concatenate([smallest_ids, largest_ids])
    high_confidence_predicted_labels = svm.predict(high_confidence_unlabeled)
    high_confidence_true_labels = y[high_confidence_ids]
    
    splits = sklearn.cross_validation.StratifiedShuffleSplit(high_confidence_predicted_labels, n_iter=2, test_size=0.9)

    saved_L_primes = []
    saved_L_prime_ids = []
    saved_cv_accuracies = []

    for augment_indices, test_indices in splits:

        augment = high_confidence_unlabeled[augment_indices]
        test = high_confidence_unlabeled[test_indices]

        augment_ids = high_confidence_ids[augment_indices]
        test_ids = high_confidence_ids[test_indices]

        augment_labels = high_confidence_predicted_labels[augment_indices] 
        test_labels = high_confidence_predicted_labels[test_indices]

        L_prime = scipy.sparse.vstack([L, augment])

        y_l_prime = np.concatenate([y_l, augment_labels])
        L_prime_ids = np.concatenate([L_ids, augment_ids])

        saved_L_primes.append(L_prime)
        saved_L_prime_ids.append(L_prime_ids)    

        svm_prime = sklearn.svm.LinearSVC(penalty='l2', C=10, dual=False)
        accuracy = sklearn.cross_validation.cross_val_score(svm_prime, L_prime, y_l_prime, cv=5, n_jobs=7).mean()

        saved_cv_accuracies.append(accuracy)
            
    best_index = np.argmax(saved_cv_accuracies)
    best_L_prime_ids = saved_L_prime_ids[best_index]
    best_accuracy = saved_cv_accuracies[best_index]
    
    return best_L_prime_ids, best_accuracy

In [ ]:
svm = sklearn.svm.LinearSVC(penalty='l2', C=10, dual=False)
svm.fit(L, y_l)
cv_accuracy = sklearn.cross_validation.cross_val_score(svm, L, y_l, cv=5, n_jobs=7).mean()

accuracies = [cv_accuracy]

iteration = 0
number_labeled = L.shape[0]
prediction_accuracy = sklearn.metrics.accuracy_score(y_u, svm.predict(U))

print "%d\t%d\t%f\t%f" %(iteration, number_labeled, cv_accuracy, prediction_accuracy)


while True:
    iteration += 1

    L_ids, cv_accuracy = increment_svm(svm, L_ids, cv_accuracy)
    
    L = X[L_ids]
    y_l = y[L_ids]
    
    U_ids = np.array(list((set(instance_ids) - set(L_ids))))
    U = X[U_ids]
    y_u = y[U_ids]
    
    svm = sklearn.svm.LinearSVC(penalty='l2', C=10, dual=False)
    svm.fit(L, y_l)
    
    number_labeled = L.shape[0]
    
    prediction_accuracy = sklearn.metrics.accuracy_score(y_u, svm.predict(U))
    print "%d\t%d\t%f\t%f" %(iteration, number_labeled, cv_accuracy, prediction_accuracy)


0	999	0.880869	0.880981
L shape (999, 1355191)
augment shape (100,)
lprime shape (1099, 1355191)
L shape (999, 1355191)
augment shape (100,)
lprime shape (1099, 1355191)
1	1099	0.899004	0.880246
L shape (1099, 1355191)
augment shape (100,)
lprime shape (1199, 1355191)
L shape (1099, 1355191)
augment shape (100,)
lprime shape (1199, 1355191)
2	1199	0.905743	0.879662
L shape (1199, 1355191)
augment shape (100,)
lprime shape (1299, 1355191)
L shape (1199, 1355191)
augment shape (100,)
lprime shape (1299, 1355191)
3	1299	0.916840	0.878965
L shape (1299, 1355191)
augment shape (100,)
lprime shape (1399, 1355191)
L shape (1299, 1355191)
augment shape (100,)
lprime shape (1399, 1355191)
4	1399	0.916382	0.878583
L shape (1399, 1355191)
augment shape (100,)
lprime shape (1499, 1355191)
L shape (1399, 1355191)
augment shape (100,)
lprime shape (1499, 1355191)
5	1499	0.921273	0.878034
L shape (1499, 1355191)
augment shape (100,)
lprime shape (1599, 1355191)
L shape (1499, 1355191)
augment shape (100,)
lprime shape (1599, 1355191)
6	1599	0.931836	0.877208
L shape (1599, 1355191)
augment shape (100,)
lprime shape (1699, 1355191)
L shape (1599, 1355191)
augment shape (100,)
lprime shape (1699, 1355191)
7	1699	0.935252	0.876646
L shape (1699, 1355191)
augment shape (100,)
lprime shape (1799, 1355191)
L shape (1699, 1355191)
augment shape (100,)
lprime shape (1799, 1355191)
8	1799	0.941082	0.875859
L shape (1799, 1355191)
augment shape (100,)
lprime shape (1899, 1355191)
L shape (1799, 1355191)
augment shape (100,)
lprime shape (1899, 1355191)
9	1899	0.945753	0.875062
L shape (1899, 1355191)
augment shape (100,)
lprime shape (1999, 1355191)
L shape (1899, 1355191)
augment shape (100,)
lprime shape (1999, 1355191)
10	1999	0.944970	0.874257
L shape (1999, 1355191)
augment shape (100,)
lprime shape (2099, 1355191)
L shape (1999, 1355191)
augment shape (100,)
lprime shape (2099, 1355191)
11	2099	0.950453	0.873387
L shape (2099, 1355191)
augment shape (100,)
lprime shape (2199, 1355191)
L shape (2099, 1355191)
augment shape (100,)
lprime shape (2199, 1355191)
12	2199	0.950889	0.873743
L shape (2199, 1355191)
augment shape (100,)
lprime shape (2299, 1355191)
L shape (2199, 1355191)
augment shape (100,)
lprime shape (2299, 1355191)
13	2299	0.953893	0.873312
L shape (2299, 1355191)
augment shape (100,)
lprime shape (2399, 1355191)
L shape (2299, 1355191)
augment shape (100,)
lprime shape (2399, 1355191)
14	2399	0.952899	0.872706
L shape (2399, 1355191)
augment shape (100,)
lprime shape (2499, 1355191)
L shape (2399, 1355191)
augment shape (100,)
lprime shape (2499, 1355191)
15	2499	0.959980	0.871692
L shape (2499, 1355191)
augment shape (100,)
lprime shape (2599, 1355191)
L shape (2499, 1355191)
augment shape (100,)
lprime shape (2599, 1355191)
16	2599	0.959599	0.871874
L shape (2599, 1355191)
augment shape (100,)
lprime shape (2699, 1355191)
L shape (2599, 1355191)
augment shape (100,)
lprime shape (2699, 1355191)
17	2699	0.963695	0.870960
L shape (2699, 1355191)
augment shape (100,)
lprime shape (2799, 1355191)
L shape (2699, 1355191)
augment shape (100,)
lprime shape (2799, 1355191)
18	2799	0.964628	0.870094
L shape (2799, 1355191)
augment shape (100,)
lprime shape (2899, 1355191)
L shape (2799, 1355191)
augment shape (100,)
lprime shape (2899, 1355191)
19	2899	0.966189	0.870153
L shape (2899, 1355191)
augment shape (100,)
lprime shape (2999, 1355191)
L shape (2899, 1355191)
augment shape (100,)
lprime shape (2999, 1355191)
20	2999	0.966986	0.869153
L shape (2999, 1355191)
augment shape (100,)
lprime shape (3099, 1355191)
L shape (2999, 1355191)
augment shape (100,)
lprime shape (3099, 1355191)

In [ ]: