In [1]:
import pickle
import sys
sys.path.append("../tools/")
from feature_format import featureFormat, targetFeatureSplit
data_dict = pickle.load(open("../final_project/final_project_dataset.pkl", "r") )
### first element is our labels, any added elements are predictor
### features. Keep this the same for the mini-project, but you'll
### have a different feature list when you do the final project.
features_list = ["poi", "salary"]
data = featureFormat(data_dict, features_list)
labels, features = targetFeatureSplit(data)
print len(labels), len(features)
Create a decision tree classifier (just use the default parameters), train it on all the data. Print out the accuracy. THIS IS AN OVERFIT TREE, DO NOT TRUST THIS NUMBER! Nonetheless,
In [2]:
from sklearn import tree
from time import time
def submitAcc(features, labels):
return clf.score(features, labels)
clf = tree.DecisionTreeClassifier()
t0 = time()
clf.fit(features, labels)
print("done in %0.3fs" % (time() - t0))
In [3]:
pred = clf.predict(features)
print "Classifier with accurancy %.2f%%" % (submitAcc(features, labels))
Now you’ll add in training and testing, so that you get a trustworthy accuracy number. Use the train_test_split validation available in sklearn.cross_validation; hold out 30% of the data for testing and set the random_state parameter to 42 (random_state controls which points go into the training set and which are used for testing; setting it to 42 means we know exactly which events are in which set, and can check the results you get).
In [15]:
from sklearn import cross_validation
X_train, X_test, y_train, y_test = cross_validation.train_test_split(features, labels, test_size=0.30, random_state=42)
print len(X_train), len(y_train)
print len(X_test), len(y_test)
In [16]:
clf = tree.DecisionTreeClassifier()
t0 = time()
clf.fit(X_train, y_train)
print("done in %0.3fs" % (time() - t0))
In [17]:
pred = clf.predict(X_test)
print "Classifier with accurancy %.2f%%" % (submitAcc(X_test, y_test))
Aaaand the testing data brings us back down to earth after that 99% accuracy
In [ ]: