In [1]:
import h2o
from h2o.estimators.gbm import H2OGradientBoostingEstimator
from h2o.estimators.glm import H2OGeneralizedLinearEstimator
In [2]:
h2o.init()
In [3]:
from h2o.h2o import _locate # private function. used to find files within h2o git project directory.
#uploading data file to h2o
air = h2o.import_file(path=_locate("smalldata/airlines/AirlinesTrain.csv.zip"))
In [4]:
# Constructing validation and train sets by sampling (20/80)
# creating a column as tall as air.nrow
r = air[0].runif()
air_train = air[r < 0.8]
air_valid = air[r >= 0.8]
myX = ["Origin", "Dest", "Distance", "UniqueCarrier", "fMonth", "fDayofMonth", "fDayOfWeek"]
myY = "IsDepDelayed"
In [5]:
#gbm
gbm = H2OGradientBoostingEstimator(distribution="bernoulli",
ntrees=100,
max_depth=3,
learn_rate=0.01)
gbm.train(x =myX,
y =myY,
training_frame =air_train,
validation_frame=air_valid)
gbm.show()
In [6]:
#glm
glm = H2OGeneralizedLinearEstimator(family = "binomial", solver="L_BFGS")
glm.train(x =myX,
y =myY,
training_frame =air_train,
validation_frame=air_valid)
glm.pprint_coef()
In [7]:
#uploading test file to h2o
air_test = h2o.import_file(path=_locate("smalldata/airlines/AirlinesTest.csv.zip"))
In [8]:
# predicting & performance on test file
gbm_pred = gbm.predict(air_test)
print "GBM predictions: "
gbm_pred.head()
gbm_perf = gbm.model_performance(air_test)
print "GBM performance: "
gbm_perf.show()
glm_pred = glm.predict(air_test)
print "GLM predictions: "
glm_pred.head()
glm_perf = glm.model_performance(air_test)
print "GLM performance: "
glm_perf.show()
In [9]:
# Building confusion matrix for test set
gbm_CM = gbm_perf.confusion_matrix()
print(gbm_CM)
print
glm_CM = glm_perf.confusion_matrix()
print(glm_CM)
In [10]:
# ROC for test set
print('GBM Precision: {0}'.format(gbm_perf.precision()))
print('GBM Accuracy: {0}'.format(gbm_perf.accuracy()))
print('GBM AUC: {0}'.format(gbm_perf.auc()))
print
print('GLM Precision: {0}'.format(glm_perf.precision()))
print('GLM Accuracy: {0}'.format(glm_perf.accuracy()))
print('GLM AUC: {0}'.format(glm_perf.auc()))