The Wise Old Tree

Gandalf is one of the primary fictional character in J.R.R Tolkien Lord of the Ring. He is a wise old wizard who people turn to for council on matters that are tough to decide. A decision tree is a lot like Gandalf. The greatest gift of decision trees is that it makes extremely hard decisions easy to interpret. It’s what statisticians call white box. We know exactly whats going on when a tree makes a decision. In this tutorial I will explain the mathematics behind the working of a decision tree and how to build a functioning decision tree in R that can predict.

The packages

library(readr) # read in data
library(dplyr) # data manipulation
library(tidyr) # data manipulation
library(ggplot2) # for data viz
library(rpart) # decision tree
library(party) # decision tree

The data

I will be using the Pokemon data set from Kaggle. Pokemon is a media franchise owned by Nintendo. It became famous when the game was turned into a animation series where an ambitious young boy tries to capture creatures/Pokemon who help him win Pokemon-battles. In his quest to become the best Pokemon trainer, he comes across a variety of Pokemon but it was always his dream to capture the legendary Pokemon MewTwo. A Pokemon is legendary when it’s extremely powerful (measured by its attacking or defense style) and is extremely rare.

The Pokemon data set comes with all the Pokemon and their type and stats. We will use a decision tree to decide if the Pokemon is legendary or not. This would also help us realize what factors contribute to deciding if a Pokemon is legendary or not.

pokemon <- read_csv("C:/Users/routh/Desktop/Study Materials/My website/Trees/Pokemon.csv", col_types = cols(`#` = col_skip()))

Data Inspection

Here are the columns in the data with their descriptions:

  • Name: Name of each Pokemon
  • Type 1: Each Pokemon has a type, this determines weakness/resistance to attacks
  • Type 2: Some Pokemon are dual type and have 2
  • Total: sum of all stats that come after this, a general guide to how strong a Pokemon is
  • HP: hit points, or health, defines how much damage a Pokemon can withstand before fainting
  • Attack: the base modifier for normal attacks (eg. Scratch, Punch)
  • Defense: the base damage resistance against normal attacks
  • SP Atk: special attack, the base modifier for special attacks (e.g. fire blast, bubble beam)
  • SP Def: the base damage resistance against special attacks
  • Speed: determines which Pokemon attacks first each round
  • Generation: the generation the pokemon belongs to.
summary(pokemon)
##      Name              Type 1             Type 2              Total      
##  Length:800         Length:800         Length:800         Min.   :180.0  
##  Class :character   Class :character   Class :character   1st Qu.:330.0  
##  Mode  :character   Mode  :character   Mode  :character   Median :450.0  
##                                                           Mean   :435.1  
##                                                           3rd Qu.:515.0  
##                                                           Max.   :780.0  
##        HP             Attack       Defense          Sp. Atk      
##  Min.   :  1.00   Min.   :  5   Min.   :  5.00   Min.   : 10.00  
##  1st Qu.: 50.00   1st Qu.: 55   1st Qu.: 50.00   1st Qu.: 49.75  
##  Median : 65.00   Median : 75   Median : 70.00   Median : 65.00  
##  Mean   : 69.26   Mean   : 79   Mean   : 73.84   Mean   : 72.82  
##  3rd Qu.: 80.00   3rd Qu.:100   3rd Qu.: 90.00   3rd Qu.: 95.00  
##  Max.   :255.00   Max.   :190   Max.   :230.00   Max.   :194.00  
##     Sp. Def          Speed          Generation     Legendary        
##  Min.   : 20.0   Min.   :  5.00   Min.   :1.000   Length:800        
##  1st Qu.: 50.0   1st Qu.: 45.00   1st Qu.:2.000   Class :character  
##  Median : 70.0   Median : 65.00   Median :3.000   Mode  :character  
##  Mean   : 71.9   Mean   : 68.28   Mean   :3.324                     
##  3rd Qu.: 90.0   3rd Qu.: 90.00   3rd Qu.:5.000                     
##  Max.   :230.0   Max.   :180.00   Max.   :6.000
glimpse(pokemon)
## Observations: 800
## Variables: 12
## $ Name       <chr> "Bulbasaur", "Ivysaur", "Venusaur", "VenusaurMega V...
## $ `Type 1`   <chr> "Grass", "Grass", "Grass", "Grass", "Fire", "Fire",...
## $ `Type 2`   <chr> "Poison", "Poison", "Poison", "Poison", NA, NA, "Fl...
## $ Total      <int> 318, 405, 525, 625, 309, 405, 534, 634, 634, 314, 4...
## $ HP         <int> 45, 60, 80, 80, 39, 58, 78, 78, 78, 44, 59, 79, 79,...
## $ Attack     <int> 49, 62, 82, 100, 52, 64, 84, 130, 104, 48, 63, 83, ...
## $ Defense    <int> 49, 63, 83, 123, 43, 58, 78, 111, 78, 65, 80, 100, ...
## $ `Sp. Atk`  <int> 65, 80, 100, 122, 60, 80, 109, 130, 159, 50, 65, 85...
## $ `Sp. Def`  <int> 65, 80, 100, 120, 50, 65, 85, 85, 115, 64, 80, 105,...
## $ Speed      <int> 45, 60, 80, 80, 65, 80, 100, 100, 100, 43, 58, 78, ...
## $ Generation <int> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
## $ Legendary  <chr> "False", "False", "False", "False", "False", "False...

Data transformations

In a classification task, we need to convert response Legendary to a factor along with Generation. I will also create a variable Dual to indicate if a Pokemon has dual nature. After that we can get rid of the irrelevant columns.

pokemon <- pokemon %>%
              mutate(Legendary = as.factor(Legendary),
                     Generation = as.factor(Generation),
                     Dual = factor(ifelse(is.na(`Type 1`) | is.na(`Type 2`),0,1))) %>%
              rename(sp.attack = `Sp. Atk`, sp.defence = `Sp. Def`) %>%
              select(4:13)

The Mechanics

Before moving on to how to build and predict a decision tree in R, we need to understand how a decision tree operates, the important terminologies, how to make it work better and most importantly how it decides what to chose.

Terminology:

Decision Trees or CART (Classification and Regression Trees) have the following elements:

Take a look at this picture.

  • A root node is where the first split in decision occurs. A decision or internal node is the rest of the nodes following the root node where a decision is made. Finally the terminal or leaf node is where the outcomes (based on the previous decisions) are finalized.

  • The initial splitting node is the parent node to it’s children nodes. The children form the branches of the whole tree

Splitting Criterion

We know that splitting occurs at every node. But what splits? Or more important how does it split? The split occurs on the values of the predictor variables. A node splits into two values (which is why it’s also called a binary tree). Technically though, the node can split into more than two values but most trees have binary splits.That brings us to the first kind of tree called the Recursive Partitioning Tree (RP). It does exactly what the name suggests. Refer to this nice picture

If you would think about a decision tree with 2 variables, what a RP does is it partitions the space (containing values of the response) again and again until each small little partition has one kind or class of values of the response. That’s it! That’s all that a decision tree does. Brilliant yet simple! But there is one final hurdle - how does it decide what variable to pick first or at what value of a predictor variable it should ideally split on?

To answer the first question, in an RP tree, there are a number of criterions that can be used to split a node. Information Gain is a popular splitting criterion that uses entropy. Entropy is a fundamental concept of Information Theory. Tbh, Info Theory is a huge topic by itself and (much) beyond the scope of this tutorial (and myself) for sure. Intuitively, entropy is how much information is missing. If you asked me where I am from and I answer, “India”; there is lots of “entropy” in this answer. Translate this to DTs and splits, an entropy for a variable is how much information you are loosing if you pick a variable to split on out of many other variables. Mathematically, this is:

\[ E = \sum_{i=1}^c - p_i log_2(p_i) \]

Lets understand with an example. Lets say the repose variable is if I go out today or not. Arbitrarily, out of 14 instances (Lets say), I assign 9 to yes (I go out) and 5 to no. So what is the entropy for these decisions?

entropy <- function(vector){
  p = vector/sum(vector) # gives you the p_i's
  sum(-(p*log2(p))) # the sum
}

entropy(c(9,5))
## [1] 0.940286

Now lets ramp it up a little. Now I decide to go out depending on (a) If it is sunny or not or (b) if I’m feeling lazy or not. Here the two 2x2 table to help you visualize:

set.seed(4)
fake.data.sunny <- data.frame(sunny = sample(c(1,0),14,replace = T),
                              go.out = sample(c(1,0),14,replace = T))

with(fake.data.sunny,table(sunny,go.out))
##      go.out
## sunny 0 1
##     0 6 1
##     1 4 3
fake.data.lazy <- data.frame(lazy = sample(c(1,0),14,replace = T),
                             go.out = sample(c(1,0),14,replace = T))

with(fake.data.lazy,table(lazy,go.out))
##     go.out
## lazy 0 1
##    0 7 1
##    1 3 3

Entropy for two attributes (go out and sunny or go out and lazy) is given by:

\[ Entropy(Target|Variable) = \sum Entropy(Var)*Probability(Var) \]

This helpful function computes the entropy for a variable:

entropy.var <- function(...){  # enter a list of vectors
  list.vector = list(...)
  p = purrr::map_dbl(list.vector,~sum(.x))
  prop = p/sum(p)
  sum(prop*(purrr::map_dbl(list.vector,entropy)),na.rm = T)
}

# for sunny
entropy.var(c(6,1),c(4,3))  # 6+1 & 4+3
## [1] 0.7884505
# for lazy
entropy.var(c(7,1),c(3,3)) 
## [1] 0.7391797

Therefore Information Gain is:

\[ Gain(Target|Variable) = Entrpoy(Target) - Entropy(Target|Variable) \]

For sunny it would be \(0.94 - 0.788 = 0.15\) and for lazy it would be \(0.94 - 0.739 = 0.2\). The decision tree now selects the variable with the largest information gain. In this case, me going out or not would depend on if I was feeling lazy first (so true) and then if it was sunny.

To answer the second question of what value it chose to split on? We can use Information Gain here too. For every value of a variable, the Information Gain is calculated and then the maximum IG is chosen as the best possible IG for that variable. This best possible IG is calculated for all variables. After that we would choose the variable with the highest maximum IG and then the best split point (the one with the highest IG) for that variable is chosen as the point of split.

Now, CARTs can also be used to build regression trees, where the outcomes are not categorical. The splitting criterion for such trees is the MSE for a particular predictor (p) on the response variable instead, where MSE is:

\[ MSE_p = 1/n\sum_{i=1}^n (\hat{Y_{i,p}} - Y_{i,p}) \]

Where \(\hat{Y}\) is the mean response value for that region or partition. Obviously here, the variable with the lowest MSE is then chosen as the splitting variable and the value of predictor with minimum possible MSE is chosen as the split point.

Note: The second kind of tree is called Conditional Partitioning Tree. Instead of using IG as splitting criterion, this tree is grown using hypothesis tests at every node and the variable that is most significant is chosen as the candidate variable to split on.

Prunning

A DT grown using the process above may produce good predictions on the training set, but is likely to overfit the data, leading to poor test set performance. A way to overcome the variance (overfitting) is to grow a smaller tree at the cost of lower bias.

One possible way to tackle this is to build the tree only as long as the decrease in the RSS or increase in IG due to each split exceeds some (high) threshold. This strategy will result in smaller trees, but is too “risky” since a seemingly worthless split early on in the tree might be followed by a very good split-that later on.

Another way is to grow the tree to full depth and the prune it back to create a subtree in a way that gives the lowest possible test error rate. Later on we will see how to calculate the CV error for a tree. If we record the CV error for all possible subtrees we can select the subtree with lowest error. However this might be a inefficient procedure especially for a larger tree with alot of internal nodes. With the prune function in R we essentially tell rpart to see if a split is improving the fit by an amount (specified by the CP or cost complexity criterion). If it doesn’t the algorithm will not pursue further splitting.

Prediction

Finally we are ready to implement this procedure in R. We will use the Pokemon data to predict if a Pokemon with specific features is Legendary or not.

Split into train-test

train_test_split <- function(data,percent,seed){
  set.seed(seed)
  rows = sample(nrow(data))
  data <- data[rows,]
  split = round(nrow(data)*percent)
  list(train = data[1:split,], test =  data[(split+1):nrow(data),])
}


list <- train_test_split(pokemon,0.7,123)
train <- list$train; test <- list$test

Recursive partitioning

The recursive partitioning is implemented using the rpart function in the party package. You can also specify the categorical nature of response in method argument. The fancyRpartPlot in the rattle package helps visualize the tree.

rtree.fit <- rpart(formula = Legendary ~.,
                   method = "class",
                   data = train) 

rattle::fancyRpartPlot(rtree.fit)

This is R’s representation of the decision tree. We can see that not all variables were used to grow the tree. The most important variable was total points. Let’s take the first node. 92% of observations are not Legendary with the total points less than 580. The label (True/False) above these proportions indicates the way the nodes are voting while the numbers below indicates the composition of the node.

Let’s move on to the next node. If the total is less than 580, you move left and you reach the second node. Here 86% (of the total) of Pokemon are not Legendary. The tree stops because this node correctly votes all these Pokemon as non legendary (indicated by the 1 against 0). If we move to node 3 however, where the total is greater than 580, out of the 14% (do the math 86+14 = 100% of the data, which shows the binary partition) of the passengers, about 42% was correctly voted and the 58% was incorrectly voted as non legendary. Therefore, this node splits again. This process continues until all nodes (see the nodes below) are as pure as possible. Theoretically one could keep on going until all nodes are 100% pure.

Controlled split

You can further control the splitting process using the rpart.control option and parms option. Here is a demonstration:

control <- rpart.control(minsplit = 5, # min #obs to attempt splitting
                         minbucket = 15, 
                         cp = 0.07,
                         maxcompete = 4,
                         xval = 5, # number of cross validations
                         maxdepth = 15) # how deep should the tree grow on any node

control.tree <-  rpart(formula = Legendary ~.,
                   method = "class",
                   data = train,
                   parms = list(prior = c(0.2,0.8), split = "information"),  # default is gini
                   control = control) 

rattle::fancyRpartPlot(control.tree)

Conditional partitioning

Conditional partitioning are implemented using the ctree function.

ctree.fit <- ctree(formula = Legendary ~.,
                   data = train)

Legendary or not?

Predictions on the test set are done using the generic predict function.

  • Recursive Tree Fit
recursive.prediction <- predict(rtree.fit, newdata = test, type = "class")
  • Conditional Tree Fit
conditional.prediction <- predict(ctree.fit, newdata = test)

Prediction accuracy

Its pretty straightforward to compute the prediction accuracy. In this case, the accuracy in prediction is very high (0.94).

predict.df <- data.frame(
  predictions = recursive.prediction,
  actual = test$Legendary
)

(t <- table(predict.df))
##            actual
## predictions False True
##       False   214    8
##       True      6   12
print(paste('Accuracy:',round((t[1]+t[4])/(t[1]+t[2]+t[3]+t[4]),2)))
## [1] "Accuracy: 0.94"

Using ROC plot

If the outcome is probabilistic we could build a ROC curve to measure the accuracy. This version makes probabilistic predictions and the ROC is plotted to measure Area under the curve. The prediction and performance function from the ROCR package is used to measure the TPR and FPR. In this case we see how good the classification was by the strong ROC curve.

rtree.fit.prob <- rpart(formula = Legendary ~.,
                   method = "anova",
                   data = train)
recursive.prediction.prob <- predict(rtree.fit, newdata = test)


rtree_roc <- rtree.fit.prob %>%
      predict(newdata = test) %>%
      ROCR::prediction(test$Legendary) %>%
      ROCR::performance("tpr", "fpr")

roc_df <- data.frame(
  FPR = rtree_roc@x.values[[1]],
  TPR = rtree_roc@y.values[[1]],
  cutoff = rtree_roc@alpha.values[[1]]
)

ggplot(roc_df,aes(x = FPR, y = TPR))+
  geom_point()+
  geom_line()+
  theme_minimal()

Pruning

Finally one can prune a tree using the prune function.

my3cols <- c("#E7B800", "#2E9FDF", "#FC4E07")
my2cols <- c("#2E9FDF", "#FC4E07")

cp <- data.frame(rtree.fit$cptable)
best.cp <- cp[which.min(cp$xerror),"CP"]

cp %>%
  select(CP,nsplit,xerror,rel.error)%>%
  gather(key,value,-nsplit)%>%
  ggplot(aes(x = nsplit, y = value, col = key))+
  geom_point(size = 2)+
  geom_line(size = 1.1)+
  scale_color_manual(values = my3cols)+
  theme_minimal()

The results show that 3 splits gives us the minimum xerror (related to the PRESS error) at a CP value of 0.06 approximately. We can prune the tree using this CP value. If we choose to use relative error instead, we would choose 6 splits and a CP of 0.01 instead. Different error measures gives us different CP values and different number of splits. We will create the pruned tree by supplying the minimum CP value based on xerror and get the prediction accuracy.

prune.tree <- prune(rtree.fit, cp = best.cp)

rattle::fancyRpartPlot(prune.tree)

predictions.prune <- predict(prune.tree, newdata = test, type = "class")

predict.df.prune <- data.frame(
  predictions = predictions.prune,
  actual = test$Legendary
)

(t <- table(predict.df))
##            actual
## predictions False True
##       False   214    8
##       True      6   12
print(paste('Accuracy:',round((t[1]+t[4])/(t[1]+t[2]+t[3]+t[4]),2)))
## [1] "Accuracy: 0.94"

In this case, prunning doesn’t really improve the accuracy in prediction.

Using Cross-Validation

Resampling methods such as Cross-Validation can also be used on the training data to pick out a CP value to prune and fit the best tree. The most convenient way to perform cross-validation is to use the train function from the caret package. The call to the object outputs the best CP (which in this case is picked according to the accuracy of 1-missclassification rate) value that can then be used to prune the tree.

tc <- caret::trainControl("cv",50,classProbs = T)

train.rpart <- caret::train(Legendary ~., 
                             data = train, 
                             method="rpart",
                             trControl=tc, 
                             tuneLength = 10, 
                             parms=list(split='information'))

train.rpart
## CART 
## 
## 560 samples
##   9 predictor
##   2 classes: 'False', 'True' 
## 
## No pre-processing
## Resampling: Cross-Validated (50 fold) 
## Summary of sample sizes: 549, 549, 549, 550, 549, 549, ... 
## Resampling results across tuning parameters:
## 
##   cp          Accuracy   Kappa    
##   0.00000000  0.9359091  0.5002696
##   0.02962963  0.9322424  0.4419367
##   0.05925926  0.9392121  0.5448884
##   0.08888889  0.9410303  0.5786528
##   0.11851852  0.9448182  0.6275213
##   0.14814815  0.9448182  0.6404524
##   0.17777778  0.9464848  0.6612857
##   0.20740741  0.9358788  0.7028026
##   0.23703704  0.9393636  0.7327317
##   0.26666667  0.8984545  0.2406340
## 
## Accuracy was used to select the optimal model using  the largest value.
## The final value used for the model was cp = 0.1777778.

Related