In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import train_test_split
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import SVC
In [2]:
# Load dataset
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data"
names = ['sepal-length', 'sepal-width', 'petal-length', 'petal-width', 'class']
dataset = pd.read_csv(url, names=names)
In [3]:
dataset.head(10)
Out[3]:
In [4]:
# Split-out validation dataset
array = dataset.values
X = array[:,0:4]
Y = array[:,-1]
validation_size = 0.20
seed = 7
# Usage of new sklearn method train_test_split instead of cross validation
X_train, X_validation, Y_train, Y_validation = train_test_split(X, Y, test_size=validation_size, random_state=seed)
In [5]:
# Test options and evaluation metric
seed = 7
scoring = 'accuracy'
In [6]:
# GridSearchCV with SVC
parameters = {'kernel': ('linear', 'rbf'), 'C':[1,2,3,4,5,6,7,8,9,10]}
svc = SVC()
#gscv = GridSearchCV(estimator=SVC(), param_grid=parameters, cv=10, scoring=scoring)
In [7]:
# GridSearchCV Performance depending on number of Jobs
jobs = 8
timeit_results = []
for _ in range(jobs):
gscvSVC = GridSearchCV(estimator=SVC(), param_grid=parameters, cv=10, n_jobs=(_+1), scoring=scoring)
tr = %timeit -o gscvSVC.fit(X_train, Y_train)
timeit_results.append(tr)
# best_times are extracted
best_times = [timer.best for timer in timeit_results]
In [9]:
x = np.arange(1,9)
labels = ['%i. Core' % i for i in x]
fig = plt.figure()
fig.suptitle('Hyperparameter search time per number of cores')
ax = fig.add_subplot(111)
ax.set_xlabel('Number of cores')
ax.set_ylabel('Search time (s)')
ax.plot(x, best_times)
plt.xticks(x, labels, rotation='vertical')
plt.show()
Search time is constantly increasing due to two main factors
In [14]:
gscvSVC.cv_results_
Out[14]:
In [ ]: