Use the pseudocode you came up with in class to write your own 5-fold cross-validation function that splits the data set into.


In [1]:
from sklearn import datasets
from random import shuffle
from sklearn import tree
from sklearn.cross_validation import cross_val_score
from sklearn import metrics
iris = datasets.load_iris() # load iris data set
x = iris.data[:,2:] # the attributes
y = iris.target # the target variable

In [2]:
def cross_validation(attributes, target):
    dt = tree.DecisionTreeClassifier()
    
    data_and_target = list(zip(attributes, target))
    shuffle(data_and_target)
    list_len = len(attributes)
    len_set = list_len // 5
    
    # create the 5 sets
    sets = []
    range_5 = range(0, 5)    
    for i in range_5:
        sets.append(data_and_target[i*len_set:(i+1)*len_set])
    
    # train and test
    for i in range_5:
        print("Testing with set {}, training with all other sets…".format(i+1))
        training_sets = sets.copy()
        test_set = training_sets.pop(i)
        training_list = [i for sublist in training_sets for i in sublist]
        
        x_train, y_train = zip(*training_list)
        x_test, y_test = zip(*test_set)

        dt = dt.fit(attributes,target)
        y_pred = dt.fit(x_train, y_train).predict(x_test)
        print("Accuracy: {0:.3f}\n".format(metrics.accuracy_score(y_test, y_pred)))


# custom function
cross_validation(x, y)

# cross_validation module
print("Using the “cross_validation” module:")
dt = tree.DecisionTreeClassifier()
results = cross_val_score(dt,x,y,cv=5)
for item in results:
    print("Accuracy: {0:.3f}".format(item))


Testing with set 1, training with all other sets…
Accuracy: 0.967

Testing with set 2, training with all other sets…
Accuracy: 1.000

Testing with set 3, training with all other sets…
Accuracy: 0.900

Testing with set 4, training with all other sets…
Accuracy: 0.967

Testing with set 5, training with all other sets…
Accuracy: 0.933

Using the “cross_validation” module:
Accuracy: 0.967
Accuracy: 0.967
Accuracy: 0.900
Accuracy: 0.933
Accuracy: 1.000

In [ ]: