In [22]:
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import os
%matplotlib inline

In [80]:
plt.rcParams['axes.labelsize'] = 14
plt.rcParams['xtick.labelsize'] = 12
plt.rcParams['ytick.labelsize'] = 12

# Where to save the figures
PROJECT_ROOT_DIR = "."
CHAPTER_ID = "classification"

def save_fig(fig_id, tight_layout=True):
    path = os.path.join(PROJECT_ROOT_DIR, "images", CHAPTER_ID, fig_id + ".png")
    print("Saving figure", fig_id)
    if tight_layout:
        plt.tight_layout()
    plt.savefig(path, format='png', dpi=300)

In [2]:
# load the data first
raw_data = pd.read_csv("/Users/weilu/Research/data/training_data/training_set.csv")

In [7]:
raw_test_data = pd.read_csv("/Users/weilu/Research/data/test_data/test_data.csv")

In [90]:
raw_test_data.groupby("Name").describe().stack()


Out[90]:
Step Qw Rw VTotal QGO Burial Water Rama Chain Chi DSSP P_AP Helix Frag_Mem Name
0 1 0.642206 -25357.541220 -1041.039794 11.431633 -122.857483 -67.858201 -647.698337 168.209829 34.302323 -0.000000 -6.143964 -53.760710 -356.664884 1MBA
1 2 0.592049 -23982.747463 -892.903884 17.434915 -124.679973 -48.483913 -629.313620 187.855837 56.408218 -0.000542 -3.972628 -48.821527 -299.330651 1MBA
2 3 0.646271 -24229.837390 -935.248471 14.043660 -125.714258 -49.021336 -620.676503 179.710908 39.973846 -0.000959 -3.960233 -47.547005 -322.056590 1MBA
3 4 0.667102 -24093.207322 -922.127259 15.770030 -126.519994 -45.511713 -629.844627 168.412484 42.580307 -0.000747 -4.099364 -43.934473 -298.979160 1MBA
4 5 0.697657 -24526.447174 -896.456937 15.329368 -125.012346 -50.401883 -622.295216 190.424375 43.516979 -0.003721 -4.771043 -46.453310 -296.790139 1MBA
5 6 0.661085 -24545.307471 -912.176207 17.765244 -126.028231 -45.715756 -632.824581 191.995985 52.392239 -0.000478 -5.017589 -42.258788 -322.484251 1MBA
6 7 0.680621 -24800.514112 -932.935837 16.294323 -124.390942 -51.328426 -637.752450 193.668463 41.910365 -0.000206 -4.998277 -51.413289 -314.925398 1MBA
7 8 0.696199 -24400.842171 -935.657349 18.676804 -123.564156 -45.588914 -636.146380 177.531075 36.142383 -0.000866 -4.465104 -39.649466 -318.592725 1MBA
8 9 0.663718 -24327.115316 -870.940819 20.282472 -125.657467 -47.968761 -607.062704 193.970523 39.476324 -0.001125 -4.491294 -36.869053 -302.619734 1MBA
9 10 0.671354 -24980.173530 -926.204416 18.292951 -125.698964 -49.284430 -610.906347 172.990613 37.183339 -0.004072 -5.126619 -42.819927 -320.830960 1MBA
10 11 0.694080 -25138.818691 -925.398977 17.999514 -123.510309 -51.258481 -603.937749 165.036957 39.555329 -0.000823 -6.618758 -48.358616 -314.306040 1MBA
11 12 0.703669 -24713.494041 -922.093862 20.054106 -124.950337 -50.010328 -620.778886 179.992447 31.635631 -0.000003 -5.237918 -44.679705 -308.118868 1MBA
12 13 0.732757 -24671.077428 -945.070457 15.605510 -124.380980 -48.849768 -628.017372 185.681130 32.078439 -0.000070 -5.827495 -44.320264 -327.039586 1MBA
13 14 0.718830 -24991.764383 -936.543660 17.728568 -125.320291 -49.076855 -630.690263 169.317541 39.028191 -0.000229 -5.375784 -40.711567 -311.442970 1MBA
14 15 0.713275 -25098.759475 -889.836198 16.876637 -125.043877 -51.100257 -619.055442 189.608059 45.250010 -0.000609 -6.006783 -37.634340 -302.729598 1MBA
15 16 0.711584 -25214.390925 -890.741803 16.353313 -125.297044 -50.543800 -611.617684 198.000987 41.281896 -0.001095 -6.093572 -41.189225 -311.635579 1MBA
16 17 0.706208 -24968.085366 -919.612861 16.176461 -125.233852 -46.510455 -622.183164 172.289259 41.037242 -0.000501 -5.121430 -49.001706 -301.064715 1MBA
17 18 0.719944 -25602.682786 -916.519095 19.294601 -124.077036 -49.548504 -633.283786 184.093797 44.025839 -0.004478 -7.114692 -45.609776 -304.295060 1MBA
18 19 0.713793 -25859.835768 -937.703989 20.235877 -126.189737 -53.950566 -637.831142 183.438780 39.189815 -0.001062 -8.021401 -39.692070 -314.882483 1MBA
19 20 0.697168 -25700.912949 -878.110831 21.966144 -123.168694 -48.965783 -613.603484 188.795123 44.912393 -0.001673 -7.774788 -42.380744 -297.889325 1MBA
20 21 0.693139 -25725.627901 -908.677742 23.145199 -122.558317 -49.467892 -629.747087 192.755908 42.698117 -0.001088 -7.649629 -47.461241 -310.391711 1MBA
21 22 0.707157 -25330.668540 -898.354629 22.081920 -122.970081 -52.810694 -622.750256 197.330470 41.141167 -0.000909 -7.800390 -42.862190 -309.713667 1MBA
22 23 0.662130 -25282.260357 -925.192955 24.324214 -123.243448 -47.626578 -630.950977 181.815043 39.580540 -0.000002 -7.593266 -44.804548 -316.693932 1MBA
23 24 0.671960 -24804.626547 -827.788370 23.284372 -124.902858 -45.398328 -612.131651 203.654104 50.562349 -0.001875 -7.769709 -39.755764 -275.329011 1MBA
24 25 0.675152 -25001.757816 -903.822872 23.040718 -122.895731 -50.983775 -630.292876 186.921610 57.227237 -0.001967 -7.530705 -46.750232 -312.557151 1MBA
25 26 0.699897 -25887.549016 -881.373128 21.641617 -122.196867 -48.584160 -616.648465 205.096054 46.246699 -0.003519 -9.072163 -54.793565 -303.058757 1MBA
26 27 0.743188 -25724.209831 -949.375837 18.495553 -123.160694 -50.919858 -638.535472 173.609172 47.778706 -0.003136 -9.316160 -45.389111 -321.934838 1MBA
27 28 0.738437 -25648.828587 -926.855122 16.621925 -122.760352 -50.775119 -626.182392 182.859002 42.578815 -0.002272 -7.524734 -51.037865 -310.632129 1MBA
28 29 0.685782 -25433.156568 -932.716607 23.160309 -124.408750 -48.037070 -633.271055 196.459668 31.245115 -0.001844 -7.244246 -50.736401 -319.882333 1MBA
29 30 0.706074 -25416.954711 -929.810213 20.672715 -124.027194 -49.193509 -637.037034 188.073066 40.127256 -0.000532 -6.809812 -51.253890 -310.361278 1MBA
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
15996 1981 0.409053 -17999.430222 -687.945391 64.749397 -94.364869 -36.660676 -412.404768 0.000000 0.000000 -14.187624 -13.723739 -10.512898 -170.840214 T251
15997 1982 0.412973 -17706.985422 -662.951404 65.369576 -95.824086 -39.239889 -397.547297 0.000000 0.000000 -12.931692 -12.945733 -10.668885 -159.163397 T251
15998 1983 0.412490 -17734.987728 -651.697834 66.991917 -96.133700 -37.559237 -390.226102 0.000000 0.000000 -12.956347 -12.129292 -10.468002 -159.217072 T251
15999 1984 0.381225 -17385.595026 -668.725491 70.517699 -94.178319 -35.374407 -402.370681 0.000000 0.000000 -14.924097 -11.556034 -13.982130 -166.857522 T251
16000 1985 0.411569 -17586.846192 -680.370962 65.885466 -94.001182 -36.093283 -412.413635 0.000000 0.000000 -14.301954 -11.052160 -12.687701 -165.706513 T251
16001 1986 0.447329 -17906.513485 -696.843772 64.405015 -95.110788 -37.439279 -414.429041 0.000000 0.000000 -15.172935 -11.970102 -13.270076 -173.856568 T251
16002 1987 0.444756 -18042.046753 -691.885893 61.861258 -94.626819 -39.873685 -412.972379 0.000000 0.000000 -16.207708 -13.480528 -11.871953 -164.714078 T251
16003 1988 0.430409 -18093.958579 -684.855304 64.609306 -96.071619 -38.016655 -405.415886 0.000000 0.000000 -14.221513 -12.443746 -13.325077 -169.970113 T251
16004 1989 0.432102 -17881.002230 -684.615045 63.509618 -93.540973 -36.252660 -406.601098 0.000000 0.000000 -16.728372 -11.737532 -14.501614 -168.762415 T251
16005 1990 0.436026 -18087.815053 -691.490257 61.249654 -94.523777 -39.387193 -410.510786 0.000000 0.000000 -16.066533 -12.922115 -12.347810 -166.981698 T251
16006 1991 0.432148 -18028.838163 -672.562106 62.402149 -95.967017 -42.356838 -394.356266 0.000000 0.000000 -14.886479 -13.040497 -13.837657 -160.519501 T251
16007 1992 0.428299 -17821.646326 -667.567507 62.961995 -94.190221 -38.883845 -394.041379 0.000000 0.000000 -15.134337 -13.145636 -9.722363 -165.411722 T251
16008 1993 0.445055 -17953.021943 -676.239168 63.411669 -95.310941 -36.836461 -405.651693 0.000000 0.000000 -14.167776 -11.149783 -10.754804 -165.779378 T251
16009 1994 0.433835 -17825.536395 -681.205940 63.169846 -94.463526 -40.801754 -400.689789 0.000000 0.000000 -15.272762 -12.098315 -11.391685 -169.657954 T251
16010 1995 0.442354 -17724.481384 -664.789940 63.542508 -95.196316 -37.376438 -393.112601 0.000000 0.000000 -16.050392 -12.422823 -10.741236 -163.432641 T251
16011 1996 0.434745 -17754.084522 -672.442338 64.434908 -94.533419 -34.005547 -403.886281 0.000000 0.000000 -15.092364 -11.954105 -14.205701 -163.199828 T251
16012 1997 0.422195 -17769.979895 -681.833646 65.157560 -94.231425 -38.802982 -406.905781 0.000000 0.000000 -15.728867 -12.699764 -12.908596 -165.713791 T251
16013 1998 0.448584 -18092.867298 -677.401650 63.545254 -93.618322 -39.201728 -399.306366 0.000000 0.000000 -14.613241 -11.704113 -13.180279 -169.322856 T251
16014 1999 0.437622 -18257.921397 -675.269864 63.907841 -93.647527 -39.513912 -400.491144 0.000000 0.000000 -15.612118 -10.583574 -11.608279 -167.721153 T251
16015 2000 0.432451 -18133.688583 -652.728393 64.247283 -94.284688 -37.103015 -389.481177 0.000000 0.000000 -14.963553 -10.807478 -9.468101 -160.867663 T251
16016 2001 0.424738 -18167.485016 -681.739774 63.014811 -95.551322 -34.766557 -399.333689 0.000000 0.000000 -14.813711 -13.448593 -12.063449 -174.777265 T251
16017 2002 0.421845 -17945.970506 -688.003028 63.864316 -94.280235 -38.230495 -408.321769 0.000000 0.000000 -14.807868 -12.586852 -17.248887 -166.391240 T251
16018 2003 0.439704 -17809.081368 -655.869433 64.292552 -94.986574 -32.730553 -391.494076 0.000000 0.000000 -12.806316 -12.302707 -7.676176 -168.165583 T251
16019 2004 0.445781 -18133.030311 -691.035529 61.747763 -94.010031 -35.740129 -411.960860 0.000000 0.000000 -13.843014 -12.599330 -12.481455 -172.148473 T251
16020 2005 0.423916 -17989.679224 -683.706551 63.149044 -93.611940 -38.960648 -404.571845 0.000000 0.000000 -15.072901 -11.885887 -13.252408 -169.499966 T251
16021 2006 0.436672 -18045.736997 -684.514419 63.777966 -94.300664 -34.139236 -411.805194 0.000000 0.000000 -14.460893 -11.992803 -11.031269 -170.562326 T251
16022 2007 0.409773 -18010.384341 -643.100641 65.686987 -93.909658 -37.051062 -378.533287 0.000000 0.000000 -14.578533 -11.086955 -8.116355 -165.511778 T251
16023 2008 0.431648 -18019.114654 -670.635961 64.258386 -94.594259 -31.778508 -409.378777 0.000000 0.000000 -14.257835 -11.768737 -10.452200 -162.664032 T251
16024 2009 0.409032 -17908.953299 -672.217025 65.476533 -94.910261 -36.213197 -411.832237 0.000000 0.000000 -13.389249 -12.530290 -5.182432 -163.635892 T251
16025 2010 0.421582 -18004.395479 -695.670619 64.564592 -93.993430 -36.303500 -426.367551 0.000000 0.000000 -15.631319 -13.601444 -5.526691 -168.811276 T251

16026 rows × 15 columns


In [13]:
FEATURES = ["Rw", "VTotal", "QGO"]
LABEL = ["Good"]

In [16]:
def normalize(x):
    return (x - x.mean()) / x.std()

In [19]:
X_train = raw_data[FEATURES].transform(normalize)

In [93]:
X_test = raw_test_data[FEATURES+["Name"]].groupby("Name").transform(normalize)

In [98]:
Y_test = pd.Series(raw_test_data["Qw"] > 0.7)

In [27]:
raw_data["Qw"].hist()


Out[27]:
<matplotlib.axes._subplots.AxesSubplot at 0x111514f98>

In [51]:
Y = pd.Series(raw_data["Qw"] > 0.7)

In [52]:
import numpy as np

# For illustration only. Sklearn has train_test_split()
def split_train_test(data, y, test_ratio):
    shuffled_indices = np.random.permutation(len(data))
    test_set_size = int(len(data) * test_ratio)
    test_indices = shuffled_indices[:test_set_size]
    train_indices = shuffled_indices[test_set_size:]
    return data.iloc[train_indices], y.iloc[train_indices], data.iloc[test_indices], y.iloc[test_indices]

In [53]:
train_set, train_y, test_set, test_y = split_train_test(X_train, Y, 0.2)

In [59]:
from sklearn.linear_model import SGDClassifier

sgd_clf = SGDClassifier(random_state=42)
sgd_clf.fit(train_set.values, train_y.values)


Out[59]:
SGDClassifier(alpha=0.0001, average=False, class_weight=None, epsilon=0.1,
       eta0=0.0, fit_intercept=True, l1_ratio=0.15,
       learning_rate='optimal', loss='hinge', n_iter=5, n_jobs=1,
       penalty='l2', power_t=0.5, random_state=42, shuffle=True, verbose=0,
       warm_start=False)

In [60]:
from sklearn.model_selection import cross_val_score
cross_val_score(sgd_clf, train_set.values, train_y.values, cv=3, scoring="accuracy")


Out[60]:
array([ 0.83768657,  0.74067164,  0.76865672])

In [62]:
from sklearn.model_selection import cross_val_predict

y_train_pred = cross_val_predict(sgd_clf, train_set.values, train_y.values, cv=3)

In [64]:
from sklearn.metrics import confusion_matrix

confusion_matrix(train_y.values, y_train_pred)


Out[64]:
array([[802, 194],
       [156, 456]])

In [99]:
y_scores = cross_val_predict(sgd_clf, X_test.values, Y_test.values, cv=3,
                             method="decision_function")

In [100]:
from sklearn.metrics import precision_recall_curve

precisions, recalls, thresholds = precision_recall_curve(Y_test.values, y_scores)

In [103]:
def plot_precision_recall_vs_threshold(precisions, recalls, thresholds):
    plt.plot(thresholds, precisions[:-1], "b--", label="Precision", linewidth=2)
    plt.plot(thresholds, recalls[:-1], "g-", label="Recall", linewidth=2)
    plt.xlabel("Threshold", fontsize=16)
    plt.legend(loc="upper left", fontsize=16)
    plt.ylim([0, 1])

plt.figure(figsize=(8, 4))
plot_precision_recall_vs_threshold(precisions, recalls, thresholds)
plt.xlim([-10, 10])
save_fig("precision_recall_vs_threshold_plot")
plt.show()


Saving figure precision_recall_vs_threshold_plot

In [65]:
y_scores = cross_val_predict(sgd_clf, train_set.values, train_y.values, cv=3,
                             method="decision_function")

In [67]:
from sklearn.metrics import precision_recall_curve

precisions, recalls, thresholds = precision_recall_curve(train_y.values, y_scores)

In [73]:
def plot_precision_recall_vs_threshold(precisions, recalls, thresholds):
    plt.plot(thresholds, precisions[:-1], "b--", label="Precision", linewidth=2)
    plt.plot(thresholds, recalls[:-1], "g-", label="Recall", linewidth=2)
    plt.xlabel("Threshold", fontsize=16)
    plt.legend(loc="upper left", fontsize=16)
    plt.ylim([0, 1])

plt.figure(figsize=(8, 4))
plot_precision_recall_vs_threshold(precisions, recalls, thresholds)
plt.xlim([-30, 30])
plt.show()



In [74]:
y_train_pred_90 = (y_scores > 6)

In [76]:
from sklearn.metrics import precision_score, recall_score
precision_score(train_y.values, y_train_pred_90)


Out[76]:
0.94968553459119498

In [77]:
def plot_precision_vs_recall(precisions, recalls):
    plt.plot(recalls, precisions, "b-", linewidth=2)
    plt.xlabel("Recall", fontsize=16)
    plt.ylabel("Precision", fontsize=16)
    plt.axis([0, 1, 0, 1])

plt.figure(figsize=(8, 6))
plot_precision_vs_recall(precisions, recalls)
plt.show()



In [78]:
from sklearn.metrics import roc_curve

fpr, tpr, thresholds = roc_curve(train_y.values, y_scores)

In [81]:
def plot_roc_curve(fpr, tpr, label=None):
    plt.plot(fpr, tpr, linewidth=2, label=label)
    plt.plot([0, 1], [0, 1], 'k--')
    plt.axis([0, 1, 0, 1])
    plt.xlabel('False Positive Rate', fontsize=16)
    plt.ylabel('True Positive Rate', fontsize=16)

plt.figure(figsize=(8, 6))
plot_roc_curve(fpr, tpr)
save_fig("roc_curve_plot")
plt.show()


Saving figure roc_curve_plot

In [82]:
from sklearn.metrics import roc_auc_score

roc_auc_score(train_y.values, y_scores)


Out[82]:
0.87445697823975643

In [108]:
from sklearn.ensemble import RandomForestClassifier
forest_clf = RandomForestClassifier(random_state=42, class_weight={0:.1, 1:.9})
y_probas_forest = cross_val_predict(forest_clf, train_set.values, train_y.values, cv=3,
                                    method="predict_proba")
y_scores_forest = y_probas_forest[:, 1] # score = proba of positive class
fpr_forest, tpr_forest, thresholds_forest = roc_curve(train_y.values,y_scores_forest)
plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, "b:", linewidth=2, label="SGD")
plot_roc_curve(fpr_forest, tpr_forest, "Random Forest")
plt.legend(loc="lower right", fontsize=16)
save_fig("roc_curve_comparison_plot")
plt.show()


Saving figure roc_curve_comparison_plot

In [83]:
from sklearn.ensemble import RandomForestClassifier
forest_clf = RandomForestClassifier(random_state=42)
y_probas_forest = cross_val_predict(forest_clf, train_set.values, train_y.values, cv=3,
                                    method="predict_proba")

In [84]:
y_scores_forest = y_probas_forest[:, 1] # score = proba of positive class
fpr_forest, tpr_forest, thresholds_forest = roc_curve(train_y.values,y_scores_forest)

In [87]:
plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, "b:", linewidth=2, label="SGD")
plot_roc_curve(fpr_forest, tpr_forest, "Random Forest")
plt.legend(loc="lower right", fontsize=16)
save_fig("roc_curve_comparison_plot")
plt.show()


Saving figure roc_curve_comparison_plot

In [86]:
from sklearn.ensemble import RandomForestClassifier
rnd_clf = RandomForestClassifier(n_estimators=500, n_jobs=-1, random_state=42)
rnd_clf.fit(train_set.values, train_y.values)
for name, score in zip(["Rw", "VTotal", "QGO"], rnd_clf.feature_importances_):
    print(name, score)


Rw 0.299058961859
VTotal 0.269804844783
QGO 0.431136193357