Published on: September 29, 2020
Table of Content
Decision Trees (DTs) are a non-parametric supervised learning method used for classification and regression. The goal is to create a model that predicts the value of a target variable by learning simple decision rules inferred from the data. DTs are highly interpretable, capable of achieving high accuracy for many tasks while requiring little data preparation.
Creating a decision tree – Recursive Binary Splitting
Growing a tree involves continuously splitting the data into subsets to minimize some cost function. At each step, all features are considered, and different split points are tried and tested using a cost function. The split with the lowest cost is then selected. The process gets repeated until some stopping point (mentioned later). This algorithm is recursive in nature as the groups formed after each split can be subdivided using the same strategy.
Cost of a split
The cost of a split determines how good it is to split at that specific feature value. For regression cost functions like the sum of squared errors or the standard deviation are used.
For classification the Gini Index is used:
Where J is the set of all classes, and pi is the fraction of items belonging to class i. A split should ideally have an error value of zero, which means that the resulting groups contain only one class. The worst gini purity is 0.5, which occurs when the classes in a group are split 50-50.
When should you stop splitting?
Now you might ask when to stop growing the tree? This is an important question because if we would keep splitting and splitting the decision tree would get huge, quite fast. Such complex trees are slow and dent to overfit. Therefore, we will set a predefined stopping criterion to halt the construction of the decision tree.
The two most common stopping methods are:
- Minimum count of training examples assigned to a leaf node, e.g., if there are less than 10 training points, stop splitting.
- Maximum depth (maximum length from root to leaf)
A larger tree might perform better but is also more prone to overfit. Having too large of a min count or too small of a maximum depth could stop the training to early and result in bad performance.
Pruning is a technique that reduces the size of decision trees by removing sections of the tree that have little importance. Pruning reduces the complexity of the final model, and hence improves predictive accuracy by reducing overfitting.
There are multiple pruning techniques available. In this article, we'll focus on two:
- Reduced error pruning
- Cost complexity pruning
Reduced error pruning
One of the simplest forms of pruning is reduced error pruning. Starting at the leaves, each node is replaced with its most popular class. If the loss function is not negatively affected, then the change is kept, else it is reverted. While a somewhat naive approach to pruning, reduced error pruning has the advantage of speed and simplicity.
Cost complexity pruning
Cost complexity pruning, also known as weakest link pruning, is a more sophisticated pruning method. It creates a series of trees T0 to Tn where T0 is the initial tree, and Tn is the root alone. The tree at step i is created by removing a subtree from tree i-1 and replacing it with a leaf node.
For more information, check out:
Decision trees for both classification and regression are super easy to use in Scikit-Learn.
To load in the Iris data-set, create a decision tree object, and train it on the Iris data, the following code can be used:
from sklearn.datasets import load_iris from sklearn import tree iris = load_iris() X = iris.data y = iris.target clf = tree.DecisionTreeClassifier() clf = clf.fit(X, y)
Once trained, you can plot the tree with the plot_tree function:
pip install graphviz or conda install python-graphviz
Below is an example graphviz export of the above tree.
import graphviz dot_data = tree.export_graphviz(clf, out_file=None, feature_names=iris.feature_names, class_names=iris.target_names, filled=True, rounded=True, special_characters=True) graph = graphviz.Source(dot_data) graph.render("iris")
Alternatively, the tree can also be exported in textual format with the export_text method.
from sklearn.datasets import load_iris from sklearn.tree import DecisionTreeClassifier from sklearn.tree import export_text iris = load_iris() decision_tree = DecisionTreeClassifier(random_state=0, max_depth=2) decision_tree = decision_tree.fit(iris.data, iris.target) r = export_text(decision_tree, feature_names=iris['feature_names']) print(r)
|--- petal width (cm) <= 0.80 | |--- class: 0 |--- petal width (cm) > 0.80 | |--- petal width (cm) <= 1.75 | | |--- class: 1 | |--- petal width (cm) > 1.75 | | |--- class: 2
from sklearn.datasets import load_diabetes from sklearn.tree import DecisionTreeRegressor X, y = load_diabetes(return_X_y=True) regressor = DecisionTreeRegressor(random_state=0) regressor = regressor.fit(X, y)