A decision tree classifies data by successively splitting the data according to some threshold. To explain what this really means, we will change our usual approach and first see how we can use a decision tree in Python and then explain the algorithm.

Decision Trees in Python¶

Before we actually start, let us again define a function to plot the results.

In [2]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

def plot_decision_boundary(X, y, classifier, resolution=0.02):
    """
    Modified version of an implementation in
    Sebastian Raschka and Vahid Mirijalili,
    Python Machine Learning,
    2nd ed., 2017, Packt Publishing
    """
    markers = ('o', 's')
    colors = ('tab:blue', 'tab:orange')
    cmap = ListedColormap(colors)
    
    # define the grid
    x1_min, x1_max = X[:, 0].min() - 1, X[:, 0].max() + 1
    x2_min, x2_max = X[:, 1].min() - 1, X[:, 1].max() + 1
    xx1, xx2 = np.meshgrid(np.arange(x1_min, x1_max, resolution),
    np.arange(x2_min, x2_max, resolution))
    
    if classifier is not None:
        # for each grid point, predict the class
        lab = classifier.predict(np.array([xx1.ravel(), xx2.ravel()]).T)
        lab = lab.reshape(xx1.shape)
    
        # plot the decision regions
        plt.contourf(xx1, xx2, lab, alpha=0.3, cmap=cmap)
        plt.xlim(xx1.min(), xx1.max())
        plt.ylim(xx2.min(), xx2.max())
    
    # plot the data points
    for idx, cl in enumerate(np.unique(y)):
        plt.scatter(x=X[y == cl, 0],
                    y=X[y == cl, 1],
                    alpha=0.8,
                    c=colors[idx],
                    marker=markers[idx],
                    label=f'Class {cl}')
        plt.xlabel('feature 1')
        plt.ylabel('feature 2')
        plt.legend()

As data, we will use random toy data showing that shows a XOR problem.

In [3]:
# create some toy data
rng = np.random.default_rng(seed=0)
X = rng.standard_normal((30, 2))

# divide the data
y = np.where(X[:,0]*X[:,1] >= 0, 1, 0)

# plot the data
plot_decision_boundary(X, y, None)
plt.title('Toy data')
plt.show()

Now we use the DecisionTreeClassifier from the scikit-learn library and classify the data.

In [4]:
from sklearn.tree import DecisionTreeClassifier

dt = DecisionTreeClassifier(criterion='gini',
                            max_depth=3,
                            random_state=0)
dt.fit(X, y)

plot_decision_boundary(X, y, dt)
plt.title('Scikit-Learn Decision Tree')
plt.show()

As a nice feature, there is a function in the scikit-learn library that plots the resulting tree structure for us.

In [5]:
from sklearn import tree
tree.plot_tree(dt,
               feature_names=['feat 1', 'feat 2'],
               filled=True,
               rounded=True)
plt.show()

In a real example, we would of course need to check the accuracy of our trained model. We will do this when we talk about the "standard" machine learning workflow.

What is a descision tree¶

Ok, now you have seen how to use scikit-learn and how to build a decision tree, but what does it all mean? In the figure above we see a flowchart like structure. At the top there is a root node, which starts with the whole data set. This set is then divided by a decision, in this case Is the value of feature 1 <= 0.385? All samples for which this is true are then collected in a child node on the left branch, all others in a child node on the right branch. In the above case, we start with 30 samples in the root node and divide them into 21 samples in the left branch and 9 in the right branch. This division process is then continued. For example, the left branch is further divided by Is the value of feature 2 <= -0.044? and so on. This process continues until a branch contains only samples of one class. These nodes are called leaf nodes. In the tree above, we end up with four leaf nodes at the bottom and one right at the top under the root node. The information value within the nodes shows the distribution of different classes, and the color corresponds to the majority class within the node.

If we compare the tree above with the feature map above, we see that each decision in the tree corresponds to a horizontal or vertical line in the feature map.

So far, so good, we now understand what a decision tree looks like, but...

How is the tree build up?¶

The goal of the algorithm is to produce pure leaf nodes that contain only samples of one class. Thus, if we have a measure of the impurity of a node, the goal is to reduce the impurity from the parent node to the child nodes. In other words, we want to maximize the information gain from splitting. This is defined as

$$ IG(D_p, f)=I(D_p)-\frac{N_{left}}{N_p}I(D_{left})-\frac{N_{right}}{N_p}I(D_{right}) \, . $$

where $IG$ is the information gain, $I$ is the impurity function, $D$ and $N$ are the data set and sample number for the parent $p$, $left$ and $right$ child nodes. The information gain is thus just the difference between the impurity of the parent node and the weighted mean of the child nodes.

For the calculation of the impurity we have several possibilities. The most common ones are the entropy and the gini-impurity, which are defined as

$$ \begin{align*} I_{ent}&=-\sum_{i=1}^cp(i|t)\log_2p(i|t)\, ,\\ I_{gini}&=\sum_{i=1}^cp(i|t)\left(1-p(i|t)\right)=1-\sum_{i=1}^cp(i|t)^2\, . \end{align*} $$

Within both formulas, the sum runs over all classes and $p(i|t)$ is the probability to find a member of class i in the set of samples, i.e. $N_{class, i}/N_{node}$. In practice the results with both measures are similar and in most case the gini impurity is used as default case.

The algorithm is thus

  1. Split the root node for one feature.
  2. Compute the information gain.
  3. Repeat the splitting for a different value. -> Scan the entire feature space.
  4. Find the split that maximizes the information gain.
  5. Split the nodes and continue with step 1 for all child nodes until the child nodes are pure.

Some notes¶

  • The result of the tree is easy to interpret.
  • Decision trees do not require much data handling before the training and work on catergorical and continuous features.
  • Decision trees can classify non-linear problems.
  • Decision trees tend to overfit, because they partition the data until each sample is perfectly classified. We can deal with this problem by adding a maximum depth to the tree. Above, we chose three (even though the tree only needed 2 levels).

Random Forest¶

Decision trees can also be used with so called ensemble methods. An example is a random forest, which follows four simple steps

  1. Draw a subsample from the training data.
  2. Build a decision tree on the subsample.
  3. Repeat steps 1 and 2 several times.
  4. Classify new data according to the majority vote of the trees in the forest.

The following figure shows the result of a random forest with 10 trees.

In [6]:
from sklearn.ensemble import RandomForestClassifier

rf = RandomForestClassifier(n_estimators=10,
                            random_state=0)
rf.fit(X, y)

plot_decision_boundary(X, y, rf)
plt.title('Scikit-Learn Random Forest')
plt.show()

Citation

The E-Learning project SOGA-Py was developed at the Department of Earth Sciences by Annette Rudolph, Joachim Krois and Kai Hartmann. You can reach us via mail by soga[at]zedat.fu-berlin.de.

Creative Commons License
You may use this project freely under the Creative Commons Attribution-ShareAlike 4.0 International License.

Please cite as follow: Rudolph, A., Krois, J., Hartmann, K. (2023): Statistics and Geodata Analysis using Python (SOGA-Py). Department of Earth Sciences, Freie Universitaet Berlin.