In [70]:
import sys
from time import time
from pprint import pprint
import numpy as np
import scipy
import scipy.sparse
import joblib
import sklearn
import sklearn.svm
import sklearn.datasets
import sklearn.cross_validation
from sklearn.datasets import fetch_20newsgroups
from sklearn.semi_supervised import LabelPropagation
from sklearn.semi_supervised import LabelSpreading
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import TfidfVectorizer
import warnings
warnings.filterwarnings('ignore')
%pylab inline
In [21]:
dataset = fetch_20newsgroups(data_home=".")
In [62]:
group_names = dataset.target_names
group_counts = [(dataset.target==group_index).sum() for group_index in range(len(group_names))]
half_docs_group_index = np.where((np.cumsum(group_counts) > 5600))[0][0]
y = (dataset.target > half_docs_group_index).astype(int)
In [63]:
vectorizer = TfidfVectorizer(max_df=0.5, min_df=4)
X = vectorizer.fit_transform(dataset.data)
In [81]:
print X.shape
In [82]:
from collections import Counter
Counter(y)
Out[82]:
In [74]:
def test_ssl_model(labeled_indices, unlabeled_indices, model):
L = X[labeled_indices]
U = X[unlabeled_indices]
y_l = y[labeled_indices]
y_u = y[unlabeled_indices]
l_ssl = y.copy()
l_ssl[unlabeled_indices]=-1
model.fit(X,l_ssl)
sl = model.score(L,y_l)
su = model.score(U,y_u)
return sl, su
def test_baseline_model(labeled_indices, unlabeled_indices, model):
L = X[labeled_indices]
U = X[unlabeled_indices]
y_l = y[labeled_indices]
y_u = y[unlabeled_indices]
model.fit(L, y_l)
sl = model.score(L,y_l)
su = model.score(U,y_u)
return sl, su
In [79]:
test_size = 0.97
kernel = "rbf"
gamma = 100
alpha = 0.3
lp_model = LabelPropagation(kernel=kernel, alpha=alpha, gamma=gamma)
svm_model = sklearn.svm.LinearSVC()
ridge_model = sklearn.linear_model.RidgeClassifier()
splits = sklearn.cross_validation.StratifiedShuffleSplit(y, n_iter=1, test_size=test_size)
labeled_indices, unlabeled_indices = splits.__iter__().next()
print test_baseline_model(labeled_indices, unlabeled_indices, ridge_model)
print test_baseline_model(labeled_indices, unlabeled_indices, svm_model)
print test_ssl_model(labeled_indices, unlabeled_indices, lp_model)