In [1]:
# Install Sympy to get the expressions at the final model
! pip install sympy


Requirement already satisfied: sympy in /usr/local/lib/python2.7/site-packages
Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python2.7/site-packages (from sympy)

In [2]:
# Imports
import statsmodels.datasets as datasets
import sklearn.metrics as metrics
import numpy
import matplotlib.pyplot as plt
from numpy import log
from pyearth import Earth as earth
from sklearn.ensemble import RandomForestRegressor
from pyearth import Earth
from pyearth import export

%matplotlib inline

In [3]:
# Let's get Boston dataset to deal with some continuous problem
boston = datasets.get_rdataset("Boston", "MASS").data

In [4]:
# Some visualization of the dataset
boston[0:5]


Out[4]:
crim zn indus chas nox rm age dis rad tax ptratio black lstat medv
0 0.00632 18.0 2.31 0 0.538 6.575 65.2 4.0900 1 296 15.3 396.90 4.98 24.0
1 0.02731 0.0 7.07 0 0.469 6.421 78.9 4.9671 2 242 17.8 396.90 9.14 21.6
2 0.02729 0.0 7.07 0 0.469 7.185 61.1 4.9671 2 242 17.8 392.83 4.03 34.7
3 0.03237 0.0 2.18 0 0.458 6.998 45.8 6.0622 3 222 18.7 394.63 2.94 33.4
4 0.06905 0.0 2.18 0 0.458 7.147 54.2 6.0622 3 222 18.7 396.90 5.33 36.2

In [5]:
# Grab the INDEPENDENT variables to put inside the model
x = boston.iloc[:, 0:boston.shape[1] - 1]

In [6]:
# Some data to play 
x[0:5]


Out[6]:
crim zn indus chas nox rm age dis rad tax ptratio black lstat
0 0.00632 18.0 2.31 0 0.538 6.575 65.2 4.0900 1 296 15.3 396.90 4.98
1 0.02731 0.0 7.07 0 0.469 6.421 78.9 4.9671 2 242 17.8 396.90 9.14
2 0.02729 0.0 7.07 0 0.469 7.185 61.1 4.9671 2 242 17.8 392.83 4.03
3 0.03237 0.0 2.18 0 0.458 6.998 45.8 6.0622 3 222 18.7 394.63 2.94
4 0.06905 0.0 2.18 0 0.458 7.147 54.2 6.0622 3 222 18.7 396.90 5.33

In [7]:
# Get all columns of dataset
xlabel = list(x.columns)
xlabel


Out[7]:
['crim',
 'zn',
 'indus',
 'chas',
 'nox',
 'rm',
 'age',
 'dis',
 'rad',
 'tax',
 'ptratio',
 'black',
 'lstat']

In [8]:
# Get only the DEPENDENT variable
y = boston.iloc[:, boston.shape[1] - 1]
y[:5]


Out[8]:
0    24.0
1    21.6
2    34.7
3    33.4
4    36.2
Name: medv, dtype: float64

In [9]:
# Put some parameters to model
model = earth(allow_linear=None
            ,allow_missing=False
            ,check_every=None
            ,enable_pruning=True
            ,endspan=None
            ,endspan_alpha=0.05
            ,fast_K=None
            ,fast_h=None
            ,feature_importance_type=None
            ,max_degree=None
            ,max_terms=None
            ,min_search_points=None
            ,minspan=None
            ,minspan_alpha=0.05
            ,penalty=3
            ,smooth=None
            ,thresh=None
            ,use_fast=None
            ,verbose=0
            ,zero_tol=None)

In [10]:
# Fit the model
model.fit(x, log(y), xlabels = xlabel)


Out[10]:
Earth(allow_linear=None, allow_missing=False, check_every=None,
   enable_pruning=True, endspan=None, endspan_alpha=0.05, fast_K=None,
   fast_h=None, feature_importance_type=None, max_degree=None,
   max_terms=None, min_search_points=None, minspan=None,
   minspan_alpha=0.05, penalty=3, smooth=None, thresh=None, use_fast=None,
   verbose=0, zero_tol=None)

In [11]:
# GRab some data from the model
print model.summary()


Earth Model
---------------------------------------
Basis Function   Pruned  Coefficient   
---------------------------------------
(Intercept)      No      1.92889       
h(lstat-5.68)    No      0.0291543     
h(5.68-lstat)    Yes     None          
h(rm-6.383)      No      0.27093       
h(6.383-rm)      Yes     None          
h(crim-24.8017)  No      -0.0274698    
h(24.8017-crim)  No      0.0406624     
h(dis-1.5106)    No      -0.0353598    
h(1.5106-dis)    No      1.41847       
ptratio          No      -0.0292831    
nox              No      -0.586277     
rad              No      0.0203696     
tax              No      -0.000525082  
h(black-179.36)  No      -0.000554406  
h(179.36-black)  No      -0.00144926   
h(lstat-30.59)   Yes     None          
h(30.59-lstat)   No      0.0564323     
h(rm-7.82)       No      -0.313019     
h(7.82-rm)       Yes     None          
indus            Yes     None          
h(tax-666)       Yes     None          
h(666-tax)       Yes     None          
h(crim-9.96654)  No      0.0218642     
h(9.96654-crim)  Yes     None          
chas             Yes     None          
---------------------------------------
MSE: 0.0241, GCV: 0.0282, RSQ: 0.8558, GRSQ: 0.8317

In [12]:
# Some metrics about the model (R**2)
r2 = metrics.r2_score(log(y), model.predict(x))
mean_absolute_error = metrics.mean_absolute_error(log(y), model.predict(x))
mean_squared_error = metrics.mean_squared_error(log(y), model.predict(x))

print 'R**2: ', r2
print 'Mean Absolute Error: ', mean_absolute_error
print 'Mean Squared Error: ', mean_squared_error
print 'Root Mean Squared Error', numpy.sqrt(mean_squared_error) #Fix it


R**2:  0.855759400541
Mean Absolute Error:  0.11375424877
Mean Squared Error:  0.0240524027952
Root Mean Squared Error 0.155088370922

In [13]:
#See the traces of the model
print(model.trace())


Forward Pass
---------------------------------------------------------------
iter  parent  var  knot  mse       terms  gcv    rsq    grsq   
---------------------------------------------------------------
0     -       -    -     0.166752  1      0.167  0.000  0.000  
1     0       12   183   0.052668  3      0.054  0.684  0.678  
2     0       5    56    0.044880  5      0.047  0.731  0.720  
3     0       0    403   0.036296  7      0.039  0.782  0.769  
4     0       7    367   0.031285  9      0.034  0.812  0.797  
5     0       10   -1    0.029718  10     0.033  0.822  0.805  
6     0       4    -1    0.028164  11     0.031  0.831  0.813  
7     0       8    -1    0.026881  12     0.030  0.839  0.820  
8     0       9    -1    0.025696  13     0.029  0.846  0.826  
9     0       11   409   0.025057  15     0.029  0.850  0.827  
10    0       12   398   0.024554  17     0.029  0.853  0.826  
11    0       5    280   0.024169  19     0.029  0.855  0.825  
12    0       2    -1    0.023990  20     0.029  0.856  0.825  
13    0       9    402   0.023761  22     0.030  0.858  0.823  
14    0       0    443   0.023468  24     0.030  0.859  0.821  
15    0       3    -1    0.023359  25     0.030  0.860  0.820  
---------------------------------------------------------------
Stopping Condition 2: Improvement below threshold

Pruning Pass
--------------------------------------------
iter  bf  terms  mse   gcv    rsq    grsq   
--------------------------------------------
0     -   25     0.02  0.030  0.860  0.819  
1     23  24     0.02  0.030  0.860  0.821  
2     21  23     0.02  0.030  0.860  0.823  
3     18  22     0.02  0.029  0.860  0.825  
4     2   21     0.02  0.029  0.860  0.827  
5     15  20     0.02  0.029  0.860  0.829  
6     24  19     0.02  0.028  0.859  0.830  
7     4   18     0.02  0.028  0.858  0.831  
8     20  17     0.02  0.028  0.857  0.831  
9     19  16     0.02  0.028  0.856  0.832  
10    22  15     0.02  0.028  0.854  0.832  
11    5   14     0.02  0.028  0.853  0.832  
12    13  13     0.02  0.028  0.851  0.831  
13    17  12     0.03  0.028  0.849  0.831  
14    1   11     0.03  0.029  0.845  0.828  
15    14  10     0.03  0.029  0.841  0.826  
16    12  9      0.03  0.030  0.835  0.821  
17    11  8      0.03  0.031  0.825  0.813  
18    10  7      0.03  0.032  0.818  0.806  
19    7   6      0.03  0.033  0.814  0.804  
20    9   5      0.03  0.034  0.805  0.797  
21    8   4      0.04  0.039  0.774  0.767  
22    6   3      0.04  0.046  0.731  0.725  
23    3   2      0.06  0.057  0.660  0.657  
24    16  1      0.17  0.167  0.000  0.000  
--------------------------------------------
Selected iteration: 9


In [14]:
# Summary of the model
print(model.summary())


Earth Model
---------------------------------------
Basis Function   Pruned  Coefficient   
---------------------------------------
(Intercept)      No      1.92889       
h(lstat-5.68)    No      0.0291543     
h(5.68-lstat)    Yes     None          
h(rm-6.383)      No      0.27093       
h(6.383-rm)      Yes     None          
h(crim-24.8017)  No      -0.0274698    
h(24.8017-crim)  No      0.0406624     
h(dis-1.5106)    No      -0.0353598    
h(1.5106-dis)    No      1.41847       
ptratio          No      -0.0292831    
nox              No      -0.586277     
rad              No      0.0203696     
tax              No      -0.000525082  
h(black-179.36)  No      -0.000554406  
h(179.36-black)  No      -0.00144926   
h(lstat-30.59)   Yes     None          
h(30.59-lstat)   No      0.0564323     
h(rm-7.82)       No      -0.313019     
h(7.82-rm)       Yes     None          
indus            Yes     None          
h(tax-666)       Yes     None          
h(666-tax)       Yes     None          
h(crim-9.96654)  No      0.0218642     
h(9.96654-crim)  Yes     None          
chas             Yes     None          
---------------------------------------
MSE: 0.0241, GCV: 0.0282, RSQ: 0.8558, GRSQ: 0.8317

In [15]:
# See predictions
y_hat = model.predict(x)
y_hat


Out[15]:
array([ 3.32406087,  3.13255523,  3.52929814,  3.51118389,  3.41392939,
        3.22957942,  3.07854321,  2.87843681,  2.58270934,  2.91628264,
        2.82812743,  3.03016581,  2.99555127,  3.00339681,  2.96655025,
        3.00586111,  3.0449344 ,  2.84379035,  2.99483962,  2.95262371,
        2.6734715 ,  2.86942517,  2.71939587,  2.69480206,  2.79149452,
        2.83055139,  2.83501332,  2.80326325,  2.91798155,  2.99141717,
        2.62899892,  2.87354113,  2.55846911,  2.76061944,  2.74979543,
        3.14594336,  3.10749957,  3.14982392,  3.11386322,  3.3606862 ,
        3.60873003,  3.37901793,  3.22377793,  3.17244031,  3.11918581,
        3.11697805,  3.00881418,  2.86178772,  2.53690228,  2.91670651,
        3.02469226,  3.1370755 ,  3.29457934,  3.16246715,  2.72436909,
        3.47488032,  3.07717649,  3.48242662,  3.13594507,  3.10141706,
        2.98296052,  2.97069099,  3.1784547 ,  3.1591606 ,  3.24080506,
        3.29740904,  3.11570181,  3.08525679,  2.94559994,  3.06319513,
        3.19047936,  3.09596718,  3.21454786,  3.15818056,  3.18253902,
        3.12056374,  3.05888268,  3.09711526,  3.00954679,  3.1095437 ,
        3.32628747,  3.23357258,  3.18289441,  3.16485934,  3.12545407,
        3.28939491,  3.04835376,  3.19262022,  3.42288604,  3.43255552,
        3.18216058,  3.19510227,  3.22869171,  3.26570696,  3.14713643,
        3.31587989,  3.10389926,  3.71704842,  3.76709762,  3.52332917,
        3.14735696,  3.20789664,  2.98817267,  2.94433192,  2.98891054,
        2.88661624,  2.82072973,  2.95457462,  3.01383823,  2.89245353,
        2.95886221,  3.19906484,  2.95498654,  2.9190012 ,  3.11820325,
        2.98917128,  3.05599316,  3.10138536,  3.00443956,  3.01259984,
        3.00262535,  3.01374175,  2.91671119,  2.7196907 ,  2.92816964,
        2.99820856,  2.67231966,  2.75621563,  2.81211415,  2.69011939,
        2.8851509 ,  2.83300191,  2.8927955 ,  2.79005719,  2.77577875,
        2.73752727,  2.76291338,  2.83931858,  2.64820934,  2.71154226,
        2.57326914,  2.55453332,  2.72803242,  2.57942901,  2.64819799,
        2.71254876,  2.89054005,  2.49826315,  2.48073671,  2.65473751,
        2.88416025,  2.93127648,  2.9763643 ,  2.87930306,  2.89026841,
        2.75302994,  2.76795309,  3.49665579,  3.26764798,  3.11049711,
        3.30793156,  3.78682188,  3.8349552 ,  3.74662021,  3.04259707,
        3.1556046 ,  3.73869732,  3.1396592 ,  3.12069581,  3.08976472,
        3.06544088,  3.05627315,  3.094015  ,  3.25800288,  3.23228516,
        3.39098337,  3.18988208,  3.30181363,  3.42998972,  3.52382592,
        3.65263309,  3.23084716,  3.58838406,  3.37237563,  3.09618141,
        3.11158012,  3.77358898,  3.41281752,  3.41875822,  3.53625116,
        3.43194387,  3.39185146,  3.61235137,  3.39603365,  3.38915838,
        3.89555431,  3.61322254,  3.41758491,  3.49647207,  3.34355636,
        3.40065616,  3.19105768,  3.71587953,  3.87872179,  3.92543249,
        3.09396889,  3.07409413,  2.8828028 ,  2.98885684,  2.75095206,
        2.92228199,  2.73745747,  2.96358807,  3.14059465,  2.61744062,
        3.1314661 ,  3.10694112,  3.26940124,  2.98960382,  3.1768527 ,
        3.40200123,  2.91702932,  3.35191307,  3.34150891,  3.82153215,
        3.76764857,  3.87267119,  3.54873087,  3.8042937 ,  3.49868017,
        3.16844449,  3.63858282,  3.86659992,  3.80723568,  3.37219345,
        3.19766859,  3.26965207,  3.62497955,  3.29906146,  3.30254441,
        3.26184507,  3.09318659,  3.1128997 ,  3.29229078,  2.97107342,
        2.79987141,  3.04337848,  3.02998248,  3.0709128 ,  3.21724923,
        3.18850781,  3.30682646,  3.3977938 ,  3.59491333,  3.07078509,
        2.99692746,  3.72869902,  3.77831124,  3.59254656,  3.47518073,
        3.50027994,  3.65144255,  3.74136743,  3.4778328 ,  3.55048946,
        3.25031386,  3.30134706,  3.74030831,  3.83819654,  3.04670808,
        3.03936504,  3.21521551,  3.24720682,  3.57304328,  3.52303081,
        3.57175989,  3.50670008,  3.47985354,  3.28256886,  3.57788487,
        3.90390653,  3.6040079 ,  3.88516966,  3.87824084,  3.33836867,
        3.15445785,  2.92177952,  3.22061396,  3.20777375,  3.2190613 ,
        3.50533049,  3.56832543,  3.36538742,  3.20387741,  3.15425841,
        3.32873199,  3.26855011,  2.97456621,  3.29842701,  3.48819488,
        3.36822823,  3.27733578,  3.27863902,  3.53433806,  3.54074051,
        3.34145447,  3.633176  ,  3.44765305,  3.33528661,  3.09946362,
        2.97967857,  3.20634742,  3.06475457,  3.15509462,  3.14996356,
        3.03173506,  2.84407144,  2.90762836,  3.0720519 ,  2.98725318,
        3.16691261,  3.16369546,  3.12777095,  3.02350036,  3.17994851,
        3.21343031,  3.14742215,  2.96883565,  3.0709594 ,  3.1466652 ,
        3.08019442,  2.96566533,  3.1093904 ,  3.14025154,  3.11178293,
        3.08989273,  3.06795072,  3.03496053,  3.11842201,  3.08355789,
        3.09556703,  3.42913533,  3.04816976,  3.22451296,  3.35285375,
        2.9318643 ,  2.88390888,  3.11481192,  3.22500381,  3.12875047,
        2.99939244,  3.06110958,  2.9280683 ,  3.27886533,  2.86277338,
        2.93479209,  2.61730213,  2.92630916,  2.90694881,  2.92288244,
        3.05730093,  2.92845893,  3.03409628,  2.91300005,  3.58841019,
        3.14444343,  3.00821734,  2.77056259,  3.60209039,  3.59124299,
        3.89918978,  3.42124355,  3.52268192,  2.87134579,  2.87848622,
        3.10613482,  2.71835723,  2.88068342,  2.39963349,  2.56216813,
        2.31034474,  2.48394909,  2.49490632,  2.51896489,  2.24839494,
        2.27260086,  2.1283716 ,  2.07811284,  2.20428912,  2.60679496,
        2.75154   ,  2.77476693,  2.36947616,  2.74355896,  2.5941588 ,
        2.71632532,  2.75174232,  2.66122925,  1.95406504,  2.34508773,
        2.10030051,  2.47674549,  2.58727988,  2.28858502,  2.0279942 ,
        2.09357722,  2.78116651,  3.1209413 ,  2.67878716,  2.85784927,
        2.47042668,  2.4993105 ,  2.18019379,  2.42189721,  2.0524969 ,
        2.0792353 ,  2.39104536,  2.1502086 ,  1.88140846,  2.48422856,
        2.70669582,  2.82010195,  2.77893304,  2.48937219,  2.60351566,
        2.20333543,  2.58575179,  2.25019275,  2.62625148,  2.42290212,
        2.71865332,  2.71857135,  2.98213306,  2.82944574,  2.59515795,
        2.49358733,  2.41412859,  2.12643521,  2.23141336,  2.47444289,
        2.24614976,  2.5560647 ,  2.79149577,  2.58852828,  2.45682063,
        2.34538786,  2.7741725 ,  2.61957101,  2.60663101,  2.71106867,
        2.68973563,  2.86908777,  2.80523288,  2.96704493,  2.53883243,
        2.75950645,  2.63669857,  2.53241617,  2.77326672,  2.78686509,
        2.98976787,  2.92296088,  2.81158171,  2.98096548,  2.81033182,
        3.00663967,  2.76420945,  2.82128656,  2.54598472,  2.67230107,
        2.91001747,  3.04261027,  3.01581307,  3.1960213 ,  2.7523871 ,
        2.6929528 ,  2.85713079,  2.39072715,  2.64922802,  2.71209706,
        3.01574667,  3.20711598,  3.30633866,  3.12506225,  3.05893175,
        3.06217283,  2.87360246,  3.02109408,  2.63121546,  2.49929389,
        2.35438505,  2.63386933,  2.75007986,  3.02369167,  2.97627175,
        2.85823698,  2.75517404,  2.94474617,  2.9952239 ,  2.939202  ,
        2.95421083,  3.06265089,  3.02659052,  3.28586778,  3.20457732,
        3.05154626])

In [16]:
# Now to see the variable importances we'll list first a set of metrics 
criteria = ('rss', 'gcv', 'nb_subsets')

In [17]:
# Put some parameters inside the model
# Second model considering variable importances

model2 = Earth(max_degree=3
              ,max_terms=10
              ,minspan_alpha=.5
              ,feature_importance_type=criteria #This is the variable that made de selections about variable importances
              ,verbose=True)

In [18]:
# Fit the model
model2.fit(x, log(y), xlabels = xlabel)


Beginning forward pass
---------------------------------------------------------------
iter  parent  var  knot  mse       terms  gcv    rsq    grsq   
---------------------------------------------------------------
0     -       -    -     0.166752  1      0.167  0.000  0.000  
1     0       12   89    0.052664  3      0.054  0.684  0.678  
2     0       5    84    0.044882  5      0.047  0.731  0.720  
3     1       0    474   0.036055  7      0.038  0.784  0.770  
4     0       7    379   0.031645  9      0.034  0.810  0.794  
5     4       12   423   0.028962  11     0.032  0.826  0.808  
---------------------------------------------------------------
Stopping Condition 0: Reached maximum number of terms
Beginning pruning pass
--------------------------------------------
iter  bf  terms  mse   gcv    rsq    grsq   
--------------------------------------------
0     -   11     0.03  0.032  0.826  0.808  
1     10  10     0.03  0.032  0.826  0.810  
2     4   9      0.03  0.032  0.826  0.811  
3     7   8      0.03  0.032  0.823  0.811  
4     2   7      0.03  0.033  0.814  0.802  
5     5   6      0.03  0.035  0.800  0.789  
6     9   5      0.04  0.037  0.789  0.780  
7     8   4      0.04  0.042  0.755  0.747  
8     6   3      0.05  0.049  0.715  0.709  
9     3   2      0.06  0.065  0.617  0.613  
10    1   1      0.17  0.167  0.000  0.000  
--------------------------------------------
Selected iteration: 2
Out[18]:
Earth(allow_linear=None, allow_missing=False, check_every=None,
   enable_pruning=True, endspan=None, endspan_alpha=None, fast_K=None,
   fast_h=None, feature_importance_type=('rss', 'gcv', 'nb_subsets'),
   max_degree=3, max_terms=10, min_search_points=None, minspan=None,
   minspan_alpha=0.5, penalty=None, smooth=None, thresh=None,
   use_fast=None, verbose=True, zero_tol=None)

In [19]:
# We'll use RF estimator to get the variables importantes
rf = RandomForestRegressor()

In [20]:
# Let's fit the model with default paarmeters
rf.fit(x, log(y))


Out[20]:
RandomForestRegressor(bootstrap=True, criterion='mse', max_depth=None,
           max_features='auto', max_leaf_nodes=None,
           min_impurity_decrease=0.0, min_impurity_split=None,
           min_samples_leaf=1, min_samples_split=2,
           min_weight_fraction_leaf=0.0, n_estimators=10, n_jobs=1,
           oob_score=False, random_state=None, verbose=0, warm_start=False)

In [21]:
# Print model trace
print(model2.trace())


Forward Pass
---------------------------------------------------------------
iter  parent  var  knot  mse       terms  gcv    rsq    grsq   
---------------------------------------------------------------
0     -       -    -     0.166752  1      0.167  0.000  0.000  
1     0       12   89    0.052664  3      0.054  0.684  0.678  
2     0       5    84    0.044882  5      0.047  0.731  0.720  
3     1       0    474   0.036055  7      0.038  0.784  0.770  
4     0       7    379   0.031645  9      0.034  0.810  0.794  
5     4       12   423   0.028962  11     0.032  0.826  0.808  
---------------------------------------------------------------
Stopping Condition 0: Reached maximum number of terms

Pruning Pass
--------------------------------------------
iter  bf  terms  mse   gcv    rsq    grsq   
--------------------------------------------
0     -   11     0.03  0.032  0.826  0.808  
1     10  10     0.03  0.032  0.826  0.810  
2     4   9      0.03  0.032  0.826  0.811  
3     7   8      0.03  0.032  0.823  0.811  
4     2   7      0.03  0.033  0.814  0.802  
5     5   6      0.03  0.035  0.800  0.789  
6     9   5      0.04  0.037  0.789  0.780  
7     8   4      0.04  0.042  0.755  0.747  
8     6   3      0.05  0.049  0.715  0.709  
9     3   2      0.06  0.065  0.617  0.613  
10    1   1      0.17  0.167  0.000  0.000  
--------------------------------------------
Selected iteration: 2


In [22]:
# Print the summary of the model
print(model2.summary())


Earth Model
----------------------------------------------------
Basis Function                Pruned  Coefficient   
----------------------------------------------------
(Intercept)                   No      3.24538       
h(lstat-5.7)                  No      -0.0475761    
h(5.7-lstat)                  No      0.0799476     
h(rm-6.389)                   No      0.276175      
h(6.389-rm)                   Yes     None          
h(crim-8.05579)*h(lstat-5.7)  No      -0.000540323  
h(8.05579-crim)*h(lstat-5.7)  No      0.00245139    
h(dis-1.3861)                 No      -0.0119719    
h(1.3861-dis)                 No      2.06961       
h(lstat-23.29)*h(6.389-rm)    No      0.0264174     
h(23.29-lstat)*h(6.389-rm)    Yes     None          
----------------------------------------------------
MSE: 0.0290, GCV: 0.0316, RSQ: 0.8260, GRSQ: 0.8114

In [23]:
# Let's see the most important variables
print(model2.summary_feature_importances(sort_by='gcv'))


            nb_subsets    gcv    rss
lstat       0.42          0.78   0.78   
rm          0.25          0.12   0.12   
crim        0.17          0.06   0.06   
dis         0.17          0.04   0.04   
black       0.00          0.00   0.00   
ptratio     0.00          0.00   0.00   
tax         0.00          0.00   0.00   
rad         0.00          0.00   0.00   
age         0.00          0.00   0.00   
nox         0.00          0.00   0.00   
chas        0.00          0.00   0.00   
indus       0.00          0.00   0.00   
zn          0.00          0.00   0.00   


In [24]:
importances = model2.feature_importances_
importances['random_forest'] = rf.feature_importances_
random_values = importances['random_forest']
gcv_values = importances['gcv']

In [25]:
# Coeficients of the model
model2.coef_


Out[25]:
array([[  3.24537739e+00,  -4.75761009e-02,   7.99475897e-02,
          2.76174514e-01,  -5.40322617e-04,   2.45139029e-03,
         -1.19719092e-02,   2.06961106e+00,   2.64174196e-02]])

In [26]:
inputs = numpy.arange(x.shape[1])
inputs


Out[26]:
array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12])

In [27]:
fig = plt.figure(figsize=(20,5))

plt.subplot(1,2,1)
plt.bar(inputs, gcv_values)
plt.title("Earth")
plt.xticks(inputs)
plt.xlabel("Variable")
plt.ylabel("Importance")

plt.subplot(1,2,2)
plt.bar(inputs, random_values)
plt.title("RandomForestRegressor")
plt.xticks(inputs)
plt.xlabel("Variable")
plt.ylabel("Importance")


Out[27]:
<matplotlib.text.Text at 0x110247c90>

In [28]:
# List of variables
xlabel


Out[28]:
['crim',
 'zn',
 'indus',
 'chas',
 'nox',
 'rm',
 'age',
 'dis',
 'rad',
 'tax',
 'ptratio',
 'black',
 'lstat']

In [29]:
# See the EXACT data that needed to put inside the model
x[1:2]


Out[29]:
crim zn indus chas nox rm age dis rad tax ptratio black lstat
1 0.02731 0.0 7.07 0 0.469 6.421 78.9 4.9671 2 242 17.8 396.9 9.14

In [30]:
# Return sympy expression 
print('Resulting sympy expression:')
print(export.export_sympy(model))


Resulting sympy expression:
-0.586276615586754*nox - 0.0292830837707757*ptratio + 0.0203696402408519*rad - 0.000525081689794919*tax - 0.00144925907154959*Max(0, -black + 179.36) - 0.00055440571664812*Max(0, black - 179.36) + 0.0406624012969403*Max(0, -crim + 24.8017) - 0.0274698195965718*Max(0, crim - 24.8017) + 0.0218641874217271*Max(0, crim - 9.96654) + 1.41847368743171*Max(0, -dis + 1.5106) - 0.0353597811251357*Max(0, dis - 1.5106) + 0.0564322709192822*Max(0, -lstat + 30.59) + 0.0291543402681494*Max(0, lstat - 5.68) - 0.313018656956002*Max(0, rm - 7.82) + 0.270930341227841*Max(0, rm - 6.383) + 1.92888707548532