In [19]:
import matplotlib.pyplot as plt
import math

from sklearn import metrics

from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import ParameterGrid, train_test_split
from sklearn.pipeline import Pipeline
from sklearn.tree import DecisionTreeClassifier
from sklearn.base import TransformerMixin,BaseEstimator
from sklearn.decomposition import PCA

from sklearn.datasets import make_classification

In [44]:
X, y = make_classification(n_samples=10000, n_features=25, n_redundant=10, n_repeated=5, flip_y=0.5)

In [45]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.35)

In [49]:
param_grid = [
    {
        'pca__n_components':[1,2,5,10],
        'clf__penalty':['l1','l2'],
        'clf__C':[0.001,0.01,0.1,1.0,10.0]
    }
]

In [50]:
pipeline = Pipeline([
    ('pca',PCA()),
    ('clf',LogisticRegression())
])

In [51]:
num_cols = 5
num_rows = math.ceil(len(ParameterGrid(param_grid)) / num_cols)

plt.clf()
fig,axes = plt.subplots(num_rows,num_cols,sharey=True)
fig.set_size_inches(num_cols*3,num_rows*3)

for i,g in enumerate(ParameterGrid(param_grid)):
    
    pipeline.set_params(**g)
    pipeline.fit(X_train,y_train)
    
    y_preds = pipeline.predict_proba(X_test)

    # take the second column because the classifier outputs scores for
    # the 0 class as well
    preds = y_preds[:,1]

    # fpr means false-positive-rate
    # tpr means true-positive-rate
    fpr, tpr, _ = metrics.roc_curve(y_test, preds)

    auc_score = metrics.auc(fpr, tpr)
       
    ax = axes[i // num_cols, i % num_cols]
    
    # don't print the whole name or it won't fit
    ax.set_title(str(g),fontsize=8)
    ax.plot(fpr, tpr, label='AUC = {:.2f}'.format(auc_score))
    ax.legend(loc='lower right')

    # it's helpful to add a diagonal to indicate where chance 
    # scores lie (i.e. just flipping a coin)
    ax.plot([0,1],[0,1],'r--')

    ax.set_xlim([-0.1,1.1])
    ax.set_ylim([-0.1,1.1])
    ax.set_ylabel('True Positive Rate')
    ax.set_xlabel('False Positive Rate')

plt.gcf().tight_layout()
plt.show()


<matplotlib.figure.Figure at 0x7f6e53ce8668>

In [ ]:


In [ ]: