Gandalf is one of the primary fictional character in J.R.R Tolkien Lord of the Ring. He is a wise old wizard who people turn to for council on matters that are tough to decide. A decision tree is a lot like Gandalf. The greatest gift of decision trees is that it makes extremely hard decisions easy to interpret. It’s what statisticians call white box. We know exactly whats going on when a tree makes a decision. In this tutorial I will explain the mathematics behind the working of a decision tree and how to build a functioning decision tree in R that can predict.
The packages
library(readr) # read in data
library(dplyr) # data manipulation
library(tidyr) # data manipulation
library(ggplot2) # for data viz
library(rpart) # decision tree
library(party) # decision tree
The data
I will be using the Pokemon data set from Kaggle. Pokemon is a media franchise owned by Nintendo. It became famous when the game was turned into a animation series where an ambitious young boy tries to capture creatures/Pokemon who help him win Pokemon-battles. In his quest to become the best Pokemon trainer, he comes across a variety of Pokemon but it was always his dream to capture the legendary Pokemon MewTwo. A Pokemon is legendary when it’s extremely powerful (measured by its attacking or defense style) and is extremely rare.
The Pokemon data set comes with all the Pokemon and their type and stats. We will use a decision tree to decide if the Pokemon is legendary or not. This would also help us realize what factors contribute to deciding if a Pokemon is legendary or not.
pokemon <- read_csv("C:/Users/routh/Desktop/Study Materials/My website/Trees/Pokemon.csv", col_types = cols(`#` = col_skip()))
Data Inspection
Here are the columns in the data with their descriptions:
- Name: Name of each Pokemon
- Type 1: Each Pokemon has a type, this determines weakness/resistance to attacks
- Type 2: Some Pokemon are dual type and have 2
- Total: sum of all stats that come after this, a general guide to how strong a Pokemon is
- HP: hit points, or health, defines how much damage a Pokemon can withstand before fainting
- Attack: the base modifier for normal attacks (eg. Scratch, Punch)
- Defense: the base damage resistance against normal attacks
- SP Atk: special attack, the base modifier for special attacks (e.g. fire blast, bubble beam)
- SP Def: the base damage resistance against special attacks
- Speed: determines which Pokemon attacks first each round
- Generation: the generation the pokemon belongs to.
summary(pokemon)
## Name Type 1 Type 2 Total
## Length:800 Length:800 Length:800 Min. :180.0
## Class :character Class :character Class :character 1st Qu.:330.0
## Mode :character Mode :character Mode :character Median :450.0
## Mean :435.1
## 3rd Qu.:515.0
## Max. :780.0
## HP Attack Defense Sp. Atk
## Min. : 1.00 Min. : 5 Min. : 5.00 Min. : 10.00
## 1st Qu.: 50.00 1st Qu.: 55 1st Qu.: 50.00 1st Qu.: 49.75
## Median : 65.00 Median : 75 Median : 70.00 Median : 65.00
## Mean : 69.26 Mean : 79 Mean : 73.84 Mean : 72.82
## 3rd Qu.: 80.00 3rd Qu.:100 3rd Qu.: 90.00 3rd Qu.: 95.00
## Max. :255.00 Max. :190 Max. :230.00 Max. :194.00
## Sp. Def Speed Generation Legendary
## Min. : 20.0 Min. : 5.00 Min. :1.000 Length:800
## 1st Qu.: 50.0 1st Qu.: 45.00 1st Qu.:2.000 Class :character
## Median : 70.0 Median : 65.00 Median :3.000 Mode :character
## Mean : 71.9 Mean : 68.28 Mean :3.324
## 3rd Qu.: 90.0 3rd Qu.: 90.00 3rd Qu.:5.000
## Max. :230.0 Max. :180.00 Max. :6.000
glimpse(pokemon)
## Observations: 800
## Variables: 12
## $ Name <chr> "Bulbasaur", "Ivysaur", "Venusaur", "VenusaurMega V...
## $ `Type 1` <chr> "Grass", "Grass", "Grass", "Grass", "Fire", "Fire",...
## $ `Type 2` <chr> "Poison", "Poison", "Poison", "Poison", NA, NA, "Fl...
## $ Total <int> 318, 405, 525, 625, 309, 405, 534, 634, 634, 314, 4...
## $ HP <int> 45, 60, 80, 80, 39, 58, 78, 78, 78, 44, 59, 79, 79,...
## $ Attack <int> 49, 62, 82, 100, 52, 64, 84, 130, 104, 48, 63, 83, ...
## $ Defense <int> 49, 63, 83, 123, 43, 58, 78, 111, 78, 65, 80, 100, ...
## $ `Sp. Atk` <int> 65, 80, 100, 122, 60, 80, 109, 130, 159, 50, 65, 85...
## $ `Sp. Def` <int> 65, 80, 100, 120, 50, 65, 85, 85, 115, 64, 80, 105,...
## $ Speed <int> 45, 60, 80, 80, 65, 80, 100, 100, 100, 43, 58, 78, ...
## $ Generation <int> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
## $ Legendary <chr> "False", "False", "False", "False", "False", "False...
Data transformations
In a classification task, we need to convert response Legendary
to a factor along with Generation
. I will also create a variable Dual
to indicate if a Pokemon has dual nature. After that we can get rid of the irrelevant columns.
pokemon <- pokemon %>%
mutate(Legendary = as.factor(Legendary),
Generation = as.factor(Generation),
Dual = factor(ifelse(is.na(`Type 1`) | is.na(`Type 2`),0,1))) %>%
rename(sp.attack = `Sp. Atk`, sp.defence = `Sp. Def`) %>%
select(4:13)
The Mechanics
Before moving on to how to build and predict a decision tree in R, we need to understand how a decision tree operates, the important terminologies, how to make it work better and most importantly how it decides what to chose.
Terminology:
Decision Trees or CART (Classification and Regression Trees) have the following elements:
Take a look at this picture.
A root node is where the first split in decision occurs. A decision or internal node is the rest of the nodes following the root node where a decision is made. Finally the terminal or leaf node is where the outcomes (based on the previous decisions) are finalized.
The initial splitting node is the parent node to it’s children nodes. The children form the branches of the whole tree
Splitting Criterion
We know that splitting occurs at every node. But what splits? Or more important how does it split? The split occurs on the values of the predictor variables. A node splits into two values (which is why it’s also called a binary tree). Technically though, the node can split into more than two values but most trees have binary splits.That brings us to the first kind of tree called the Recursive Partitioning Tree (RP). It does exactly what the name suggests. Refer to this nice picture
If you would think about a decision tree with 2 variables, what a RP does is it partitions the space (containing values of the response) again and again until each small little partition has one kind or class of values of the response. That’s it! That’s all that a decision tree does. Brilliant yet simple! But there is one final hurdle - how does it decide what variable to pick first or at what value of a predictor variable it should ideally split on?
To answer the first question, in an RP tree, there are a number of criterions that can be used to split a node. Information Gain is a popular splitting criterion that uses entropy. Entropy is a fundamental concept of Information Theory. Tbh, Info Theory is a huge topic by itself and (much) beyond the scope of this tutorial (and myself) for sure. Intuitively, entropy is how much information is missing. If you asked me where I am from and I answer, “India”; there is lots of “entropy” in this answer. Translate this to DTs and splits, an entropy for a variable is how much information you are loosing if you pick a variable to split on out of many other variables. Mathematically, this is:
\[ E = \sum_{i=1}^c - p_i log_2(p_i) \]
Lets understand with an example. Lets say the repose variable is if I go out today or not. Arbitrarily, out of 14 instances (Lets say), I assign 9 to yes (I go out) and 5 to no. So what is the entropy for these decisions?
entropy <- function(vector){
p = vector/sum(vector) # gives you the p_i's
sum(-(p*log2(p))) # the sum
}
entropy(c(9,5))
## [1] 0.940286
Now lets ramp it up a little. Now I decide to go out depending on (a) If it is sunny or not or (b) if I’m feeling lazy or not. Here the two 2x2 table to help you visualize:
set.seed(4)
fake.data.sunny <- data.frame(sunny = sample(c(1,0),14,replace = T),
go.out = sample(c(1,0),14,replace = T))
with(fake.data.sunny,table(sunny,go.out))
## go.out
## sunny 0 1
## 0 6 1
## 1 4 3
fake.data.lazy <- data.frame(lazy = sample(c(1,0),14,replace = T),
go.out = sample(c(1,0),14,replace = T))
with(fake.data.lazy,table(lazy,go.out))
## go.out
## lazy 0 1
## 0 7 1
## 1 3 3
Entropy for two attributes (go out and sunny or go out and lazy) is given by:
\[ Entropy(Target|Variable) = \sum Entropy(Var)*Probability(Var) \]
This helpful function computes the entropy for a variable:
entropy.var <- function(...){ # enter a list of vectors
list.vector = list(...)
p = purrr::map_dbl(list.vector,~sum(.x))
prop = p/sum(p)
sum(prop*(purrr::map_dbl(list.vector,entropy)),na.rm = T)
}
# for sunny
entropy.var(c(6,1),c(4,3)) # 6+1 & 4+3
## [1] 0.7884505
# for lazy
entropy.var(c(7,1),c(3,3))
## [1] 0.7391797
Therefore Information Gain is:
\[ Gain(Target|Variable) = Entrpoy(Target) - Entropy(Target|Variable) \]
For sunny it would be \(0.94 - 0.788 = 0.15\) and for lazy it would be \(0.94 - 0.739 = 0.2\). The decision tree now selects the variable with the largest information gain. In this case, me going out or not would depend on if I was feeling lazy first (so true) and then if it was sunny.
To answer the second question of what value it chose to split on? We can use Information Gain here too. For every value of a variable, the Information Gain is calculated and then the maximum IG is chosen as the best possible IG for that variable. This best possible IG is calculated for all variables. After that we would choose the variable with the highest maximum IG and then the best split point (the one with the highest IG) for that variable is chosen as the point of split.
Now, CARTs can also be used to build regression trees, where the outcomes are not categorical. The splitting criterion for such trees is the MSE for a particular predictor (p) on the response variable instead, where MSE is:
\[ MSE_p = 1/n\sum_{i=1}^n (\hat{Y_{i,p}} - Y_{i,p}) \]
Where \(\hat{Y}\) is the mean response value for that region or partition. Obviously here, the variable with the lowest MSE is then chosen as the splitting variable and the value of predictor with minimum possible MSE is chosen as the split point.
Note: The second kind of tree is called Conditional Partitioning Tree. Instead of using IG as splitting criterion, this tree is grown using hypothesis tests at every node and the variable that is most significant is chosen as the candidate variable to split on.
Prunning
A DT grown using the process above may produce good predictions on the training set, but is likely to overfit the data, leading to poor test set performance. A way to overcome the variance (overfitting) is to grow a smaller tree at the cost of lower bias.
One possible way to tackle this is to build the tree only as long as the decrease in the RSS or increase in IG due to each split exceeds some (high) threshold. This strategy will result in smaller trees, but is too “risky” since a seemingly worthless split early on in the tree might be followed by a very good split-that later on.
Another way is to grow the tree to full depth and the prune it back to create a subtree in a way that gives the lowest possible test error rate. Later on we will see how to calculate the CV error for a tree. If we record the CV error for all possible subtrees we can select the subtree with lowest error. However this might be a inefficient procedure especially for a larger tree with alot of internal nodes. With the prune
function in R we essentially tell rpart
to see if a split is improving the fit by an amount (specified by the CP or cost complexity criterion). If it doesn’t the algorithm will not pursue further splitting.
Prediction
Finally we are ready to implement this procedure in R. We will use the Pokemon data to predict if a Pokemon with specific features is Legendary or not.
Split into train-test
train_test_split <- function(data,percent,seed){
set.seed(seed)
rows = sample(nrow(data))
data <- data[rows,]
split = round(nrow(data)*percent)
list(train = data[1:split,], test = data[(split+1):nrow(data),])
}
list <- train_test_split(pokemon,0.7,123)
train <- list$train; test <- list$test
Recursive partitioning
The recursive partitioning is implemented using the rpart
function in the party
package. You can also specify the categorical nature of response in method
argument. The fancyRpartPlot
in the rattle package helps visualize the tree.
rtree.fit <- rpart(formula = Legendary ~.,
method = "class",
data = train)
rattle::fancyRpartPlot(rtree.fit)
This is R’s representation of the decision tree. We can see that not all variables were used to grow the tree. The most important variable was total points. Let’s take the first node. 92% of observations are not Legendary with the total points less than 580. The label (True/False) above these proportions indicates the way the nodes are voting while the numbers below indicates the composition of the node.
Let’s move on to the next node. If the total is less than 580, you move left and you reach the second node. Here 86% (of the total) of Pokemon are not Legendary. The tree stops because this node correctly votes all these Pokemon as non legendary (indicated by the 1 against 0). If we move to node 3 however, where the total is greater than 580, out of the 14% (do the math 86+14 = 100% of the data, which shows the binary partition) of the passengers, about 42% was correctly voted and the 58% was incorrectly voted as non legendary. Therefore, this node splits again. This process continues until all nodes (see the nodes below) are as pure as possible. Theoretically one could keep on going until all nodes are 100% pure.
Controlled split
You can further control the splitting process using the rpart.control
option and parms
option. Here is a demonstration:
control <- rpart.control(minsplit = 5, # min #obs to attempt splitting
minbucket = 15,
cp = 0.07,
maxcompete = 4,
xval = 5, # number of cross validations
maxdepth = 15) # how deep should the tree grow on any node
control.tree <- rpart(formula = Legendary ~.,
method = "class",
data = train,
parms = list(prior = c(0.2,0.8), split = "information"), # default is gini
control = control)
rattle::fancyRpartPlot(control.tree)
Conditional partitioning
Conditional partitioning are implemented using the ctree
function.
ctree.fit <- ctree(formula = Legendary ~.,
data = train)
Legendary or not?
Predictions on the test set are done using the generic predict
function.
- Recursive Tree Fit
recursive.prediction <- predict(rtree.fit, newdata = test, type = "class")
- Conditional Tree Fit
conditional.prediction <- predict(ctree.fit, newdata = test)
Prediction accuracy
Its pretty straightforward to compute the prediction accuracy. In this case, the accuracy in prediction is very high (0.94).
predict.df <- data.frame(
predictions = recursive.prediction,
actual = test$Legendary
)
(t <- table(predict.df))
## actual
## predictions False True
## False 214 8
## True 6 12
print(paste('Accuracy:',round((t[1]+t[4])/(t[1]+t[2]+t[3]+t[4]),2)))
## [1] "Accuracy: 0.94"
Using ROC plot
If the outcome is probabilistic we could build a ROC curve to measure the accuracy. This version makes probabilistic predictions and the ROC is plotted to measure Area under the curve. The prediction
and performance
function from the ROCR
package is used to measure the TPR and FPR. In this case we see how good the classification was by the strong ROC curve.
rtree.fit.prob <- rpart(formula = Legendary ~.,
method = "anova",
data = train)
recursive.prediction.prob <- predict(rtree.fit, newdata = test)
rtree_roc <- rtree.fit.prob %>%
predict(newdata = test) %>%
ROCR::prediction(test$Legendary) %>%
ROCR::performance("tpr", "fpr")
roc_df <- data.frame(
FPR = rtree_roc@x.values[[1]],
TPR = rtree_roc@y.values[[1]],
cutoff = rtree_roc@alpha.values[[1]]
)
ggplot(roc_df,aes(x = FPR, y = TPR))+
geom_point()+
geom_line()+
theme_minimal()
Pruning
Finally one can prune a tree using the prune
function.
my3cols <- c("#E7B800", "#2E9FDF", "#FC4E07")
my2cols <- c("#2E9FDF", "#FC4E07")
cp <- data.frame(rtree.fit$cptable)
best.cp <- cp[which.min(cp$xerror),"CP"]
cp %>%
select(CP,nsplit,xerror,rel.error)%>%
gather(key,value,-nsplit)%>%
ggplot(aes(x = nsplit, y = value, col = key))+
geom_point(size = 2)+
geom_line(size = 1.1)+
scale_color_manual(values = my3cols)+
theme_minimal()
The results show that 3 splits gives us the minimum xerror (related to the PRESS error) at a CP value of 0.06 approximately. We can prune the tree using this CP value. If we choose to use relative error instead, we would choose 6 splits and a CP of 0.01 instead. Different error measures gives us different CP values and different number of splits. We will create the pruned tree by supplying the minimum CP value based on xerror and get the prediction accuracy.
prune.tree <- prune(rtree.fit, cp = best.cp)
rattle::fancyRpartPlot(prune.tree)
predictions.prune <- predict(prune.tree, newdata = test, type = "class")
predict.df.prune <- data.frame(
predictions = predictions.prune,
actual = test$Legendary
)
(t <- table(predict.df))
## actual
## predictions False True
## False 214 8
## True 6 12
print(paste('Accuracy:',round((t[1]+t[4])/(t[1]+t[2]+t[3]+t[4]),2)))
## [1] "Accuracy: 0.94"
In this case, prunning doesn’t really improve the accuracy in prediction.
Using Cross-Validation
Resampling methods such as Cross-Validation can also be used on the training data to pick out a CP value to prune and fit the best tree. The most convenient way to perform cross-validation is to use the train
function from the caret
package. The call to the object outputs the best CP (which in this case is picked according to the accuracy of 1-missclassification rate) value that can then be used to prune the tree.
tc <- caret::trainControl("cv",50,classProbs = T)
train.rpart <- caret::train(Legendary ~.,
data = train,
method="rpart",
trControl=tc,
tuneLength = 10,
parms=list(split='information'))
train.rpart
## CART
##
## 560 samples
## 9 predictor
## 2 classes: 'False', 'True'
##
## No pre-processing
## Resampling: Cross-Validated (50 fold)
## Summary of sample sizes: 549, 549, 549, 550, 549, 549, ...
## Resampling results across tuning parameters:
##
## cp Accuracy Kappa
## 0.00000000 0.9359091 0.5002696
## 0.02962963 0.9322424 0.4419367
## 0.05925926 0.9392121 0.5448884
## 0.08888889 0.9410303 0.5786528
## 0.11851852 0.9448182 0.6275213
## 0.14814815 0.9448182 0.6404524
## 0.17777778 0.9464848 0.6612857
## 0.20740741 0.9358788 0.7028026
## 0.23703704 0.9393636 0.7327317
## 0.26666667 0.8984545 0.2406340
##
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was cp = 0.1777778.