A decision tree classifies data by successively splitting the data according to some threshold. To explain what that means, we will first explore how we can use a decision tree in R and then explain the algorithm.

Decision Trees in R

Let us begin by generating some toy data of type XOR.

set.seed(0)
X <- matrix(rnorm(60), ncol = 2)
y <- ifelse(X[, 1] * X[, 2] >= 0, 1, 0)

# Create a data frame from the matrix and labels
data_df <- data.frame(X1 = X[, 1], X2 = X[, 2], y = factor(y))

# Define colors for classes
class_colors <- c("dodgerblue", "firebrick1")

# Plot the data with custom colors
plot(data_df$X1, data_df$X2,
  col = class_colors[data_df$y], pch = 19,
  xlab = "Feature 1", ylab = "Feature 2", main = "Scatter Plot of Data"
)
legend("topright", legend = c("Class 0", "Class 1"), col = class_colors, pch = 19, title = "Classes")

Now we use the tree function from the tree package to classify the data.

library(ggplot2)
library(tree)

# decision tree model
model <- tree(y ~ ., data = data_df)


### Plot result ###
# grid of values for prediction
df <- expand.grid(
  X1 = seq(min(data_df$X1), max(data_df$X1), length.out = 100),
  X2 = seq(min(data_df$X2), max(data_df$X2), length.out = 100)
)

# predicted outcome for each grid point
df$y_pred <- predict(model, df, type = "class")

# visualize decision boundary
ggplot(data_df, aes(X1, X2, fill = y)) +
  geom_raster(data = df, aes(fill = y_pred), alpha = 0.5) +
  geom_point(shape = 21, size = 3) +
  theme_minimal() +
  ggtitle("Decision Tree Decision Boundary for Binary Classification") +
  labs(fill = "Predicted y") +
  theme(plot.title = element_text(hjust = 0.5))

With the rpart.plot function we can visualize the tree structure.

library(rpart)
library(rpart.plot)

dt <- rpart(y ~ X, method = "class", data = data.frame(y, X))
printcp(dt) # display the results
## 
## Classification tree:
## rpart(formula = y ~ X, data = data.frame(y, X), method = "class")
## 
## Variables actually used in tree construction:
## [1] X1
## 
## Root node error: 14/30 = 0.46667
## 
## n= 30 
## 
##        CP nsplit rel error xerror    xstd
## 1 0.21429      0   1.00000 1.2143 0.19387
## 2 0.01000      1   0.78571 1.2143 0.19387
# Plotting Decision Tree
rpart.plot(dt, extra = 106)



What is a decision tree ?

Ok, now you have seen how to build a decision tree, but what does it all mean? In the figure above we see a flowchart like structure.The tree starts with a single node (the root), which has a relatively high error rate of approximately 46.67% (complexity parameter CP of 0.21429). As the tree grows (nsplit increases), the error rate decreases (rel error and xerror decrease) because the tree is able to split the data into subsets that are more homogeneous in terms of their class labels. The tree continues to grow until CP reaches 0.01 (no more splitting), and the error rate on unseen data (xerror) stabilizes around 0.92857. The output shows how the decision tree was constructed, how it performed at each stage of growth, and how it might generalize to new data. Adjusting the cp parameter can control the trade-off between tree complexity and model accuracy.

If we compare the tree above with the feature map above, we see that each decision in the tree corresponds to a horizontal or vertical line in the feature map.

So far, so good, we now understand what a decision tree looks like, but…


How is the tree build up?

The goal of the algorithm is to produce pure leaf nodes that contain only samples of one class. Thus, if we have a measure of the impurity of a node, the goal is to reduce the impurity from the parent node to the child nodes. In other words, we want to maximize the information gain from splitting. This is defined as \[IG(D_p,f)=I(D_p)−\frac{N_{left}}{N_p}I(D_{left})−\frac{N_{right}}{N_p}I(D_{right}),\] where \(IG\) is the information gain, \(I\) is the impurity function, \(D\) and \(N\) are the data set and sample number for the parent \(p\), \(left\) and \(right\) child nodes. The information gain is thus just the difference between the impurity of the parent node and the weighted mean of the child nodes.

For the calculation of the impurity we have several possibilities. The most common ones are the entropy and the gini-impurity, which are defined as \[ \begin{align*} I_{ent}&=-\sum_{i=1}^cp(i|t)\log_2p(i|t)\, ,\\ I_{gini}&=\sum_{i=1}^cp(i|t)\left(1-p(i|t)\right)=1-\sum_{i=1}^cp(i|t)^2\, . \end{align*}\]

Within both formulas, the sum runs over all classes and \(p(i∣t)\) is the probability to find a member of class \(i\) in the set of samples, i.e. \(N_{class,i}/N_{node}\). In practice the results with both measures are similar and in most case the gini-impurity is used as default case. The algorithm is thus 1. Split the root node for one feature. 2. Compute the information gain. 3. Repeat the splitting for a different value. -> Scan the entire feature space. 4. Find the split that maximizes the information gain. 5. Split the nodes and continue with step 1 for all child nodes until the child nodes are pure.

Some notes

  • The result of the tree is easy to interpret.
  • Decision trees do not require much data handling before the training and work on categorical and continuous features.
  • Decision trees can classify non-linear problems.
  • Decision trees tend to overfit, because they partition the data until each sample is perfectly classified. We can deal with this problem by adding a maximum depth to the tree. Above, we chose three (even though the tree only needed 2 levels).

Random Forest

Decision trees can also be used with so called ensemble methods. An example is a random forest, which follows four simple steps:

  1. Draw a subsample from the training data.
  2. Build a decision tree on the subsample.
  3. Repeat steps 1 and 2 several times.
  4. Classify new data according to the majority vote of the trees in the forest.

The following figure shows the result of a random forest with 100 trees.

library(randomForest)
library(ggplot2)

# random forest model
rf <- randomForest(y ~ X1 + X2, data = data_df, proximity = TRUE, ntree = 100)

df <- expand.grid(
  X1 = seq(min(data_df$X1), max(data_df$X1), length.out = 100),
  X2 = seq(min(data_df$X2), max(data_df$X2), length.out = 100)
)

# predicted outcome for each grid point
df$y_pred <- predict(rf, df)

#  visualize decision boundary
ggplot(data_df, aes(X1, X2, fill = y)) +
  geom_raster(data = df, aes(fill = y_pred), alpha = 0.5) +
  geom_point(shape = 21, size = 3) +
  theme_minimal() +
  ggtitle("Random Forest Decision Boundary for Binary Classification") +
  labs(fill = "Predicted y") +
  theme(plot.title = element_text(hjust = 0.5))



Citation

The E-Learning project SOGA-R was developed at the Department of Earth Sciences by Kai Hartmann, Joachim Krois and Annette Rudolph. You can reach us via mail by soga[at]zedat.fu-berlin.de.

Creative Commons License
You may use this project freely under the Creative Commons Attribution-ShareAlike 4.0 International License.

Please cite as follow: Hartmann, K., Krois, J., Rudolph, A. (2023): Statistics and Geodata Analysis using R (SOGA-R). Department of Earth Sciences, Freie Universitaet Berlin.