What Happens When You Predict on an Unseen Categorical Level

i.e. a rare event appears in your test set, which wasn't present in your training set.


This notebook walks through the steps of importing, cleaning, training, and testing on a data set where the test set contains a categorical level that was not present in the training set. You need to run steps 2. - 6. (to load all the variables) before you can jump between sections and run individual cells

Up to date for release H2O cluster version 3.8.2.1 and compatible with Python 2.7

Import Packages, Initialize an H2O Cluster & Load Data


In [1]:
import h2o, pandas, pprint, operator, numpy as np, matplotlib.pyplot as plt
from h2o.estimators.glm import H2OGeneralizedLinearEstimator
from h2o.estimators.gbm import H2OGradientBoostingEstimator
from h2o.estimators.random_forest import H2ORandomForestEstimator
from h2o.estimators.deeplearning import H2ODeepLearningEstimator
from h2o.estimators.naive_bayes import H2ONaiveBayesEstimator
from tabulate import tabulate

In [2]:
# Set 'interactive = True' for interactive plots, 'interactive = False' if not:
interactive = True
if not interactive: matplotlib.use('Agg', warn=False)

In [35]:
# Connect to a cluster 
h2o.init()


H2O cluster uptime: 8 minutes 15 seconds 464 milliseconds
H2O cluster version: 3.9.1.3460
H2O cluster name: H2O_started_from_python_laurend_kxh246
H2O cluster total nodes: 1
H2O cluster total free memory: 3.3 GB
H2O cluster total cores: 8
H2O cluster allowed cores: 8
H2O cluster healthy: True
H2O Connection ip: 127.0.0.1
H2O Connection port: 54321
H2O Connection proxy: None
Python Version: 3.5.1

In [36]:
# 1 - Load data - One row per flight.  
# Columns include origin, destination, departure & arrival times, carrier information, and whether flight was delayed.
print("Import and Parse airlines data")
# air_path = 'allyears2k_headers.zip'
air_path = "http://h2o-public-test-data.s3.amazonaws.com/smalldata/airlines/allyears2k_headers.zip"
data = h2o.import_file(path = air_path)
# data.describe() # uncomment to see summary of loaded data file
# data.head()     # uncomment to see top of the loaded data file


Import and Parse airlines data

Parse Progress: [##################################################] 100%

Explore Data with Visualizations


In [37]:
# 2 - Data exploration and munging. 
# Generate scatter plots of various columns and plot fitted GLM model.

# Function to fit a GLM model and plot the fitted (x,y) values
def scatter_plot(data, x, y, max_points = 1000, fit = True):
    if(fit):
        lr = H2OGeneralizedLinearEstimator(family = "gaussian")
        lr.train(x=x, y=y, training_frame=data)
        coeff = lr.coef()
    df = data[[x,y]]
    runif = df[y].runif()
    df_subset = df[runif < float(max_points)/data.nrow]
    df_py = h2o.as_list(df_subset)
    
    if(fit): h2o.remove(lr._id)

    # If x variable is string, generate box-and-whisker plot
    if(df_py[x].dtype == "object"):
        if interactive: df_py.boxplot(column = y, by = x)
    # Otherwise, generate a scatter plot
    else:
        if interactive: df_py.plot(x = x, y = y, kind = "scatter")
    
    if(fit):
        x_min = min(df_py[x])
        x_max = max(df_py[x])
        y_min = coeff["Intercept"] + coeff[x]*x_min
        y_max = coeff["Intercept"] + coeff[x]*x_max
        plt.plot([x_min, x_max], [y_min, y_max], "k-")
    if interactive: plt.show()
        
# generate matplotlib plots inside of ipython notebook        
%matplotlib inline  
scatter_plot(data, "Distance", "AirTime", fit = True)


glm Model Build Progress: [##################################################] 100%

In [6]:
# Group flights by month
grouped = data.group_by("Month")
bpd = grouped.count().sum("Cancelled").frame
bpd.show()
bpd.describe()
bpd.dim

# Convert columns to factors
data["Year"]= data["Year"].asfactor()
data["Month"] = data["Month"].asfactor()
data["DayOfWeek"] = data["DayOfWeek"].asfactor()
data["Cancelled"] = data["Cancelled"].asfactor()


Month sum_Cancelled nrow_Year
1 1067 41979
10 19 1999
Rows:2 Cols:3

Chunk compression summary: 
chunk_type chunk_name count count_percentage size size_percentage
C1N 1-Byte Integers (w/o NAs) 1 33.333336 70 B 30.434782
C2 2-Byte Integers 1 33.333336 72 B 31.304348
C2S 2-Byte Fractions 1 33.333336 88 B 38.260868
Frame distribution summary: 
size number_of_rows number_of_chunks_per_column number_of_chunks
127.0.0.1:54321 230 B 2.0 1.0 3.0
mean 230 B 2.0 1.0 3.0
min 230 B 2.0 1.0 3.0
max 230 B 2.0 1.0 3.0
stddev 0 B 0.0 0.0 0.0
total 230 B 2.0 1.0 3.0

Month sum_Cancelled nrow_Year
type int int int
mins 1.0 19.0 1999.0
mean 5.5 543.0 21989.0
maxs 10.0 1067.0 41979.0
sigma 6.363961030678928741.047906683501828270.12911183817
zeros 0 0 0
missing0 0 0
0 1.0 1067.0 41979.0
1 10.0 19.0 1999.0

In [7]:
# Calculate and plot travel time
hour1 = data["CRSArrTime"] / 100
mins1 = data["CRSArrTime"] % 100
arrTime = hour1*60 + mins1

hour2 = data["CRSDepTime"] / 100
mins2 = data["CRSDepTime"] % 100
depTime = hour2*60 + mins2

data["TravelTime"] = (arrTime-depTime > 0).ifelse((arrTime-depTime), h2o.H2OFrame([[None]] * data.nrow))

scatter_plot(data, "Distance", "TravelTime")


Parse Progress: [##################################################] 100%

glm Model Build Progress: [##################################################] 100%

In [8]:
# Impute missing travel times and re-plot
data.impute(column = "Distance", by = ["Origin", "Dest"])
scatter_plot(data, "Distance", "TravelTime")


glm Model Build Progress: [##################################################] 100%

Split Data Set into Train and Test Sets


In [9]:
# 3 - Fit a model on train; using test as validation.
# Create test/train split
s = data["Year"].runif()
train = data[s <= 0.75]
test  = data[s > 0.75]

In [10]:
# Replace all instances of 'SFO' in the destination column ('Dest') with 'BB8'
test["Dest"] = (test["Dest"] == 'SFO').ifelse('BB8', test["Dest"])
# print out the number of rows that were effected
test[test['Dest']=='BB8'].shape


Out[10]:
(337, 32)

We replace all instances of 'SFO' from the 'Dest' column, to create the situation in which your test set has a categorical level that was not present in the training set (Note: all models will run without breaking, because new categorical levels are interpreted as if they were NA values)

Build Models with Supervised Learning Algorithms

Train models for GLM, GBM, DRF, Deep Learning & Naive Bayes in one go!


In [11]:
# Set response column
myY = "IsDepDelayed"
# Set feature columns
myX = ["Origin", "Dest", "Year", "UniqueCarrier", "DayOfWeek", "Month", "Distance", "FlightNum"]

# Predict delays with GLM
data_glm = H2OGeneralizedLinearEstimator(family="binomial", standardize=True)
data_glm.train(x = myX, y = myY, training_frame = train, validation_frame = test)

# Predict delays with GBM
data_gbm2 = H2OGradientBoostingEstimator(balance_classes = False, ntrees = 50, max_depth = 5,
                                         distribution = "bernoulli", learn_rate = 0.1, min_rows = 2)
data_gbm2.train(x = myX, y = myY, training_frame = train, validation_frame = test)

# Predict delays with Distributed Random Forest (DRF)
data_rf2 = H2ORandomForestEstimator(ntrees = 10,max_depth = 5, balance_classes = False)
data_rf2.train(x = myX, y = myY, training_frame = train, validation_frame = test)

# Predict delays with Deep Learning
data_dl = H2ODeepLearningEstimator(hidden = [10,10], epochs = 5, variable_importances = True,
                                   balance_classes = False, loss = "Automatic")
data_dl.train(x = myX, y = myY, training_frame = train, validation_frame=test)

# Predict delays with Naive Bayes
# If laplace smoothing is disabled ('laplace=0') the algorithm will predict 0
data_nb = H2ONaiveBayesEstimator(laplace=1) 
data_nb.train(x = myX, y = myY, training_frame = train, validation_frame=test)


glm Model Build Progress: [##################################################] 100%

gbm Model Build Progress: [##################################################] 100%

drf Model Build Progress: [##################################################] 100%

deeplearning Model Build Progress: [##################################################] 100%

naivebayes Model Build Progress: [##################################################] 100%

Or build models individually

Build GLM


In [12]:
# Set response column
myY = "IsDepDelayed"
# Set feature columns
myX = ["Origin", "Dest", "Year", "UniqueCarrier", "DayOfWeek", "Month", "Distance", "FlightNum"]

# Predict delays with GLM
data_glm = H2OGeneralizedLinearEstimator(family="binomial", standardize=True)
data_glm.train(x = myX, y = myY, training_frame = train, validation_frame = test)

data_glm.model_performance(test)


glm Model Build Progress: [##################################################] 100%

ModelMetricsBinomialGLM: glm
** Reported on test data. **

MSE: 0.21240524065457322
R^2: 0.14856313954399736
LogLoss: 0.6130051173885511
Null degrees of freedom: 11064
Residual degrees of freedom: 10830
Null deviance: 15316.04643380064
Residual deviance: 13565.803247808637
AIC: 14035.803247808637
AUC: 0.7225305170277084
Gini: 0.44506103405541686

Confusion Matrix (Act/Pred) for max f1 @ threshold = 0.33560739553025765: 
NO YES Error Rate
NO 1542.0 3735.0 0.7078 (3735.0/5277.0)
YES 544.0 5244.0 0.094 (544.0/5788.0)
Total 2086.0 8979.0 0.3867 (4279.0/11065.0)
Maximum Metrics: Maximum metrics at their respective thresholds

metric threshold value idx
max f1 0.3356074 0.7102323 310.0
max f2 0.1528930 0.8468334 384.0
max f0point5 0.5510498 0.6883201 192.0
max accuracy 0.4951515 0.6703118 223.0
max precision 0.9865700 1.0 0.0
max recall 0.0766047 1.0 399.0
max specificity 0.9865700 1.0 0.0
max absolute_MCC 0.5510498 0.3403036 192.0
max min_per_class_accuracy 0.5255959 0.6678037 206.0
Gains/Lift Table: Avg response rate: 52.31 %

group cumulative_data_fraction lower_threshold lift cumulative_lift response_rate cumulative_response_rate capture_rate cumulative_capture_rate gain cumulative_gain
1 0.0100316 0.9155318 1.8256007 1.8256007 0.9549550 0.9549550 0.0183138 0.0183138 82.5600652 82.5600652
2 0.0200633 0.8851707 1.6878195 1.7567101 0.8828829 0.9189189 0.0169316 0.0352453 68.7819471 75.6710062
3 0.0300045 0.8651742 1.6857841 1.7332105 0.8818182 0.9066265 0.0167588 0.0520041 68.5784067 73.3210485
4 0.0400362 0.8503729 1.6017062 1.7002602 0.8378378 0.8893905 0.0160677 0.0680719 60.1706233 70.0260210
5 0.0500678 0.8379788 1.5844836 1.6770631 0.8288288 0.8772563 0.0158950 0.0839668 58.4483585 67.7063088
6 0.1002259 0.7822708 1.5913726 1.6341792 0.8324324 0.8548242 0.0798203 0.1637871 59.1372644 63.4179232
7 0.1500226 0.7336886 1.3843445 1.5512522 0.7241379 0.8114458 0.0689357 0.2327229 38.4344542 55.1252175
8 0.2 0.6966957 1.3793379 1.5082930 0.7215190 0.7889742 0.0689357 0.3016586 37.9337871 50.8293020
9 0.3000452 0.6390383 1.3055607 1.4406952 0.6829268 0.7536145 0.1306151 0.4322737 30.5560706 44.0695227
10 0.4 0.5834513 1.2237734 1.3864893 0.6401447 0.7252598 0.1223220 0.5545957 22.3773449 38.6489288
11 0.5001356 0.5292895 1.0783585 1.3247963 0.5640794 0.6929888 0.1079820 0.6625777 7.8358467 32.4796309
12 0.6 0.4600544 0.8979000 1.2537434 0.4696833 0.6558217 0.0896683 0.7522460 -10.2099992 25.3743377
13 0.6999548 0.4015261 0.7639942 1.1838063 0.3996383 0.6192382 0.0763649 0.8286109 -23.6005841 18.3806303
14 0.8 0.3414800 0.6942267 1.1225812 0.3631436 0.5872119 0.0694540 0.8980650 -30.5773275 12.2581202
15 0.8999548 0.2668751 0.5876878 1.0631725 0.3074141 0.5561358 0.0587422 0.9568072 -41.2312185 6.3172477
16 1.0 0.0743241 0.4317330 1.0 0.2258356 0.5230908 0.0431928 1.0 -56.8266962 0.0

Out[12]:


In [13]:
glm_pred_output = data_glm.predict(test)
glm_pred_output.head()


glm prediction Progress: [##################################################] 100%
predict NO YES
YES 0.4910190.508981
YES 0.5112030.488797
YES 0.4185240.581476
YES 0.4755370.524463
YES 0.4446830.555317
YES 0.3816440.618356
YES 0.2981280.701872
YES 0.3485710.651429
YES 0.3561080.643892
YES 0.3485710.651429
Out[13]:

Build GBM


In [14]:
# Set response column
myY = "IsDepDelayed"
# Set feature columns
myX = ["Origin", "Dest", "Year", "UniqueCarrier", "DayOfWeek", "Month", "Distance", "FlightNum"]

# Predict delays with GBM
data_gbm2 = H2OGradientBoostingEstimator(balance_classes = False, ntrees = 50, max_depth = 5,
                                         distribution = "bernoulli", learn_rate = 0.1, min_rows = 2)
data_gbm2.train(x = myX, y = myY, training_frame = train, validation_frame = test)

data_gbm2.model_performance(test)


gbm Model Build Progress: [##################################################] 100%

ModelMetricsBinomial: gbm
** Reported on test data. **

MSE: 0.20758090434261545
R^2: 0.1679017290749275
LogLoss: 0.602413781737347
AUC: 0.7366535927580263
Gini: 0.47330718551605266

Confusion Matrix (Act/Pred) for max f1 @ threshold = 0.3688381072712026: 
NO YES Error Rate
NO 2141.0 3136.0 0.5943 (3136.0/5277.0)
YES 800.0 4988.0 0.1382 (800.0/5788.0)
Total 2941.0 8124.0 0.3557 (3936.0/11065.0)
Maximum Metrics: Maximum metrics at their respective thresholds

metric threshold value idx
max f1 0.3688381 0.7170788 290.0
max f2 0.1506781 0.8468251 382.0
max f0point5 0.5468037 0.7011673 193.0
max accuracy 0.5217700 0.6825124 205.0
max precision 0.9581786 1.0 0.0
max recall 0.0924748 1.0 396.0
max specificity 0.9581786 1.0 0.0
max absolute_MCC 0.5217700 0.3671079 205.0
max min_per_class_accuracy 0.5054973 0.6782999 215.0
Gains/Lift Table: Avg response rate: 52.31 %

group cumulative_data_fraction lower_threshold lift cumulative_lift response_rate cumulative_response_rate capture_rate cumulative_capture_rate gain cumulative_gain
1 0.0100316 0.9016550 1.8083780 1.8083780 0.9459459 0.9459459 0.0181410 0.0181410 80.8378005 80.8378005
2 0.0201536 0.8804417 1.6898185 1.7488324 0.8839286 0.9147982 0.0171044 0.0352453 68.9818528 74.8832438
3 0.0300045 0.8642542 1.6310953 1.7101778 0.8532110 0.8945783 0.0160677 0.0513131 63.1095338 71.0177788
4 0.0400362 0.8522578 1.7222648 1.7132064 0.9009009 0.8961625 0.0172771 0.0685902 72.2264766 71.3206354
5 0.0500678 0.8413600 1.7394874 1.7184721 0.9099099 0.8989170 0.0174499 0.0860401 73.9487414 71.8472053
6 0.1001356 0.7920522 1.6011467 1.6598094 0.8375451 0.8682310 0.0801659 0.1662059 60.1146652 65.9809353
7 0.1502033 0.7486083 1.4769198 1.5988462 0.7725632 0.8363418 0.0739461 0.2401520 47.6919757 59.8846154
8 0.2002711 0.7102250 1.4079048 1.5511108 0.7364621 0.8113718 0.0704907 0.3106427 40.7904815 55.1110819
9 0.3000452 0.6530549 1.3558623 1.4861848 0.7092391 0.7774096 0.1352799 0.4459226 35.5862298 48.6184805
10 0.4009038 0.5791790 1.1768346 1.4083591 0.6155914 0.7366997 0.1186938 0.5646164 17.6834626 40.8359106
11 0.5000452 0.5079291 1.0926569 1.3457664 0.5715588 0.7039581 0.1083276 0.6729440 9.2656891 34.5766421
12 0.6002711 0.4436697 0.8843185 1.2687195 0.4625789 0.6636555 0.0886317 0.7615757 -11.5681491 26.8719487
13 0.7000452 0.3864459 0.7497936 1.1947595 0.3922101 0.6249677 0.0748100 0.8363856 -25.0206418 19.4759482
14 0.8000904 0.3229952 0.6475996 1.1263413 0.3387534 0.5891788 0.0647892 0.9011748 -35.2400443 12.6341314
15 0.8999548 0.2566938 0.5934098 1.0672040 0.3104072 0.5582446 0.0592605 0.9604354 -40.6590168 6.7204009
16 1.0 0.0676948 0.3954675 1.0 0.2068654 0.5230908 0.0395646 1.0 -60.4532537 0.0

Out[14]:


In [15]:
data_gbm2.predict(test)


gbm prediction Progress: [##################################################] 100%
predict NO YES
NO 0.7264790.273521
NO 0.7481990.251801
YES 0.5787550.421245
NO 0.7180420.281958
YES 0.5926160.407384
YES 0.30295 0.69705
YES 0.2546470.745353
YES 0.3143060.685694
YES 0.3080020.691998
YES 0.3143060.685694
Out[15]:

Build Distributed Random Forest


In [16]:
# Set response column
myY = "IsDepDelayed"
# Set feature columns
myX = ["Origin", "Dest", "Year", "UniqueCarrier", "DayOfWeek", "Month", "Distance", "FlightNum"]

# Predict delays with Distributed Random Forest (DRF)
data_rf2 = H2ORandomForestEstimator(ntrees = 10,max_depth = 5, balance_classes = False)
data_rf2.train(x = myX, y = myY, training_frame = train, validation_frame = test)

data_rf2.model_performance(test)


drf Model Build Progress: [##################################################] 100%

ModelMetricsBinomial: drf
** Reported on test data. **

MSE: 0.21587610782861724
R^2: 0.13464999766190133
LogLoss: 0.6207972055643648
AUC: 0.7172513354494129
Gini: 0.43450267089882577

Confusion Matrix (Act/Pred) for max f1 @ threshold = 0.40442138612270356: 
NO YES Error Rate
NO 1596.0 3681.0 0.6976 (3681.0/5277.0)
YES 579.0 5209.0 0.1 (579.0/5788.0)
Total 2175.0 8890.0 0.385 (4260.0/11065.0)
Maximum Metrics: Maximum metrics at their respective thresholds

metric threshold value idx
max f1 0.4044214 0.7097697 316.0
max f2 0.2664590 0.8473607 390.0
max f0point5 0.5498811 0.6798113 208.0
max accuracy 0.5140285 0.6642567 231.0
max precision 0.9083639 1.0 0.0
max recall 0.2172136 1.0 399.0
max specificity 0.9083639 1.0 0.0
max absolute_MCC 0.5140285 0.3267948 231.0
max min_per_class_accuracy 0.5287009 0.6596406 222.0
Gains/Lift Table: Avg response rate: 52.31 %

group cumulative_data_fraction lower_threshold lift cumulative_lift response_rate cumulative_response_rate capture_rate cumulative_capture_rate gain cumulative_gain
1 0.0100316 0.8517226 1.8600459 1.8600459 0.9729730 0.9729730 0.0186593 0.0186593 86.0045948 86.0045948
2 0.0203344 0.8352025 1.6937114 1.7757698 0.8859649 0.9288889 0.0174499 0.0361092 69.3711430 77.5769792
3 0.0301853 0.8227487 1.7012500 1.7514505 0.8899083 0.9161677 0.0167588 0.0528680 70.1249976 75.1450451
4 0.0413918 0.8076230 1.7267093 1.7447520 0.9032258 0.9126638 0.0193504 0.0722184 72.6709321 74.4751979
5 0.0501582 0.7957986 1.5766713 1.7153757 0.8247423 0.8972973 0.0138217 0.0860401 57.6671250 71.5375707
6 0.1000452 0.7483208 1.5792419 1.6474933 0.8260870 0.8617886 0.0787837 0.1648238 57.9241910 64.7493272
7 0.1500226 0.7075346 1.4346497 1.5765881 0.7504521 0.8246988 0.0717001 0.2365238 43.4649665 57.6588142
8 0.2004519 0.6735833 1.3806823 1.5273025 0.7222222 0.7989179 0.0696268 0.3061507 38.0682254 52.7302531
9 0.3016719 0.6252651 1.2733380 1.4420897 0.6660714 0.7543439 0.1288874 0.4350380 27.3338002 44.2089747
10 0.4 0.5823499 1.1649507 1.3739634 0.609375 0.7187076 0.1145473 0.5495853 16.4950652 37.3963372
11 0.5001356 0.5308816 1.0369495 1.3064875 0.5424188 0.6834116 0.1038355 0.6534209 3.6949502 30.6487520
12 0.6000904 0.4789212 0.9230156 1.2426140 0.4828210 0.65 0.0922598 0.7456807 -7.6984432 24.2614029
13 0.7003163 0.4444286 0.7757180 1.1757941 0.4057710 0.6150471 0.0777471 0.8234278 -22.4282010 17.5794090
14 0.8000904 0.4055548 0.7463303 1.1222385 0.3903986 0.5870326 0.0744644 0.8978922 -25.3669668 12.2238460
15 0.9001356 0.3562070 0.5491644 1.0585444 0.2872629 0.5537149 0.0549413 0.9528334 -45.0835576 5.8544388
16 1.0 0.2172136 0.4723058 1.0 0.2470588 0.5230908 0.0471666 1.0 -52.7694215 0.0

Out[16]:


In [17]:
data_rf2.predict(test)


drf prediction Progress: [##################################################] 100%
predict NO YES
YES 0.4259250.574075
YES 0.4379420.562058
YES 0.3625060.637494
YES 0.4209640.579036
YES 0.3488770.651123
YES 0.3411710.658829
YES 0.3397970.660203
YES 0.3411710.658829
YES 0.3411710.658829
YES 0.3411710.658829
Out[17]:

Build Deep Learning


In [18]:
# Set response column
myY = "IsDepDelayed"
# Set feature columns
myX = ["Origin", "Dest", "Year", "UniqueCarrier", "DayOfWeek", "Month", "Distance", "FlightNum"]

# Predict delays with Deep Learning
data_dl = H2ODeepLearningEstimator(hidden = [10,10], epochs = 5, variable_importances = True,
                                   balance_classes = False, loss = "Automatic")
data_dl.train(x = myX, y = myY, training_frame = train, validation_frame=test)

data_dl.model_performance(test)


deeplearning Model Build Progress: [##################################################] 100%

ModelMetricsBinomial: deeplearning
** Reported on test data. **

MSE: 0.2170978543303148
R^2: 0.12975256668292567
LogLoss: 0.6237081078318155
AUC: 0.7209628233723193
Gini: 0.44192564674463863

Confusion Matrix (Act/Pred) for max f1 @ threshold = 0.29878787453148525: 
NO YES Error Rate
NO 2031.0 3246.0 0.6151 (3246.0/5277.0)
YES 789.0 4999.0 0.1363 (789.0/5788.0)
Total 2820.0 8245.0 0.3647 (4035.0/11065.0)
Maximum Metrics: Maximum metrics at their respective thresholds

metric threshold value idx
max f1 0.2987879 0.7124635 310.0
max f2 0.1231821 0.8465415 392.0
max f0point5 0.4613361 0.6840484 220.0
max accuracy 0.4008738 0.6666968 254.0
max precision 0.9920392 1.0 0.0
max recall 0.0999376 1.0 397.0
max specificity 0.9920392 1.0 0.0
max absolute_MCC 0.4613361 0.3322544 220.0
max min_per_class_accuracy 0.4315932 0.6611711 236.0
Gains/Lift Table: Avg response rate: 52.31 %

group cumulative_data_fraction lower_threshold lift cumulative_lift response_rate cumulative_response_rate capture_rate cumulative_capture_rate gain cumulative_gain
1 0.0101220 0.9283660 1.8093006 1.8093006 0.9464286 0.9464286 0.0183138 0.0183138 80.9300647 80.9300647
2 0.0200633 0.8904271 1.7205425 1.7653214 0.9 0.9234234 0.0171044 0.0354181 72.0542502 76.5321386
3 0.0300045 0.8634937 1.6162672 1.7159360 0.8454545 0.8975904 0.0160677 0.0514858 61.6267199 71.5935962
4 0.0400362 0.8469819 1.6705968 1.7045756 0.8738739 0.8916479 0.0167588 0.0682446 67.0596823 70.4575591
5 0.0500678 0.8255683 1.6017062 1.6839646 0.8378378 0.8808664 0.0160677 0.0843124 60.1706233 68.3964583
6 0.1000452 0.7519715 1.5694722 1.6267701 0.8209765 0.8509485 0.0784381 0.1627505 56.9472164 62.6770086
7 0.1500226 0.6936521 1.4346497 1.5627685 0.7504521 0.8174699 0.0717001 0.2344506 43.4649665 56.2768524
8 0.2 0.6494949 1.3413110 1.5074292 0.7016275 0.7885224 0.0670352 0.3014858 34.1311012 50.7429164
9 0.3000452 0.5722457 1.3245570 1.4464534 0.6928636 0.7566265 0.1325155 0.4340014 32.4556960 44.6453402
10 0.4 0.5049965 1.1719186 1.3778507 0.6130199 0.7207411 0.1171389 0.5511403 17.1918642 37.7850726
11 0.5000452 0.4363803 1.0447940 1.3112153 0.5465221 0.6858847 0.1045266 0.6556669 4.4793951 31.1215293
12 0.6 0.3773007 0.9558569 1.2520157 0.5 0.6549179 0.0955425 0.7512094 -4.4143055 25.2015665
13 0.6999548 0.3224394 0.7760936 1.1840531 0.4059675 0.6193673 0.0775743 0.8287837 -22.3906386 18.4053135
14 0.8 0.2690563 0.6907729 1.1223652 0.3613369 0.5870990 0.0691085 0.8978922 -30.9227140 12.2365238
15 0.9000452 0.2048145 0.5871569 1.0628738 0.3071364 0.5559795 0.0587422 0.9566344 -41.2843069 6.2873764
16 1.0 0.0899145 0.4338519 1.0 0.2269439 0.5230908 0.0433656 1.0 -56.6148113 0.0

Out[18]:


In [19]:
data_dl.predict(test)


deeplearning prediction Progress: [##################################################] 100%
predict NO YES
NO 0.7723260.227674
NO 0.8564790.143521
NO 0.7021090.297891
NO 0.7616360.238364
NO 0.7051240.294876
YES 0.5100080.489992
YES 0.2854920.714508
YES 0.3499840.650016
YES 0.3491140.650886
YES 0.3499840.650016
Out[19]:

Build Naive Bayes


In [20]:
# Set response column
myY = "IsDepDelayed"
# Set feature columns
myX = ["Origin", "Dest", "Year", "UniqueCarrier", "DayOfWeek", "Month", "Distance", "FlightNum"]

# Predict delays with Naive Bayes
# If laplace smoothing is disabled ('laplace=0') the algorithm will predict 0
data_nb = H2ONaiveBayesEstimator(laplace=1) 
data_nb.train(x = myX, y = myY, training_frame = train, validation_frame=test)

data_nb.model_performance(test)


naivebayes Model Build Progress: [##################################################] 100%

ModelMetricsBinomial: naivebayes
** Reported on test data. **

MSE: 0.22952227413516735
R^2: 0.07994866844091097
LogLoss: 0.6604736519444419
AUC: 0.6926305809501246
Gini: 0.3852611619002493

Confusion Matrix (Act/Pred) for max f1 @ threshold = 0.2227661773519952: 
NO YES Error Rate
NO 1169.0 4108.0 0.7785 (4108.0/5277.0)
YES 430.0 5358.0 0.0743 (430.0/5788.0)
Total 1599.0 9466.0 0.4101 (4538.0/11065.0)
Maximum Metrics: Maximum metrics at their respective thresholds

metric threshold value idx
max f1 0.2227662 0.7025043 334.0
max f2 0.0499286 0.8468321 390.0
max f0point5 0.5720444 0.6629061 187.0
max accuracy 0.5720444 0.6419340 187.0
max precision 0.9972780 0.9807692 0.0
max recall 0.0129773 1.0 399.0
max specificity 0.9972780 0.9998105 0.0
max absolute_MCC 0.5720444 0.2928394 187.0
max min_per_class_accuracy 0.5013849 0.6371044 218.0
Gains/Lift Table: Avg response rate: 52.31 %

group cumulative_data_fraction lower_threshold lift cumulative_lift response_rate cumulative_response_rate capture_rate cumulative_capture_rate gain cumulative_gain
1 0.0100316 0.9867983 1.8600459 1.8600459 0.9729730 0.9729730 0.0186593 0.0186593 86.0045948 86.0045948
2 0.0200633 0.9728750 1.7222648 1.7911554 0.9009009 0.9369369 0.0172771 0.0359364 72.2264766 79.1155357
3 0.0300045 0.9608781 1.5815088 1.7216941 0.8272727 0.9006024 0.0157222 0.0516586 58.1508764 72.1694137
4 0.0400362 0.9452227 1.4122571 1.6441603 0.7387387 0.8600451 0.0141672 0.0658258 41.2257109 64.4160254
5 0.0500678 0.9312264 1.5500383 1.6253019 0.8108108 0.8501805 0.0155494 0.0813753 55.0038290 62.5301882
6 0.1000452 0.8739324 1.5141604 1.5697813 0.7920434 0.8211382 0.0756738 0.1570491 51.4160369 56.9781325
7 0.1500226 0.8246779 1.4173647 1.5190064 0.7414105 0.7945783 0.0708362 0.2278853 41.7364729 51.9006399
8 0.2 0.7696958 1.2894562 1.4616448 0.6745027 0.7645730 0.0644437 0.2923290 28.9456205 46.1644782
9 0.3000452 0.6837439 1.2882914 1.4038429 0.6738934 0.7343373 0.1288874 0.4212163 28.8291384 40.3842911
10 0.4 0.6035968 1.1338918 1.3363856 0.5931284 0.6990511 0.1133379 0.5345543 13.3891783 33.6385625
11 0.5002259 0.5062851 0.9705094 1.2630782 0.5076646 0.6607046 0.0972702 0.6318245 -2.9490604 26.3078175
12 0.6 0.4147244 0.8848603 1.2001843 0.4628623 0.6278054 0.0882861 0.7201106 -11.5139676 20.0184289
13 0.6999548 0.3256296 0.8711608 1.1531991 0.4556962 0.6032279 0.0870767 0.8071873 -12.8839240 15.3199135
14 0.8 0.2599485 0.7650309 1.1046562 0.4001807 0.5778355 0.0765377 0.8837249 -23.4969057 10.4656185
15 0.9000452 0.1793016 0.6786843 1.0573070 0.3550136 0.5530676 0.0678991 0.9516240 -32.1315665 5.7306970
16 1.0 0.0098785 0.4839782 1.0 0.2531646 0.5230908 0.0483760 1.0 -51.6021800 0.0

Out[20]:


In [21]:
data_nb.predict(test)


naivebayes prediction Progress: [##################################################] 100%
predict NO YES
YES 0.4843490.515651
YES 0.4796150.520385
YES 0.4028050.597195
YES 0.4445150.555485
YES 0.3535340.646466
YES 0.3222990.677701
YES 0.2581830.741817
YES 0.2922460.707754
YES 0.2941630.705837
YES 0.2922460.707754
Out[21]:

Model Performance and Output

Run each cell below to see each model's performance on the test set


In [22]:
# GLM performance
data_glm.model_performance(test)


ModelMetricsBinomialGLM: glm
** Reported on test data. **

MSE: 0.21240524065457322
R^2: 0.14856313954399736
LogLoss: 0.6130051173885511
Null degrees of freedom: 11064
Residual degrees of freedom: 10830
Null deviance: 15316.04643380064
Residual deviance: 13565.803247808637
AIC: 14035.803247808637
AUC: 0.7225305170277084
Gini: 0.44506103405541686

Confusion Matrix (Act/Pred) for max f1 @ threshold = 0.33560739553025765: 
NO YES Error Rate
NO 1542.0 3735.0 0.7078 (3735.0/5277.0)
YES 544.0 5244.0 0.094 (544.0/5788.0)
Total 2086.0 8979.0 0.3867 (4279.0/11065.0)
Maximum Metrics: Maximum metrics at their respective thresholds

metric threshold value idx
max f1 0.3356074 0.7102323 310.0
max f2 0.1528930 0.8468334 384.0
max f0point5 0.5510498 0.6883201 192.0
max accuracy 0.4951515 0.6703118 223.0
max precision 0.9865700 1.0 0.0
max recall 0.0766047 1.0 399.0
max specificity 0.9865700 1.0 0.0
max absolute_MCC 0.5510498 0.3403036 192.0
max min_per_class_accuracy 0.5255959 0.6678037 206.0
Gains/Lift Table: Avg response rate: 52.31 %

group cumulative_data_fraction lower_threshold lift cumulative_lift response_rate cumulative_response_rate capture_rate cumulative_capture_rate gain cumulative_gain
1 0.0100316 0.9155318 1.8256007 1.8256007 0.9549550 0.9549550 0.0183138 0.0183138 82.5600652 82.5600652
2 0.0200633 0.8851707 1.6878195 1.7567101 0.8828829 0.9189189 0.0169316 0.0352453 68.7819471 75.6710062
3 0.0300045 0.8651742 1.6857841 1.7332105 0.8818182 0.9066265 0.0167588 0.0520041 68.5784067 73.3210485
4 0.0400362 0.8503729 1.6017062 1.7002602 0.8378378 0.8893905 0.0160677 0.0680719 60.1706233 70.0260210
5 0.0500678 0.8379788 1.5844836 1.6770631 0.8288288 0.8772563 0.0158950 0.0839668 58.4483585 67.7063088
6 0.1002259 0.7822708 1.5913726 1.6341792 0.8324324 0.8548242 0.0798203 0.1637871 59.1372644 63.4179232
7 0.1500226 0.7336886 1.3843445 1.5512522 0.7241379 0.8114458 0.0689357 0.2327229 38.4344542 55.1252175
8 0.2 0.6966957 1.3793379 1.5082930 0.7215190 0.7889742 0.0689357 0.3016586 37.9337871 50.8293020
9 0.3000452 0.6390383 1.3055607 1.4406952 0.6829268 0.7536145 0.1306151 0.4322737 30.5560706 44.0695227
10 0.4 0.5834513 1.2237734 1.3864893 0.6401447 0.7252598 0.1223220 0.5545957 22.3773449 38.6489288
11 0.5001356 0.5292895 1.0783585 1.3247963 0.5640794 0.6929888 0.1079820 0.6625777 7.8358467 32.4796309
12 0.6 0.4600544 0.8979000 1.2537434 0.4696833 0.6558217 0.0896683 0.7522460 -10.2099992 25.3743377
13 0.6999548 0.4015261 0.7639942 1.1838063 0.3996383 0.6192382 0.0763649 0.8286109 -23.6005841 18.3806303
14 0.8 0.3414800 0.6942267 1.1225812 0.3631436 0.5872119 0.0694540 0.8980650 -30.5773275 12.2581202
15 0.8999548 0.2668751 0.5876878 1.0631725 0.3074141 0.5561358 0.0587422 0.9568072 -41.2312185 6.3172477
16 1.0 0.0743241 0.4317330 1.0 0.2258356 0.5230908 0.0431928 1.0 -56.8266962 0.0

Out[22]:


In [23]:
# Distributed Random Forest Performance
data_rf2.model_performance(test)


ModelMetricsBinomial: drf
** Reported on test data. **

MSE: 0.21587610782861724
R^2: 0.13464999766190133
LogLoss: 0.6207972055643648
AUC: 0.7172513354494129
Gini: 0.43450267089882577

Confusion Matrix (Act/Pred) for max f1 @ threshold = 0.40442138612270356: 
NO YES Error Rate
NO 1596.0 3681.0 0.6976 (3681.0/5277.0)
YES 579.0 5209.0 0.1 (579.0/5788.0)
Total 2175.0 8890.0 0.385 (4260.0/11065.0)
Maximum Metrics: Maximum metrics at their respective thresholds

metric threshold value idx
max f1 0.4044214 0.7097697 316.0
max f2 0.2664590 0.8473607 390.0
max f0point5 0.5498811 0.6798113 208.0
max accuracy 0.5140285 0.6642567 231.0
max precision 0.9083639 1.0 0.0
max recall 0.2172136 1.0 399.0
max specificity 0.9083639 1.0 0.0
max absolute_MCC 0.5140285 0.3267948 231.0
max min_per_class_accuracy 0.5287009 0.6596406 222.0
Gains/Lift Table: Avg response rate: 52.31 %

group cumulative_data_fraction lower_threshold lift cumulative_lift response_rate cumulative_response_rate capture_rate cumulative_capture_rate gain cumulative_gain
1 0.0100316 0.8517226 1.8600459 1.8600459 0.9729730 0.9729730 0.0186593 0.0186593 86.0045948 86.0045948
2 0.0203344 0.8352025 1.6937114 1.7757698 0.8859649 0.9288889 0.0174499 0.0361092 69.3711430 77.5769792
3 0.0301853 0.8227487 1.7012500 1.7514505 0.8899083 0.9161677 0.0167588 0.0528680 70.1249976 75.1450451
4 0.0413918 0.8076230 1.7267093 1.7447520 0.9032258 0.9126638 0.0193504 0.0722184 72.6709321 74.4751979
5 0.0501582 0.7957986 1.5766713 1.7153757 0.8247423 0.8972973 0.0138217 0.0860401 57.6671250 71.5375707
6 0.1000452 0.7483208 1.5792419 1.6474933 0.8260870 0.8617886 0.0787837 0.1648238 57.9241910 64.7493272
7 0.1500226 0.7075346 1.4346497 1.5765881 0.7504521 0.8246988 0.0717001 0.2365238 43.4649665 57.6588142
8 0.2004519 0.6735833 1.3806823 1.5273025 0.7222222 0.7989179 0.0696268 0.3061507 38.0682254 52.7302531
9 0.3016719 0.6252651 1.2733380 1.4420897 0.6660714 0.7543439 0.1288874 0.4350380 27.3338002 44.2089747
10 0.4 0.5823499 1.1649507 1.3739634 0.609375 0.7187076 0.1145473 0.5495853 16.4950652 37.3963372
11 0.5001356 0.5308816 1.0369495 1.3064875 0.5424188 0.6834116 0.1038355 0.6534209 3.6949502 30.6487520
12 0.6000904 0.4789212 0.9230156 1.2426140 0.4828210 0.65 0.0922598 0.7456807 -7.6984432 24.2614029
13 0.7003163 0.4444286 0.7757180 1.1757941 0.4057710 0.6150471 0.0777471 0.8234278 -22.4282010 17.5794090
14 0.8000904 0.4055548 0.7463303 1.1222385 0.3903986 0.5870326 0.0744644 0.8978922 -25.3669668 12.2238460
15 0.9001356 0.3562070 0.5491644 1.0585444 0.2872629 0.5537149 0.0549413 0.9528334 -45.0835576 5.8544388
16 1.0 0.2172136 0.4723058 1.0 0.2470588 0.5230908 0.0471666 1.0 -52.7694215 0.0

Out[23]:


In [24]:
# GBM Performance
data_gbm2.model_performance(test)


ModelMetricsBinomial: gbm
** Reported on test data. **

MSE: 0.20758090434261545
R^2: 0.1679017290749275
LogLoss: 0.602413781737347
AUC: 0.7366535927580263
Gini: 0.47330718551605266

Confusion Matrix (Act/Pred) for max f1 @ threshold = 0.3688381072712026: 
NO YES Error Rate
NO 2141.0 3136.0 0.5943 (3136.0/5277.0)
YES 800.0 4988.0 0.1382 (800.0/5788.0)
Total 2941.0 8124.0 0.3557 (3936.0/11065.0)
Maximum Metrics: Maximum metrics at their respective thresholds

metric threshold value idx
max f1 0.3688381 0.7170788 290.0
max f2 0.1506781 0.8468251 382.0
max f0point5 0.5468037 0.7011673 193.0
max accuracy 0.5217700 0.6825124 205.0
max precision 0.9581786 1.0 0.0
max recall 0.0924748 1.0 396.0
max specificity 0.9581786 1.0 0.0
max absolute_MCC 0.5217700 0.3671079 205.0
max min_per_class_accuracy 0.5054973 0.6782999 215.0
Gains/Lift Table: Avg response rate: 52.31 %

group cumulative_data_fraction lower_threshold lift cumulative_lift response_rate cumulative_response_rate capture_rate cumulative_capture_rate gain cumulative_gain
1 0.0100316 0.9016550 1.8083780 1.8083780 0.9459459 0.9459459 0.0181410 0.0181410 80.8378005 80.8378005
2 0.0201536 0.8804417 1.6898185 1.7488324 0.8839286 0.9147982 0.0171044 0.0352453 68.9818528 74.8832438
3 0.0300045 0.8642542 1.6310953 1.7101778 0.8532110 0.8945783 0.0160677 0.0513131 63.1095338 71.0177788
4 0.0400362 0.8522578 1.7222648 1.7132064 0.9009009 0.8961625 0.0172771 0.0685902 72.2264766 71.3206354
5 0.0500678 0.8413600 1.7394874 1.7184721 0.9099099 0.8989170 0.0174499 0.0860401 73.9487414 71.8472053
6 0.1001356 0.7920522 1.6011467 1.6598094 0.8375451 0.8682310 0.0801659 0.1662059 60.1146652 65.9809353
7 0.1502033 0.7486083 1.4769198 1.5988462 0.7725632 0.8363418 0.0739461 0.2401520 47.6919757 59.8846154
8 0.2002711 0.7102250 1.4079048 1.5511108 0.7364621 0.8113718 0.0704907 0.3106427 40.7904815 55.1110819
9 0.3000452 0.6530549 1.3558623 1.4861848 0.7092391 0.7774096 0.1352799 0.4459226 35.5862298 48.6184805
10 0.4009038 0.5791790 1.1768346 1.4083591 0.6155914 0.7366997 0.1186938 0.5646164 17.6834626 40.8359106
11 0.5000452 0.5079291 1.0926569 1.3457664 0.5715588 0.7039581 0.1083276 0.6729440 9.2656891 34.5766421
12 0.6002711 0.4436697 0.8843185 1.2687195 0.4625789 0.6636555 0.0886317 0.7615757 -11.5681491 26.8719487
13 0.7000452 0.3864459 0.7497936 1.1947595 0.3922101 0.6249677 0.0748100 0.8363856 -25.0206418 19.4759482
14 0.8000904 0.3229952 0.6475996 1.1263413 0.3387534 0.5891788 0.0647892 0.9011748 -35.2400443 12.6341314
15 0.8999548 0.2566938 0.5934098 1.0672040 0.3104072 0.5582446 0.0592605 0.9604354 -40.6590168 6.7204009
16 1.0 0.0676948 0.3954675 1.0 0.2068654 0.5230908 0.0395646 1.0 -60.4532537 0.0

Out[24]:


In [25]:
# Deep Learning Performance
data_dl.model_performance(test)


ModelMetricsBinomial: deeplearning
** Reported on test data. **

MSE: 0.2170978543303148
R^2: 0.12975256668292567
LogLoss: 0.6237081078318155
AUC: 0.7209628233723193
Gini: 0.44192564674463863

Confusion Matrix (Act/Pred) for max f1 @ threshold = 0.29878787453148525: 
NO YES Error Rate
NO 2031.0 3246.0 0.6151 (3246.0/5277.0)
YES 789.0 4999.0 0.1363 (789.0/5788.0)
Total 2820.0 8245.0 0.3647 (4035.0/11065.0)
Maximum Metrics: Maximum metrics at their respective thresholds

metric threshold value idx
max f1 0.2987879 0.7124635 310.0
max f2 0.1231821 0.8465415 392.0
max f0point5 0.4613361 0.6840484 220.0
max accuracy 0.4008738 0.6666968 254.0
max precision 0.9920392 1.0 0.0
max recall 0.0999376 1.0 397.0
max specificity 0.9920392 1.0 0.0
max absolute_MCC 0.4613361 0.3322544 220.0
max min_per_class_accuracy 0.4315932 0.6611711 236.0
Gains/Lift Table: Avg response rate: 52.31 %

group cumulative_data_fraction lower_threshold lift cumulative_lift response_rate cumulative_response_rate capture_rate cumulative_capture_rate gain cumulative_gain
1 0.0101220 0.9283660 1.8093006 1.8093006 0.9464286 0.9464286 0.0183138 0.0183138 80.9300647 80.9300647
2 0.0200633 0.8904271 1.7205425 1.7653214 0.9 0.9234234 0.0171044 0.0354181 72.0542502 76.5321386
3 0.0300045 0.8634937 1.6162672 1.7159360 0.8454545 0.8975904 0.0160677 0.0514858 61.6267199 71.5935962
4 0.0400362 0.8469819 1.6705968 1.7045756 0.8738739 0.8916479 0.0167588 0.0682446 67.0596823 70.4575591
5 0.0500678 0.8255683 1.6017062 1.6839646 0.8378378 0.8808664 0.0160677 0.0843124 60.1706233 68.3964583
6 0.1000452 0.7519715 1.5694722 1.6267701 0.8209765 0.8509485 0.0784381 0.1627505 56.9472164 62.6770086
7 0.1500226 0.6936521 1.4346497 1.5627685 0.7504521 0.8174699 0.0717001 0.2344506 43.4649665 56.2768524
8 0.2 0.6494949 1.3413110 1.5074292 0.7016275 0.7885224 0.0670352 0.3014858 34.1311012 50.7429164
9 0.3000452 0.5722457 1.3245570 1.4464534 0.6928636 0.7566265 0.1325155 0.4340014 32.4556960 44.6453402
10 0.4 0.5049965 1.1719186 1.3778507 0.6130199 0.7207411 0.1171389 0.5511403 17.1918642 37.7850726
11 0.5000452 0.4363803 1.0447940 1.3112153 0.5465221 0.6858847 0.1045266 0.6556669 4.4793951 31.1215293
12 0.6 0.3773007 0.9558569 1.2520157 0.5 0.6549179 0.0955425 0.7512094 -4.4143055 25.2015665
13 0.6999548 0.3224394 0.7760936 1.1840531 0.4059675 0.6193673 0.0775743 0.8287837 -22.3906386 18.4053135
14 0.8 0.2690563 0.6907729 1.1223652 0.3613369 0.5870990 0.0691085 0.8978922 -30.9227140 12.2365238
15 0.9000452 0.2048145 0.5871569 1.0628738 0.3071364 0.5559795 0.0587422 0.9566344 -41.2843069 6.2873764
16 1.0 0.0899145 0.4338519 1.0 0.2269439 0.5230908 0.0433656 1.0 -56.6148113 0.0

Out[25]:


In [26]:
# Naive Bayes Performance
data_nb.model_performance(test)


ModelMetricsBinomial: naivebayes
** Reported on test data. **

MSE: 0.22952227413516735
R^2: 0.07994866844091097
LogLoss: 0.6604736519444419
AUC: 0.6926305809501246
Gini: 0.3852611619002493

Confusion Matrix (Act/Pred) for max f1 @ threshold = 0.2227661773519952: 
NO YES Error Rate
NO 1169.0 4108.0 0.7785 (4108.0/5277.0)
YES 430.0 5358.0 0.0743 (430.0/5788.0)
Total 1599.0 9466.0 0.4101 (4538.0/11065.0)
Maximum Metrics: Maximum metrics at their respective thresholds

metric threshold value idx
max f1 0.2227662 0.7025043 334.0
max f2 0.0499286 0.8468321 390.0
max f0point5 0.5720444 0.6629061 187.0
max accuracy 0.5720444 0.6419340 187.0
max precision 0.9972780 0.9807692 0.0
max recall 0.0129773 1.0 399.0
max specificity 0.9972780 0.9998105 0.0
max absolute_MCC 0.5720444 0.2928394 187.0
max min_per_class_accuracy 0.5013849 0.6371044 218.0
Gains/Lift Table: Avg response rate: 52.31 %

group cumulative_data_fraction lower_threshold lift cumulative_lift response_rate cumulative_response_rate capture_rate cumulative_capture_rate gain cumulative_gain
1 0.0100316 0.9867983 1.8600459 1.8600459 0.9729730 0.9729730 0.0186593 0.0186593 86.0045948 86.0045948
2 0.0200633 0.9728750 1.7222648 1.7911554 0.9009009 0.9369369 0.0172771 0.0359364 72.2264766 79.1155357
3 0.0300045 0.9608781 1.5815088 1.7216941 0.8272727 0.9006024 0.0157222 0.0516586 58.1508764 72.1694137
4 0.0400362 0.9452227 1.4122571 1.6441603 0.7387387 0.8600451 0.0141672 0.0658258 41.2257109 64.4160254
5 0.0500678 0.9312264 1.5500383 1.6253019 0.8108108 0.8501805 0.0155494 0.0813753 55.0038290 62.5301882
6 0.1000452 0.8739324 1.5141604 1.5697813 0.7920434 0.8211382 0.0756738 0.1570491 51.4160369 56.9781325
7 0.1500226 0.8246779 1.4173647 1.5190064 0.7414105 0.7945783 0.0708362 0.2278853 41.7364729 51.9006399
8 0.2 0.7696958 1.2894562 1.4616448 0.6745027 0.7645730 0.0644437 0.2923290 28.9456205 46.1644782
9 0.3000452 0.6837439 1.2882914 1.4038429 0.6738934 0.7343373 0.1288874 0.4212163 28.8291384 40.3842911
10 0.4 0.6035968 1.1338918 1.3363856 0.5931284 0.6990511 0.1133379 0.5345543 13.3891783 33.6385625
11 0.5002259 0.5062851 0.9705094 1.2630782 0.5076646 0.6607046 0.0972702 0.6318245 -2.9490604 26.3078175
12 0.6 0.4147244 0.8848603 1.2001843 0.4628623 0.6278054 0.0882861 0.7201106 -11.5139676 20.0184289
13 0.6999548 0.3256296 0.8711608 1.1531991 0.4556962 0.6032279 0.0870767 0.8071873 -12.8839240 15.3199135
14 0.8 0.2599485 0.7650309 1.1046562 0.4001807 0.5778355 0.0765377 0.8837249 -23.4969057 10.4656185
15 0.9000452 0.1793016 0.6786843 1.0573070 0.3550136 0.5530676 0.0678991 0.9516240 -32.1315665 5.7306970
16 1.0 0.0098785 0.4839782 1.0 0.2531646 0.5230908 0.0483760 1.0 -51.6021800 0.0

Out[26]:

GLM Coefficient Magnitudes / DRF, GBM & Deep Learning Variable Importance


In [27]:
# Calculate magnitude of normalized GLM coefficients
from six import iteritems
glm_varimp = data_glm.coef_norm()
for k,v in iteritems(glm_varimp):
    glm_varimp[k] = abs(glm_varimp[k])
    
# Sort in descending order by magnitude
glm_sorted = sorted(glm_varimp.items(), key = operator.itemgetter(1), reverse = True)
table = tabulate(glm_sorted, headers = ["Predictor", "Normalized Coefficient"], tablefmt = "orgtbl")
print("Coefficient Magnitudes:\n\n" + table)


Coefficient Magnitudes:

| Predictor                 |   Normalized Coefficient |
|---------------------------+--------------------------|
| Year.2008                 |               2.19858    |
| Origin.MDW                |               1.67302    |
| Origin.HPN                |               1.56693    |
| Year.2003                 |               1.55654    |
| Origin.LIH                |               1.49892    |
| Year.2007                 |               1.46806    |
| Dest.LYH                  |               1.25683    |
| UniqueCarrier.HP          |               1.19937    |
| Dest.HTS                  |               1.15907    |
| Origin.TLH                |               1.13519    |
| Origin.LEX                |               1.08467    |
| Origin.HNL                |               1.04093    |
| Origin.CHO                |               1.02275    |
| Origin.ATL                |               0.999746   |
| Year.2001                 |               0.982825   |
| Origin.PSP                |               0.963531   |
| Year.2002                 |               0.963237   |
| UniqueCarrier.TW          |               0.960336   |
| Origin.ERI                |               0.955332   |
| Origin.STX                |               0.947432   |
| Origin.CAE                |               0.869595   |
| Year.2004                 |               0.860763   |
| Origin.ORD                |               0.839797   |
| Origin.PBI                |               0.800896   |
| Year.2006                 |               0.772002   |
| Dest.CHO                  |               0.76052    |
| Origin.LYH                |               0.76052    |
| Origin.MYR                |               0.757717   |
| Origin.IAH                |               0.747385   |
| Dest.FLL                  |               0.741809   |
| Dest.AVL                  |               0.734992   |
| Year.1994                 |               0.725507   |
| Dest.DAY                  |               0.718767   |
| Origin.OGG                |               0.705689   |
| Dest.PSP                  |               0.692691   |
| Origin.CMH                |               0.687228   |
| Origin.STL                |               0.682812   |
| Dest.ICT                  |               0.672662   |
| Origin.ALB                |               0.640848   |
| Origin.ROA                |               0.633851   |
| Year.1996                 |               0.631063   |
| Origin.TRI                |               0.601505   |
| Origin.SYR                |               0.596296   |
| Dest.PNS                  |               0.588903   |
| Origin.AUS                |               0.582793   |
| Dest.GSO                  |               0.579513   |
| Dest.GEG                  |               0.571285   |
| Origin.BTV                |               0.568473   |
| Origin.LAX                |               0.566536   |
| Origin.MIA                |               0.566439   |
| Origin.OMA                |               0.558284   |
| Origin.ACY                |               0.557115   |
| Year.1997                 |               0.552412   |
| Origin.PIT                |               0.54741    |
| Year.1990                 |               0.539981   |
| Dest.FAY                  |               0.53628    |
| Origin.CRW                |               0.527206   |
| Dest.OGG                  |               0.524551   |
| Dest.SFO                  |               0.522489   |
| Origin.CRP                |               0.517442   |
| Origin.EYW                |               0.512461   |
| Dest.IAH                  |               0.511763   |
| UniqueCarrier.PI          |               0.50913    |
| Dest.SLC                  |               0.509118   |
| Origin.FLL                |               0.503688   |
| UniqueCarrier.WN          |               0.498006   |
| Origin.OKC                |               0.495998   |
| Dest.ISP                  |               0.495776   |
| Origin.MSY                |               0.489918   |
| Year.2005                 |               0.48681    |
| Origin.MCO                |               0.477058   |
| Origin.IND                |               0.473817   |
| Origin.RNO                |               0.469941   |
| Dest.BGM                  |               0.466168   |
| Dest.IND                  |               0.457781   |
| Origin.PHL                |               0.456633   |
| Dest.UCA                  |               0.452046   |
| Origin.GSO                |               0.45192    |
| Origin.PWM                |               0.437979   |
| Dest.TPA                  |               0.431338   |
| Origin.SRQ                |               0.422504   |
| Origin.DAY                |               0.413958   |
| Dest.CAK                  |               0.413293   |
| Origin.DFW                |               0.412907   |
| Origin.SAV                |               0.410809   |
| UniqueCarrier.CO          |               0.40659    |
| Dest.BDL                  |               0.406095   |
| Dest.CAE                  |               0.401772   |
| Dest.COS                  |               0.397441   |
| Origin.OAK                |               0.381577   |
| Origin.BOI                |               0.380043   |
| Origin.JAX                |               0.376318   |
| Dest.PBI                  |               0.373462   |
| Dest.SDF                  |               0.372169   |
| Year.1995                 |               0.367096   |
| Dest.BUF                  |               0.36642    |
| Dest.SWF                  |               0.36335    |
| Origin.ELP                |               0.356208   |
| Dest.CLE                  |               0.355224   |
| Origin.LGA                |               0.353049   |
| Dest.JAX                  |               0.352995   |
| Dest.ALB                  |               0.348283   |
| Origin.BUF                |               0.346038   |
| Origin.MCI                |               0.345448   |
| Origin.BOS                |               0.343105   |
| Origin.BDL                |               0.340712   |
| Dest.BTV                  |               0.332514   |
| Origin.ROC                |               0.330604   |
| Origin.MSP                |               0.327787   |
| Origin.JFK                |               0.321378   |
| Dest.MCO                  |               0.319434   |
| Dest.ABQ                  |               0.31625    |
| Dest.CMH                  |               0.31476    |
| Dest.LIH                  |               0.314168   |
| Dest.LBB                  |               0.308055   |
| UniqueCarrier.US          |               0.306931   |
| Origin.BWI                |               0.302733   |
| Dest.STL                  |               0.299111   |
| Origin.SLC                |               0.298382   |
| Origin.SAN                |               0.296723   |
| Origin.CLT                |               0.295934   |
| Origin.TUS                |               0.295071   |
| Origin.BHM                |               0.293635   |
| Dest.KOA                  |               0.289065   |
| Year.1992                 |               0.2847     |
| Dest.TUL                  |               0.283613   |
| Origin.UCA                |               0.268208   |
| DayOfWeek.5               |               0.263935   |
| Origin.CLE                |               0.261628   |
| Year.1991                 |               0.252562   |
| Origin.COS                |               0.249223   |
| Dest.CLT                  |               0.249012   |
| Origin.BUR                |               0.245153   |
| Dest.FAT                  |               0.242143   |
| Month.10                  |               0.241364   |
| Year.1987                 |               0.241364   |
| Dest.ROA                  |               0.240329   |
| Origin.PHF                |               0.239999   |
| Dest.OMA                  |               0.239768   |
| Origin.SMF                |               0.234955   |
| Dest.SEA                  |               0.229919   |
| Dest.HRL                  |               0.218484   |
| Origin.DEN                |               0.216983   |
| Year.1999                 |               0.211729   |
| Dest.PDX                  |               0.210602   |
| Dest.LAS                  |               0.206601   |
| Dest.MSP                  |               0.204311   |
| Origin.IAD                |               0.197859   |
| Dest.BUR                  |               0.197227   |
| Year.2000                 |               0.193379   |
| Origin.SFO                |               0.190741   |
| Origin.BGM                |               0.189911   |
| Origin.SWF                |               0.189525   |
| Dest.IAD                  |               0.18846    |
| Dest.SAT                  |               0.185439   |
| Dest.ROC                  |               0.183779   |
| Dest.LAX                  |               0.180429   |
| Origin.TPA                |               0.179862   |
| Dest.PHL                  |               0.170191   |
| Distance                  |               0.167585   |
| Origin.DTW                |               0.159638   |
| Dest.BNA                  |               0.159266   |
| DayOfWeek.4               |               0.157277   |
| Dest.OAK                  |               0.15125    |
| Dest.FNT                  |               0.150786   |
| Origin.PDX                |               0.15073    |
| Origin.RDU                |               0.144973   |
| Origin.ORF                |               0.144369   |
| Year.1989                 |               0.144251   |
| Dest.BWI                  |               0.143257   |
| Dest.ELP                  |               0.142565   |
| Dest.MDW                  |               0.134653   |
| UniqueCarrier.AA          |               0.133115   |
| Year.1993                 |               0.130076   |
| Dest.PHX                  |               0.129837   |
| Dest.SAN                  |               0.129307   |
| DayOfWeek.2               |               0.127522   |
| Origin.JAN                |               0.125264   |
| Origin.MKE                |               0.124907   |
| Origin.DAL                |               0.121417   |
| Dest.EWR                  |               0.121082   |
| Month.1                   |               0.117726   |
| Dest.RNO                  |               0.113732   |
| Origin.SJC                |               0.113654   |
| Dest.MAF                  |               0.112554   |
| Dest.ORD                  |               0.111985   |
| DayOfWeek.6               |               0.109718   |
| Dest.SBN                  |               0.109414   |
| Origin.ISP                |               0.105969   |
| UniqueCarrier.UA          |               0.0936254  |
| Origin.RIC                |               0.0921766  |
| Origin.PVD                |               0.0818253  |
| Dest.ORF                  |               0.07756    |
| Dest.MIA                  |               0.0775558  |
| Origin.CHS                |               0.072934   |
| Origin.LAS                |               0.0687986  |
| Dest.ABE                  |               0.0664163  |
| Origin.EWR                |               0.0642849  |
| Dest.DTW                  |               0.0633446  |
| Dest.HNL                  |               0.0610877  |
| Dest.ACY                  |               0.0584843  |
| Intercept                 |               0.0583759  |
| Dest.RSW                  |               0.054807   |
| Dest.DEN                  |               0.0531848  |
| Dest.TUS                  |               0.0503133  |
| Dest.SMF                  |               0.0495967  |
| Dest.BOS                  |               0.0492589  |
| Dest.MKE                  |               0.0488848  |
| Dest.LGA                  |               0.045868   |
| Origin.MEM                |               0.0415562  |
| Origin.DCA                |               0.0409492  |
| Origin.SNA                |               0.0380913  |
| Dest.MDT                  |               0.0359403  |
| DayOfWeek.7               |               0.0330296  |
| Dest.DFW                  |               0.032886   |
| Dest.MCI                  |               0.0325893  |
| Dest.SNA                  |               0.0310822  |
| Year.1998                 |               0.0305543  |
| DayOfWeek.3               |               0.0289725  |
| Year.1988                 |               0.0255815  |
| FlightNum                 |               0.0238619  |
| Dest.PIT                  |               0.0205454  |
| Origin.ONT                |               0.0203642  |
| Origin.SJU                |               0.0172619  |
| Origin.MDT                |               0.0171563  |
| Dest.BOI                  |               0.0171016  |
| Dest.MSY                  |               0.010624   |
| Dest.DAL                  |               0.00981252 |
| Dest.ATL                  |               0.00782332 |
| UniqueCarrier.DL          |               0.00670187 |
| Dest.PVD                  |               0.00252622 |
| Dest.JAN                  |               0.00252588 |
| Dest.SJC                  |               0.00177487 |
| Dest.SYR                  |               0.00165583 |
| Dest.AUS                  |               0.00048016 |
| Origin.KOA                |               0          |
| Dest.ORH                  |               0          |
| Month.missing(NA)         |               0          |
| Dest.MYR                  |               0          |
| Origin.LBB                |               0          |
| Dest.BHM                  |               0          |
| Origin.ABE                |               0          |
| DayOfWeek.1               |               0          |
| Origin.LIT                |               0          |
| Dest.AVP                  |               0          |
| Dest.SRQ                  |               0          |
| Dest.ILM                  |               0          |
| Origin.GRR                |               0          |
| Dest.PHF                  |               0          |
| Origin.MLB                |               0          |
| Origin.DSM                |               0          |
| Origin.SAT                |               0          |
| Dest.RDU                  |               0          |
| Dest.HPN                  |               0          |
| Dest.JFK                  |               0          |
| Dest.CRP                  |               0          |
| Dest.EUG                  |               0          |
| Dest.OKC                  |               0          |
| Origin.MHT                |               0          |
| Dest.STT                  |               0          |
| Origin.AMA                |               0          |
| Origin.TYS                |               0          |
| Dest.DSM                  |               0          |
| Origin.EGE                |               0          |
| Dest.GRR                  |               0          |
| Dest.LIT                  |               0          |
| Origin.LAN                |               0          |
| Dest.LEX                  |               0          |
| Dest.ANC                  |               0          |
| Dest.MRY                  |               0          |
| Dest.CHS                  |               0          |
| UniqueCarrier.missing(NA) |               0          |
| Dest.AMA                  |               0          |
| Origin.MFR                |               0          |
| Origin.SEA                |               0          |
| Dest.OAJ                  |               0          |
| Origin.GNV                |               0          |
| Origin.SCK                |               0          |
| Origin.HOU                |               0          |
| Dest.ERI                  |               0          |
| Dest.GSP                  |               0          |
| Dest.DCA                  |               0          |
| Origin.CVG                |               0          |
| Dest.ELM                  |               0          |
| Dest.MHT                  |               0          |
| Origin.TUL                |               0          |
| Origin.ICT                |               0          |
| Origin.MRY                |               0          |
| Year.missing(NA)          |               0          |
| Origin.GEG                |               0          |
| Dest.RIC                  |               0          |
| Dest.HOU                  |               0          |
| Dest.ONT                  |               0          |
| UniqueCarrier.PS          |               0          |
| Dest.CVG                  |               0          |
| Origin.missing(NA)        |               0          |
| Dest.missing(NA)          |               0          |
| Origin.BIL                |               0          |
| Dest.EYW                  |               0          |
| Origin.PHX                |               0          |
| Dest.CHA                  |               0          |
| Dest.TOL                  |               0          |
| Dest.PWM                  |               0          |
| Origin.MAF                |               0          |
| Origin.SDF                |               0          |
| DayOfWeek.missing(NA)     |               0          |
| Origin.SBN                |               0          |
| Origin.HRL                |               0          |
| Origin.BNA                |               0          |
| Origin.RSW                |               0          |
| Origin.STT                |               0          |
| Origin.ANC                |               0          |
| Dest.SJU                  |               0          |
| Origin.ABQ                |               0          |
| Origin.AVP                |               0          |
| Dest.SCK                  |               0          |

In [28]:
# Plot GLM Coefficient Magnitudes
all_coefficient_magnitudes = pandas.DataFrame(glm_sorted)
coefficient_magnitudes = all_coefficient_magnitudes[1:10]
feature_labels = list(coefficient_magnitudes[0])
Index = coefficient_magnitudes.index

# for python3 use range() instead of xrange()
plt.figure(figsize=(16,5))
h = plt.bar(range(len(feature_labels)), coefficient_magnitudes[1],width=0.6, label=feature_labels, color ='aqua')
plt.title("GLM Coefficient Magnitudes", fontsize=20 )
xticks_pos = [0.65*patch.get_width() + patch.get_xy()[0] for patch in h]
plt.xticks(xticks_pos, feature_labels, fontsize=13,  ha='right')


Out[28]:
([<matplotlib.axis.XTick at 0x10e9526a0>,
  <matplotlib.axis.XTick at 0x10e7a9fd0>,
  <matplotlib.axis.XTick at 0x10e96b160>,
  <matplotlib.axis.XTick at 0x10e9c4b38>,
  <matplotlib.axis.XTick at 0x10e9cb588>,
  <matplotlib.axis.XTick at 0x10e9cbf98>,
  <matplotlib.axis.XTick at 0x10e9cd978>,
  <matplotlib.axis.XTick at 0x10e9d13c8>,
  <matplotlib.axis.XTick at 0x10e9d1dd8>],
 <a list of 9 Text xticklabel objects>)

In [29]:
# DRF Variable Importance
data_rf2.varimp(use_pandas=True)


Out[29]:
variable relative_importance scaled_importance percentage
0 Year 2689.601562 1.000000 0.365336
1 Origin 2137.199951 0.794616 0.290301
2 Dest 1349.895386 0.501894 0.183360
3 UniqueCarrier 508.218170 0.188957 0.069033
4 Distance 320.266632 0.119076 0.043503
5 FlightNum 259.619904 0.096527 0.035265
6 DayOfWeek 93.597191 0.034800 0.012714
7 Month 3.603836 0.001340 0.000490

In [30]:
# Plot DRF Feature Importances
importances = data_rf2.varimp(use_pandas=True)
feature_labels = list(importances['variable'])
Index = importances.index

plt.figure(figsize=(14,5))
h = plt.bar(range(len(feature_labels)), importances['relative_importance'],width=0.6, label=feature_labels, color ='aqua')
plt.title("DRF Feature Importances", fontsize=20 )
xticks_pos = [0.65*patch.get_width() + patch.get_xy()[0] for patch in h]
plt.xticks(xticks_pos, feature_labels, fontsize=12,  ha='right')


Out[30]:
([<matplotlib.axis.XTick at 0x10ec08860>,
  <matplotlib.axis.XTick at 0x10de29e80>,
  <matplotlib.axis.XTick at 0x10e9e3160>,
  <matplotlib.axis.XTick at 0x10ede15c0>,
  <matplotlib.axis.XTick at 0x10ede1fd0>,
  <matplotlib.axis.XTick at 0x10ede3a20>,
  <matplotlib.axis.XTick at 0x10ede9470>,
  <matplotlib.axis.XTick at 0x10ede9e80>],
 <a list of 8 Text xticklabel objects>)

In [31]:
# GBM Variable Importance
data_gbm2.varimp(use_pandas=True)


Out[31]:
variable relative_importance scaled_importance percentage
0 Origin 4716.587402 1.000000 0.369159
1 Year 3886.308838 0.823966 0.304174
2 Dest 2975.810547 0.630924 0.232911
3 UniqueCarrier 469.562317 0.099556 0.036752
4 DayOfWeek 339.474030 0.071975 0.026570
5 FlightNum 280.978149 0.059572 0.021992
6 Distance 107.865120 0.022869 0.008442
7 Month 0.000000 0.000000 0.000000

In [32]:
# Plot GBM Feature Importances
importances = data_gbm2.varimp(use_pandas=True)
feature_labels = list(importances['variable'])
Index = importances.index

plt.figure(figsize=(14,5))
h = plt.bar(range(len(feature_labels)), importances['relative_importance'],width=0.6, label=feature_labels, color ='aqua')
plt.title("GBM Feature Importances", fontsize=20 )
xticks_pos = [0.65*patch.get_width() + patch.get_xy()[0] for patch in h]
plt.xticks(xticks_pos, feature_labels, fontsize=12,  ha='right')


Out[32]:
([<matplotlib.axis.XTick at 0x10e799080>,
  <matplotlib.axis.XTick at 0x10ee14f28>,
  <matplotlib.axis.XTick at 0x10ddebb00>,
  <matplotlib.axis.XTick at 0x10ee5bc50>,
  <matplotlib.axis.XTick at 0x10ee606a0>,
  <matplotlib.axis.XTick at 0x10ee640f0>,
  <matplotlib.axis.XTick at 0x10ee64b00>,
  <matplotlib.axis.XTick at 0x10ee68550>],
 <a list of 8 Text xticklabel objects>)

In [33]:
# Deep Learning Variable Importance
data_dl.varimp(use_pandas=True)


Out[33]:
variable relative_importance scaled_importance percentage
0 Origin.MDW 1.000000 1.000000 0.008451
1 Year.2003 0.866093 0.866093 0.007320
2 Dest.EWR 0.774157 0.774157 0.006543
3 Origin.ORD 0.742141 0.742141 0.006272
4 Year.2001 0.736653 0.736653 0.006226
5 Year.2002 0.710980 0.710980 0.006009
6 Origin.SFO 0.678087 0.678087 0.005731
7 Month.1 0.666030 0.666030 0.005629
8 Origin.AUS 0.665489 0.665489 0.005624
9 Dest.SFO 0.657573 0.657573 0.005557
10 Year.1989 0.640961 0.640961 0.005417
11 Year.1990 0.637166 0.637166 0.005385
12 Year.2007 0.604424 0.604424 0.005108
13 Year.1994 0.591667 0.591667 0.005000
14 Origin.PSP 0.591646 0.591646 0.005000
15 Dest.FLL 0.591537 0.591537 0.004999
16 Year.2008 0.587823 0.587823 0.004968
17 Dest.CAE 0.585480 0.585480 0.004948
18 Origin.ATL 0.567194 0.567194 0.004794
19 Origin.PIT 0.558928 0.558928 0.004724
20 Distance 0.555028 0.555028 0.004691
21 Year.1998 0.552688 0.552688 0.004671
22 Dest.CHO 0.549932 0.549932 0.004648
23 FlightNum 0.549580 0.549580 0.004645
24 Origin.IAH 0.549099 0.549099 0.004641
25 Dest.PHX 0.543288 0.543288 0.004591
26 UniqueCarrier.HP 0.537485 0.537485 0.004542
27 Origin.PHL 0.533020 0.533020 0.004505
28 Dest.HTS 0.524996 0.524996 0.004437
29 UniqueCarrier.UA 0.524380 0.524380 0.004432
... ... ... ... ...
285 Dest.GSP 0.256756 0.256756 0.002170
286 DayOfWeek.4 0.256518 0.256518 0.002168
287 Origin.CRP 0.255690 0.255690 0.002161
288 Dest.GRR 0.254057 0.254057 0.002147
289 Dest.HPN 0.252856 0.252856 0.002137
290 Origin.LAN 0.251714 0.251714 0.002127
291 Origin.TPA 0.249546 0.249546 0.002109
292 Origin.CLT 0.248593 0.248593 0.002101
293 Origin.CVG 0.248328 0.248328 0.002099
294 Origin.PVD 0.243274 0.243274 0.002056
295 Dest.RSW 0.238737 0.238737 0.002018
296 Dest.MAF 0.237673 0.237673 0.002009
297 Dest.CAK 0.231845 0.231845 0.001959
298 Origin.BDL 0.231663 0.231663 0.001958
299 Origin.SNA 0.228553 0.228553 0.001932
300 Origin.ONT 0.225977 0.225977 0.001910
301 Dest.SAN 0.224082 0.224082 0.001894
302 Origin.STX 0.223197 0.223197 0.001886
303 Dest.ANC 0.215848 0.215848 0.001824
304 Origin.TYS 0.214950 0.214950 0.001817
305 Origin.ELP 0.214924 0.214924 0.001816
306 Origin.SRQ 0.203653 0.203653 0.001721
307 Dest.MKE 0.194508 0.194508 0.001644
308 Origin.EYW 0.179053 0.179053 0.001513
309 Dest.missing(NA) 0.000000 0.000000 0.000000
310 Origin.missing(NA) 0.000000 0.000000 0.000000
311 Year.missing(NA) 0.000000 0.000000 0.000000
312 UniqueCarrier.missing(NA) 0.000000 0.000000 0.000000
313 DayOfWeek.missing(NA) 0.000000 0.000000 0.000000
314 Month.missing(NA) 0.000000 0.000000 0.000000

315 rows × 4 columns


In [34]:
# Plot Deep Learning Feature Importances
all_coefficient_magnitudes = data_dl.varimp(use_pandas=True)
importances = all_coefficient_magnitudes[1:10]
feature_labels = list(importances['variable'])
Index = importances.index

plt.figure(figsize=(20,6))
h = plt.bar(range(len(feature_labels)), importances['relative_importance'],width=0.6, label=feature_labels, color ='aqua')
plt.title("Deep Learning Feature Importances",fontsize = 20)
xticks_pos = [0.65*patch.get_width() + patch.get_xy()[0] for patch in h]
plt.xticks(xticks_pos, feature_labels,fontsize = 13,   ha='right')


Out[34]:
([<matplotlib.axis.XTick at 0x10edeb630>,
  <matplotlib.axis.XTick at 0x10dd096a0>,
  <matplotlib.axis.XTick at 0x10ee016a0>,
  <matplotlib.axis.XTick at 0x10f0995f8>,
  <matplotlib.axis.XTick at 0x10f09b048>,
  <matplotlib.axis.XTick at 0x10f09ba58>,
  <matplotlib.axis.XTick at 0x10f0a04a8>,
  <matplotlib.axis.XTick at 0x10f0a0eb8>,
  <matplotlib.axis.XTick at 0x10f0a3908>],
 <a list of 9 Text xticklabel objects>)

Can H2O Handle New Categorical Levels in a Test Set?


Yes! Unlike most machine learning algorithms, H2O-3's algorithms can successfully make predictions, even if a test set contains categorical levels that were not present in the training set. This is because every algorithm handles new categorical levels specifically. So, the next question becomes:

How does each algorithm handle unseen categorical levels in a test set?

Skip to the algorithm you're using to see how it predicts on a categorical level not seen during training:


GLM :

GLM will predict 'Double.NAN' for each row with a new categorical level, indicating a prediction wasn't made.

After running the cells to load, clean, and split the data you can play with a GLM here.

DRF & GBM :

Unseen factors can go either left or right for small counts of factor levels. Otherwise, for large counts, they go left.

After running the cells to load, clean, and split the data you can play with a GBM here or a DRF here.

Deep Learning :

For an unseen categorical level in the test set, Deep Learning makes an extra input neuron that remains untrained and contributes some random amount to the subsequent layer.

After running the cells to load, clean, and split the data you can play with a Deep Learning model here.

K-Means :

An unseen categorical level in a row does not contribute to that row's prediction. This is because the unseen categorical level does not contribute to the distance comparison between clusters, and therefore does not factor in predicting the cluster to which that row belongs.

Naive Bayes :

If the Laplace smoothing parameter is disabled ('laplace = 0'), then Naive Bayes will predict a probability of 0 for any row in the test set that contains a previously unseen categorical level. However, if the Laplace smoothing parameter is used (e.g. 'laplace = 1'), then the model can make predictions for rows that include previously unseen categorical level.

Laplace smoothing adjusts the maximum likelihood estimates by adding 1 to the numerator and k to the denominator to allow for new categorical levels in the training set:

$$\phi_{j|y=1}= \frac{\Sigma_{i=1}^m 1(x_{j}^{(i)} \ = \ 1 \ \bigcap y^{(i)} \ = \ 1) \ + \ 1}{\Sigma_{i=1}^{m}1(y^{(i)} \ = \ 1) \ + \ k}$$$$\phi_{j|y=0}= \frac{\Sigma_{i=1}^m 1(x_{j}^{(i)} \ = \ 1 \ \bigcap y^{(i)} \ = \ 0) \ + \ 1}{\Sigma_{i \ = \ 1}^{m}1(y^{(i)} \ = \ 0) \ + \ k}$$

(Where $x^{(i)}$ represents features, $y^{(i)}$ represents the response column, and $k$ represents the addition of each new categorical level (k functions to balance the added 1 in the numerator))

Laplace smoothing should be used with care; it is generally intended to allow for predictions in rare events. As prediction data becomes increasingly distinct from training data, new models should be trained when possible to account for a broader set of possible feature values.

After running the cells to load, clean, and split the data you can play with a Naive Bayes model here.

PCA :

New categorical levels in the test data that were not present in the training data, are skipped in the row product- sum.

How Does H2O Handle Missing Values during Training & Testing?

Skip to the algorithm you're using to see how it trains or predicts with missing values:

(Note: NA values in the training set are not neccessarily handled the same way as NA values in the test set)


GLM :

How does the algorithm handle missing values during training?

Depending on the selected missing value handling policy, they are either imputed mean or the whole row is skipped. The default behavior is mean imputation. Note that categorical variables are imputed by adding extra "missing" level.

Optionally, glm can skip all rows with any missing values.

How does the algorithm handle missing values during testing?

Same as during training. If the missing value handling is set to skip and we are generating predictions, skipped rows will have Na (missing) prediction.

DRF & GBM :

How does the algorithm handle missing values during training and testing?

Missing values always go right at every split decision.

Deep Learning :

How does the algorithm handle missing values during training?

Missing values in the training set will be mean-imputed or the whole row can be skipped, depending on how the following parameter is set: missing_values_handling = "MeanImputation" or "Skip".

How does the algorithm handle missing values during testing?

Missing values in the test set will be mean-imputed (with the mean of the training data) during scoring.

K-Means :

How does the algorithm handle missing values during training?

Missing values are automatically imputed by the column mean. K-means also handles missing values by assuming that missing feature distance contributions are equal to the average of all other distance term contributions.

How does the algorithm handle missing values during testing?

Missing values are automatically imputed by the column mean of the training data.

Naive Bayes :

How does the algorithm handle missing values during training?

All rows with one or more missing values (either in the predictors or the response) will be skipped during model building.

How does the algorithm handle missing values during testing?

If a predictor is missing, it will be skipped when taking the product of conditional probabilities in calculating the joint probability conditional on the response.

PCA :

How does the algorithm handle missing values during scoring?

For the GramSVD and Power methods, all rows containing missing values are ignored during training. For the GLRM method, missing values are excluded from the sum over the loss function in the objective. For more information, refer to section 4 Generalized Loss Functions, equation (13), in "Generalized Low Rank Models" by Boyd et al.

How does the algorithm handle missing values during testing?

During scoring, the test data is right-multiplied by the eigenvector matrix produced by PCA. Missing categorical values are skipped in the row product-sum. Missing numeric values propagate an entire row of NAs in the resulting projection matrix.


In [ ]: