Building a Classification Decision Tree with [R] --in progress--

This is an introduction to Decision Trees using the [R] language. Using the iris dataset we'll build a tree to identify, based on the width and length of Sepal and Pedal, the Species of the flower.

For this purpose we'll use the caret library, from which we'll load up the rpart module for Recursive partitioning for classification and survival trees.

We'll also need the rattle package to do some fancy plots for evaluating the model.

Let's start!


In [1]:
library(datasets)
library(caret)
library(rattle)


Loading required package: lattice
Loading required package: ggplot2
Loading required package: RGtk2
Rattle: A free graphical interface for data mining with R.
Version 3.5.0 Copyright (c) 2006-2015 Togaware Pty Ltd.
Type 'rattle()' to shake, rattle, and roll your data.

1. Get the dataset


In [2]:
indata <- datasets::iris

In [3]:
head(indata)


Out[3]:
Sepal.LengthSepal.WidthPetal.LengthPetal.WidthSpecies
15.13.51.40.2setosa
24.931.40.2setosa
34.73.21.30.2setosa
44.63.11.50.2setosa
553.61.40.2setosa
65.43.91.70.4setosa

2. Explore the dataset

Make pair plots and density plots to find the relations of your data.. One can simply do the pairs plot of the base system:


In [5]:
pairs(~. , data=indata)


or make it a bit more pretty with featurePlot (caret)


In [7]:
featurePlot(x=indata[, 1:4], 
                        y= indata$Species, 
                        plot="pairs",
                        ## Add a key at the top
                        auto.key = list(columns=3))


One can do the same plot but adding an ellipse to cluster the events.


In [8]:
featurePlot(x=indata[, 1:4], 
                              y= indata$Species, 
                              plot="ellipse",
                              ## Add a key at the top
                              auto.key = list(columns=3))


Do density plots


In [9]:
featurePlot(x=indata[, 1:4], 
                                y= indata$Species, 
                                plot="density",
                                ## Pass in options to xyplot() to 
                                ## make it prettier
                                scales = list(x = list(relation="free"),
                                              y = list(relation="free")),
                                adjust = 1.5,
                                pch = "|",
                                layout = c(4, 1),
                                auto.key = list(columns = 3))


.. box plots...


In [10]:
featurePlot(x = iris[, 1:4],
                             y = iris$Species,
                             plot = "box",
                             ## Pass in options to bwplot() 
                             scales = list(y = list(relation="free"),
                                           x = list(rot = 90)),
                             layout = c(4,1 ),
                             auto.key = list(columns = 2))


3. Set the seed


In [11]:
set.seed(1987)

4. Split the dataset into training and testing sets

To perform the split, I will sample out of a linear space of continuous integers baring the size nrows(indata) a percentage of the total; let's say 60%. These random 90 integers will be the row indices that I will use in my training sample. The rest are for my testing.

Let's skip CV and validation samples for the moment...


In [12]:
train_indx <- sample(nrow(indata), floor(nrow(indata)*0.6))  # sample(nrow(indata) = 150
                                                             # floor(nrow(indata)*0.6 = 90

In [13]:
train_indx


Out[13]:
  1. 15
  2. 120
  3. 90
  4. 57
  5. 146
  6. 144
  7. 56
  8. 34
  9. 85
  10. 49
  11. 142
  12. 104
  13. 100
  14. 25
  15. 41
  16. 64
  17. 19
  18. 88
  19. 108
  20. 50
  21. 40
  22. 2
  23. 70
  24. 71
  25. 63
  26. 37
  27. 6
  28. 27
  29. 83
  30. 73
  31. 86
  32. 24
  33. 53
  34. 52
  35. 47
  36. 38
  37. 136
  38. 89
  39. 32
  40. 106
  41. 76
  42. 77
  43. 39
  44. 117
  45. 43
  46. 17
  47. 137
  48. 54
  49. 133
  50. 48
  51. 62
  52. 125
  53. 111
  54. 26
  55. 124
  56. 30
  57. 105
  58. 134
  59. 3
  60. 94
  61. 68
  62. 99
  63. 1
  64. 107
  65. 60
  66. 22
  67. 132
  68. 59
  69. 82
  70. 113
  71. 7
  72. 114
  73. 141
  74. 74
  75. 91
  76. 96
  77. 18
  78. 4
  79. 112
  80. 138
  81. 8
  82. 29
  83. 130
  84. 66
  85. 9
  86. 79
  87. 149
  88. 21
  89. 84
  90. 122

In [14]:
train_sample <- indata[train_indx, ]

In [16]:
test_sample <- indata[-train_indx, ]

For example my test sample now holds...


In [17]:
test_sample


Out[17]:
Sepal.LengthSepal.WidthPetal.LengthPetal.WidthSpecies
553.61.40.2setosa
104.93.11.50.1setosa
115.43.71.50.2setosa
124.83.41.60.2setosa
134.831.40.1setosa
144.331.10.1setosa
165.74.41.50.4setosa
205.13.81.50.3setosa
234.63.610.2setosa
285.23.51.50.2setosa
314.83.11.60.2setosa
335.24.11.50.1setosa
354.93.11.50.2setosa
3653.21.20.2setosa
424.52.31.30.3setosa
4453.51.60.6setosa
455.13.81.90.4setosa
464.831.40.3setosa
5173.24.71.4versicolor
556.52.84.61.5versicolor
584.92.43.31versicolor
61523.51versicolor
655.62.93.61.3versicolor
675.634.51.5versicolor
696.22.24.51.5versicolor
726.12.841.3versicolor
756.42.94.31.3versicolor
786.7351.7versicolor
805.72.63.51versicolor
815.52.43.81.1versicolor
876.73.14.71.5versicolor
926.134.61.4versicolor
935.82.641.2versicolor
955.62.74.21.3versicolor
975.72.94.21.3versicolor
986.22.94.31.3versicolor
1016.33.362.5virginica
1025.82.75.11.9virginica
1037.135.92.1virginica
1096.72.55.81.8virginica
1107.23.66.12.5virginica
1155.82.85.12.4virginica
1166.43.25.32.3virginica
1187.73.86.72.2virginica
1197.72.66.92.3virginica
1216.93.25.72.3virginica
1237.72.86.72virginica
1267.23.261.8virginica
1276.22.84.81.8virginica
1286.134.91.8virginica
1296.42.85.62.1virginica
1317.42.86.11.9virginica
1356.12.65.61.4virginica
139634.81.8virginica
1406.93.15.42.1virginica
1435.82.75.11.9virginica
1456.73.35.72.5virginica
1476.32.551.9virginica
1486.535.22virginica
1505.935.11.8virginica

5. Create the formula to train the model

Now to train the model we'll use the train function. This needs a formula object specifying which variable is the Y and which are the X's in the form of Y ~ X1 + X2 + ...

Also, the string with the name of the model is needed to be specified. To find out which model names are included in the caret package simply to :


In [18]:
names(getModelInfo())


Out[18]:
  1. 'ada'
  2. 'AdaBag'
  3. 'AdaBoost.M1'
  4. 'amdai'
  5. 'ANFIS'
  6. 'avNNet'
  7. 'awnb'
  8. 'awtan'
  9. 'bag'
  10. 'bagEarth'
  11. 'bagEarthGCV'
  12. 'bagFDA'
  13. 'bagFDAGCV'
  14. 'bartMachine'
  15. 'bayesglm'
  16. 'bdk'
  17. 'binda'
  18. 'blackboost'
  19. 'Boruta'
  20. 'brnn'
  21. 'BstLm'
  22. 'bstSm'
  23. 'bstTree'
  24. 'C5.0'
  25. 'C5.0Cost'
  26. 'C5.0Rules'
  27. 'C5.0Tree'
  28. 'cforest'
  29. 'chaid'
  30. 'CSimca'
  31. 'ctree'
  32. 'ctree2'
  33. 'cubist'
  34. 'DENFIS'
  35. 'dnn'
  36. 'dwdLinear'
  37. 'dwdPoly'
  38. 'dwdRadial'
  39. 'earth'
  40. 'elm'
  41. 'enet'
  42. 'enpls.fs'
  43. 'enpls'
  44. 'evtree'
  45. 'extraTrees'
  46. 'fda'
  47. 'FH.GBML'
  48. 'FIR.DM'
  49. 'foba'
  50. 'FRBCS.CHI'
  51. 'FRBCS.W'
  52. 'FS.HGD'
  53. 'gam'
  54. 'gamboost'
  55. 'gamLoess'
  56. 'gamSpline'
  57. 'gaussprLinear'
  58. 'gaussprPoly'
  59. 'gaussprRadial'
  60. 'gbm'
  61. 'gcvEarth'
  62. 'GFS.FR.MOGUL'
  63. 'GFS.GCCL'
  64. 'GFS.LT.RS'
  65. 'GFS.THRIFT'
  66. 'glm'
  67. 'glmboost'
  68. 'glmnet'
  69. 'glmStepAIC'
  70. 'gpls'
  71. 'hda'
  72. 'hdda'
  73. 'HYFIS'
  74. 'icr'
  75. 'J48'
  76. 'JRip'
  77. 'kernelpls'
  78. 'kknn'
  79. 'knn'
  80. 'krlsPoly'
  81. 'krlsRadial'
  82. 'lars'
  83. 'lars2'
  84. 'lasso'
  85. 'lda'
  86. 'lda2'
  87. 'leapBackward'
  88. 'leapForward'
  89. 'leapSeq'
  90. 'Linda'
  91. 'lm'
  92. 'lmStepAIC'
  93. 'LMT'
  94. 'loclda'
  95. 'logicBag'
  96. 'LogitBoost'
  97. 'logreg'
  98. 'lssvmLinear'
  99. 'lssvmPoly'
  100. 'lssvmRadial'
  101. 'lvq'
  102. 'M5'
  103. 'M5Rules'
  104. 'mda'
  105. 'Mlda'
  106. 'mlp'
  107. 'mlpWeightDecay'
  108. 'multinom'
  109. 'nb'
  110. 'nbDiscrete'
  111. 'nbSearch'
  112. 'neuralnet'
  113. 'nnet'
  114. 'nnls'
  115. 'nodeHarvest'
  116. 'oblique.tree'
  117. 'OneR'
  118. 'ORFlog'
  119. 'ORFpls'
  120. 'ORFridge'
  121. 'ORFsvm'
  122. 'ownn'
  123. 'pam'
  124. 'parRF'
  125. 'PART'
  126. 'partDSA'
  127. 'pcaNNet'
  128. 'pcr'
  129. 'pda'
  130. 'pda2'
  131. 'penalized'
  132. 'PenalizedLDA'
  133. 'plr'
  134. 'pls'
  135. 'plsRglm'
  136. 'polr'
  137. 'ppr'
  138. 'protoclass'
  139. 'pythonKnnReg'
  140. 'qda'
  141. 'QdaCov'
  142. 'qrf'
  143. 'qrnn'
  144. 'ranger'
  145. 'rbf'
  146. 'rbfDDA'
  147. 'rda'
  148. 'relaxo'
  149. 'rf'
  150. 'rFerns'
  151. 'RFlda'
  152. 'rfRules'
  153. 'ridge'
  154. 'rknn'
  155. 'rknnBel'
  156. 'rlm'
  157. 'rmda'
  158. 'rocc'
  159. 'rotationForest'
  160. 'rotationForestCp'
  161. 'rpart'
  162. 'rpart2'
  163. 'rpartCost'
  164. 'rqlasso'
  165. 'rqnc'
  166. 'RRF'
  167. 'RRFglobal'
  168. 'rrlda'
  169. 'RSimca'
  170. 'rvmLinear'
  171. 'rvmPoly'
  172. 'rvmRadial'
  173. 'SBC'
  174. 'sda'
  175. 'sddaLDA'
  176. 'sddaQDA'
  177. 'sdwd'
  178. 'simpls'
  179. 'SLAVE'
  180. 'slda'
  181. 'smda'
  182. 'snn'
  183. 'sparseLDA'
  184. 'spls'
  185. 'stepLDA'
  186. 'stepQDA'
  187. 'superpc'
  188. 'svmBoundrangeString'
  189. 'svmExpoString'
  190. 'svmLinear'
  191. 'svmLinear2'
  192. 'svmPoly'
  193. 'svmRadial'
  194. 'svmRadialCost'
  195. 'svmRadialWeights'
  196. 'svmSpectrumString'
  197. 'tan'
  198. 'tanSearch'
  199. 'treebag'
  200. 'vbmpRadial'
  201. 'widekernelpls'
  202. 'WM'
  203. 'wsrf'
  204. 'xgbLinear'
  205. 'xgbTree'
  206. 'xyf'

So first let's create the formula


In [19]:
formula <- as.formula(Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width)

And train the model by


In [20]:
tr <- train(formula, train_sample, method="rpart") #, 
           # minsplit=2, minbucket=1, cp=0.001, maxdepth=8)


Loading required package: rpart

Additional settings/parameters of the model can be set in the train function. For example the minsplit in the rpart model is the minimum number of observations a node must have to be splitted. Similarly, minbucket is the minimum number of observations the leaf node must have, cp is the complexity parameter (any split that does not decrease the overall lack of fit by a factor of cp is not attempted). Finally, maxdepth is the size of the tree (Root node being 0).

N.B. The complexity parameter is used to control the size of the decision tree and to select the optimal tree size. This is useful if you want to look at the values of CP for various tree sizes. The default value is 0.01.

Printing out the tr object gives us some information about the model


In [29]:
print(tr)


CART 

90 samples
 4 predictor
 3 classes: 'setosa', 'versicolor', 'virginica' 

No pre-processing
Resampling: Bootstrapped (25 reps) 
Summary of sample sizes: 90, 90, 90, 90, 90, 90, ... 
Resampling results across tuning parameters:

  cp         Accuracy   Kappa      Accuracy SD  Kappa SD  
  0.0000000  0.9290799  0.8917610  0.02847877   0.04353444
  0.3793103  0.8589198  0.7850328  0.11650106   0.17618093
  0.5517241  0.4458745  0.1901532  0.17222404   0.24096407

Accuracy was used to select the optimal model using  the largest value.
The final value used for the model was cp = 0. 

In [ ]:


In [27]:
summary(tr)


Call:
rpart(formula = .outcome ~ ., data = list(Sepal.Length = c(5.8, 
6, 5.5, 6.3, 6.7, 6.8, 5.7, 5.5, 5.4, 5.3, 6.9, 6.3, 5.7, 4.8, 
5, 6.1, 5.7, 6.3, 7.3, 5, 5.1, 4.9, 5.6, 5.9, 6, 5.5, 5.4, 5, 
5.8, 6.3, 6, 5.1, 6.9, 6.4, 5.1, 4.9, 7.7, 5.6, 5.4, 7.6, 6.6, 
6.8, 4.4, 6.5, 4.4, 5.4, 6.3, 5.5, 6.4, 4.6, 5.9, 6.7, 6.5, 5, 
6.3, 4.7, 6.5, 6.3, 4.7, 5, 5.8, 5.1, 5.1, 4.9, 5.2, 5.1, 7.9, 
6.6, 5.5, 6.8, 4.6, 5.7, 6.7, 6.1, 5.5, 5.7, 5.1, 4.6, 6.4, 6.4, 
5, 5.2, 7.2, 6.7, 4.4, 6, 6.2, 5.4, 6, 5.6), Sepal.Width = c(4, 
2.2, 2.5, 3.3, 3, 3.2, 2.8, 4.2, 3, 3.7, 3.1, 2.9, 2.8, 3.4, 
3.5, 2.9, 3.8, 2.3, 2.9, 3.3, 3.4, 3, 2.5, 3.2, 2.2, 3.5, 3.9, 
3.4, 2.7, 2.5, 3.4, 3.3, 3.1, 3.2, 3.8, 3.6, 3, 3, 3.4, 3, 3, 
2.8, 3, 3, 3.2, 3.9, 3.4, 2.3, 2.8, 3.2, 3, 3.3, 3.2, 3, 2.7, 
3.2, 3, 2.8, 3.2, 2.3, 2.7, 2.5, 3.5, 2.5, 2.7, 3.7, 3.8, 2.9, 
2.4, 3, 3.4, 2.5, 3.1, 2.8, 2.6, 3, 3.5, 3.1, 2.7, 3.1, 3.4, 
3.4, 3, 3.1, 2.9, 2.9, 3.4, 3.4, 2.7, 2.8), Petal.Length = c(1.2, 
5, 4, 4.7, 5.2, 5.9, 4.5, 1.4, 4.5, 1.5, 5.1, 5.6, 4.1, 1.9, 
1.3, 4.7, 1.7, 4.4, 6.3, 1.4, 1.5, 1.4, 3.9, 4.8, 4, 1.3, 1.7, 
1.6, 3.9, 4.9, 4.5, 1.7, 4.9, 4.5, 1.6, 1.4, 6.1, 4.1, 1.5, 6.6, 
4.4, 4.8, 1.3, 5.5, 1.3, 1.3, 5.6, 4, 5.6, 1.4, 4.2, 5.7, 5.1, 
1.6, 4.9, 1.6, 5.8, 5.1, 1.3, 3.3, 4.1, 3, 1.4, 4.5, 3.9, 1.5, 
6.4, 4.6, 3.7, 5.5, 1.4, 5, 5.6, 4.7, 4.4, 4.2, 1.4, 1.5, 5.3, 
5.5, 1.5, 1.4, 5.8, 4.4, 1.4, 4.5, 5.4, 1.7, 5.1, 4.9), Petal.Width = c(0.2, 
1.5, 1.3, 1.6, 2.3, 2.3, 1.3, 0.2, 1.5, 0.2, 2.3, 1.8, 1.3, 0.2, 
0.3, 1.4, 0.3, 1.3, 1.8, 0.2, 0.2, 0.2, 1.1, 1.8, 1, 0.2, 0.4, 
0.4, 1.2, 1.5, 1.6, 0.5, 1.5, 1.5, 0.2, 0.1, 2.3, 1.3, 0.4, 2.1, 
1.4, 1.4, 0.2, 1.8, 0.2, 0.4, 2.4, 1.3, 2.2, 0.2, 1.5, 2.1, 2, 
0.2, 1.8, 0.2, 2.2, 1.5, 0.2, 1, 1, 1.1, 0.2, 1.7, 1.4, 0.4, 
2, 1.3, 1, 2.1, 0.3, 2, 2.4, 1.2, 1.2, 1.2, 0.3, 0.2, 1.9, 1.8, 
0.2, 0.2, 1.6, 1.4, 0.2, 1.5, 2.3, 0.2, 1.6, 2), .outcome = c(1, 
3, 2, 2, 3, 3, 2, 1, 2, 1, 3, 3, 2, 1, 1, 2, 1, 2, 3, 1, 1, 1, 
2, 2, 2, 1, 1, 1, 2, 2, 2, 1, 2, 2, 1, 1, 3, 2, 1, 3, 2, 2, 1, 
3, 1, 1, 3, 2, 3, 1, 2, 3, 3, 1, 3, 1, 3, 3, 1, 2, 2, 2, 1, 3, 
2, 1, 3, 2, 2, 3, 1, 3, 3, 2, 2, 2, 1, 1, 3, 3, 1, 1, 3, 2, 1, 
2, 3, 1, 2, 3)), control = list(minsplit = 20, minbucket = 7, 
    cp = 0, maxcompete = 4, maxsurrogate = 5, usesurrogate = 2, 
    surrogatestyle = 0, maxdepth = 30, xval = 0))
  n= 90 

         CP nsplit  rel error
1 0.5517241      0 1.00000000
2 0.3793103      1 0.44827586
3 0.0000000      2 0.06896552

Variable importance
Petal.Length  Petal.Width Sepal.Length  Sepal.Width 
          33           30           23           14 

Node number 1: 90 observations,    complexity param=0.5517241
  predicted class=setosa      expected loss=0.6444444  P(node) =1
    class counts:    32    32    26
   probabilities: 0.356 0.356 0.289 
  left son=2 (32 obs) right son=3 (58 obs)
  Primary splits:
      Petal.Length < 2.45 to the left,  improve=31.04368, (0 missing)
      Petal.Width  < 0.75 to the left,  improve=31.04368, (0 missing)
      Sepal.Length < 5.45 to the left,  improve=19.00606, (0 missing)
      Sepal.Width  < 3.15 to the right, improve=13.65276, (0 missing)
  Surrogate splits:
      Petal.Width  < 0.75 to the left,  agree=1.000, adj=1.000, (0 split)
      Sepal.Length < 5.45 to the left,  agree=0.900, adj=0.719, (0 split)
      Sepal.Width  < 3.15 to the right, agree=0.833, adj=0.531, (0 split)

Node number 2: 32 observations
  predicted class=setosa      expected loss=0  P(node) =0.3555556
    class counts:    32     0     0
   probabilities: 1.000 0.000 0.000 

Node number 3: 58 observations,    complexity param=0.3793103
  predicted class=versicolor  expected loss=0.4482759  P(node) =0.6444444
    class counts:     0    32    26
   probabilities: 0.000 0.552 0.448 
  left son=6 (30 obs) right son=7 (28 obs)
  Primary splits:
      Petal.Length < 4.85 to the left,  improve=21.399180, (0 missing)
      Petal.Width  < 1.65 to the left,  improve=21.302400, (0 missing)
      Sepal.Length < 6.15 to the left,  improve= 9.100647, (0 missing)
      Sepal.Width  < 2.95 to the left,  improve= 2.104231, (0 missing)
  Surrogate splits:
      Petal.Width  < 1.75 to the left,  agree=0.879, adj=0.750, (0 split)
      Sepal.Length < 6.15 to the left,  agree=0.810, adj=0.607, (0 split)
      Sepal.Width  < 2.95 to the left,  agree=0.638, adj=0.250, (0 split)

Node number 6: 30 observations
  predicted class=versicolor  expected loss=0.03333333  P(node) =0.3333333
    class counts:     0    29     1
   probabilities: 0.000 0.967 0.033 

Node number 7: 28 observations
  predicted class=virginica   expected loss=0.1071429  P(node) =0.3111111
    class counts:     0     3    25
   probabilities: 0.000 0.107 0.893 


In [26]:
post(tr$finalModel,file='')



In [ ]: