Predicting Pathogen from RNAseq data


In [1]:
from math import pow
import pandas as pd
import numpy as np
from sklearn.metrics import confusion_matrix, accuracy_score, classification_report
from sklearn.model_selection import train_test_split, ShuffleSplit, cross_val_score, GridSearchCV
from sklearn.svm import LinearSVC
from sklearn.feature_selection import RFE, SelectFromModel
from sklearn.pipeline import Pipeline

random_state = 0xDEADCAFE
use_gsea_genes = True
rfe_step = 0.02

In [2]:
if use_gsea_genes:
    with open("geneSet.txt") as gsfile:
        gsea_genes = list(gsfile.read().splitlines())[1:]
    rfe_step = 1

In [3]:
patient_groups=["control", "viral", "bacterial", "fungal"]
group_id = lambda name: patient_groups.index(name)

X = pd.DataFrame.from_csv("combineSV_WTcpmtable_v2.txt", sep="\s+").T
y = [group_id("bacterial")] * 29 \
    + [group_id("viral")] * 42 \
    + [group_id("fungal")] * 10 \
    + [group_id("control")] * 61

if use_gsea_genes:
    X = X.filter(gsea_genes, axis=1)

print "Complete data set has %d samples and %d features." % (X.shape[0], X.shape[1])
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=random_state)
print "Training set has %d samples. Testing set has %d samples." % (len(X_train), len(X_test))


Complete data set has 142 samples and 582 features.
Training set has 99 samples. Testing set has 43 samples.

In [4]:
import matplotlib.pyplot as plt
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import StratifiedKFold
from sklearn.feature_selection import RFECV
from sklearn.datasets import make_classification
from sklearn.linear_model import RandomizedLogisticRegression

parameters={'estimator__C': [pow(2, i) for i in xrange(-25, 4, 1)]}
svc = LinearSVC(class_weight='balanced')
rfe = RFECV(estimator=svc, step=rfe_step, cv=2, scoring='accuracy', n_jobs=1)
clf = GridSearchCV(rfe, parameters, scoring='accuracy', n_jobs=8, cv=2, verbose=1)
clf.fit(X_train, y_train)

# summarize results
print("Best: %f using %s" % (clf.best_score_, clf.best_params_))
means = clf.cv_results_['mean_test_score']
stds = clf.cv_results_['std_test_score']
params = clf.cv_results_['params']
for mean, stdev, param in zip(means, stds, params):
    print("%f (%f) with: %r" % (mean, stdev, param))


Fitting 2 folds for each of 29 candidates, totalling 58 fits
[Parallel(n_jobs=8)]: Done  34 tasks      | elapsed:  2.2min
[Parallel(n_jobs=8)]: Done  58 out of  58 | elapsed:  3.6min finished
Best: 0.484848 using {'estimator__C': 2.384185791015625e-07}
0.434343 (0.023274) with: {'estimator__C': 2.9802322387695312e-08}
0.474747 (0.044710) with: {'estimator__C': 5.960464477539063e-08}
0.444444 (0.013474) with: {'estimator__C': 1.1920928955078125e-07}
0.484848 (0.014699) with: {'estimator__C': 2.384185791015625e-07}
0.484848 (0.034910) with: {'estimator__C': 4.76837158203125e-07}
0.434343 (0.023274) with: {'estimator__C': 9.5367431640625e-07}
0.434343 (0.017149) with: {'estimator__C': 1.9073486328125e-06}
0.393939 (0.022049) with: {'estimator__C': 3.814697265625e-06}
0.393939 (0.038585) with: {'estimator__C': 7.62939453125e-06}
0.414141 (0.058184) with: {'estimator__C': 1.52587890625e-05}
0.454545 (0.023886) with: {'estimator__C': 3.0517578125e-05}
0.454545 (0.044097) with: {'estimator__C': 6.103515625e-05}
0.434343 (0.043485) with: {'estimator__C': 0.0001220703125}
0.434343 (0.043485) with: {'estimator__C': 0.000244140625}
0.383838 (0.048997) with: {'estimator__C': 0.00048828125}
0.414141 (0.037973) with: {'estimator__C': 0.0009765625}
0.404040 (0.007962) with: {'estimator__C': 0.001953125}
0.363636 (0.132292) with: {'estimator__C': 0.00390625}
0.434343 (0.017149) with: {'estimator__C': 0.0078125}
0.434343 (0.043485) with: {'estimator__C': 0.015625}
0.424242 (0.027561) with: {'estimator__C': 0.03125}
0.414141 (0.017761) with: {'estimator__C': 0.0625}
0.404040 (0.028173) with: {'estimator__C': 0.125}
0.343434 (0.010412) with: {'estimator__C': 0.25}
0.383838 (0.008574) with: {'estimator__C': 0.5}
0.434343 (0.003062) with: {'estimator__C': 1.0}
0.373737 (0.018986) with: {'estimator__C': 2.0}
0.404040 (0.012249) with: {'estimator__C': 4.0}
0.434343 (0.037360) with: {'estimator__C': 8.0}

In [5]:
best_estimator = clf.best_estimator_

print("Optimal number of features : %d" % best_estimator.n_features_)
print("Recursive Feature Elimination (RFE) eliminated %d features" % (X.shape[1] - best_estimator.n_features_))

# Plot number of features VS. cross-validation scores
plt.figure()
plt.xlabel("Number of feature subsets selected")
plt.ylabel("Cross validation score")
plt.plot(range(1, len(best_estimator.grid_scores_) + 1), best_estimator.grid_scores_)
plt.show()


Optimal number of features : 248
Recursive Feature Elimination (RFE) eliminated 334 features

In [6]:
%matplotlib inline
from learning_curves import plot_learning_curve
import matplotlib.pyplot as plt

def create_learning_curve(title, model):
    cv = ShuffleSplit(n_splits=2, test_size=0.3, random_state=random_state)
    plot_learning_curve(model, title, X, y, (0.2, 1.01), cv=cv, n_jobs=1)
    
create_learning_curve("Learning Curves Full Feature Set", svc)
create_learning_curve("Learning Curves Limited Feature Set", best_estimator.estimator_)

plt.show()


None

Make predictions based on the model


In [7]:
from classification_metrics import classification_metrics

# make predictions
predicted = best_estimator.predict(X_test)

classification_metrics(y_test, predicted, patient_groups)


Accuracy was 55.81%

             precision    recall  f1-score   support

    control       0.50      0.61      0.55        18
      viral       0.83      0.31      0.45        16
  bacterial       0.64      0.88      0.74         8
     fungal       0.25      1.00      0.40         1

avg / total       0.64      0.56      0.55        43

Confusion Matrix: cols = predictions, rows = actual

                       control          viral      bacterial         fungal
        control             11              1              3              3
          viral             10              5              1              0
      bacterial              1              0              7              0
         fungal              0              0              0              1

Review model predictions


In [8]:
probs = np.array(best_estimator.predict(X_test))

d = {"Predicted": [patient_groups[i] for i in best_estimator.predict(X_test)],
     "Actual": [patient_groups[i] for i in y_test]}

patient_df = pd.DataFrame(d, index=X_test.index)
patient_df.sort_values(by="Actual")


Out[8]:
Actual Predicted
MNC.473 bacterial bacterial
MN_223 bacterial bacterial
MN_208 bacterial bacterial
MNC.557 bacterial bacterial
MN_363 bacterial bacterial
MN_304 bacterial bacterial
MNC.214 bacterial bacterial
MN_307 bacterial control
MNC.611 control control
MN_366 control control
MNC.633 control control
MNC.151 control control
MNC.595 control fungal
MNC.571 control control
MNC.452.y control bacterial
MNC.294 control fungal
MNC.172 control control
MNC.612 control viral
MNC.392 control bacterial
MNC.631 control bacterial
MNC.173 control control
MNC.676 control control
MNC.213 control fungal
MNC.533 control control
MNC.354 control control
MNC.675 control control
MN_323 fungal fungal
MN_382 viral control
MNC.033 viral control
MN_207 viral control
MNC.091 viral viral
MN_097.y viral control
MNC.215 viral control
MNC.474 viral viral
MNC.012.y viral control
MN_100 viral control
MNC.534 viral control
MNC.674 viral control
MN_282 viral bacterial
MN_171 viral viral
MN_284 viral control
MNC.155 viral viral
MNC.014 viral viral

Review patients the model misclassified


In [9]:
patient_df[patient_df["Predicted"] != patient_df["Actual"]]


Out[9]:
Actual Predicted
MN_097.y viral control
MN_284 viral control
MN_282 viral bacterial
MNC.674 viral control
MNC.534 viral control
MNC.631 control bacterial
MNC.392 control bacterial
MNC.612 control viral
MN_307 bacterial control
MN_100 viral control
MNC.595 control fungal
MNC.012.y viral control
MNC.215 viral control
MN_207 viral control
MNC.452.y control bacterial
MNC.294 control fungal
MN_382 viral control
MNC.213 control fungal
MNC.033 viral control