A flowchart of yes/no questions, learned automatically from data.
Mode
Key idea
A series of yes/no questions that funnel data to a prediction. "Is income > £30k?" → "Yes" → "Is age > 40?" → "Yes" → predict "high credit". The algorithm picks the questions and the order, learning them from training data.
Slide the depth — watch the tree carve axis-aligned regions and the right-hand diagram grow
max depth = 3
Every split is axis-aligned — so a single tree can only carve feature space into rectangles. That's fine for the Linear boundary dataset (one split does it) but watch what happens on XOR: depth 1 can't separate the classes at all, but depth 2 nails it with two perpendicular splits. Crank to depth 8 on a noisy dataset and you'll see overfitting — every training point gets its own tiny region.
Decision trees are perhaps the most intuitive ML model — you can read one off and explain exactly why it made each prediction. They handle mixed feature types (numeric + categorical), don't need normalisation, and don't make assumptions about the shape of the data.
The big downside: a single tree is brittle. Small changes in the training data produce very different trees, so they overfit easily. The fix is ensembles (random forests, gradient boosting) — many trees combined. But you have to understand a single tree first.
Reach for it when
You need a model you can read off and explain
Quick baseline with no feature engineering
Mixed numeric / categorical features
You'll feed it into a random forest or boosted ensemble
Skip it when
You want best possible accuracy — use random forests / GBM instead
Smooth function approximation matters (trees are piecewise-constant)
Linear relationships dominate — a linear model is more elegant
Very high-dim or sparse data
from sklearn.tree import DecisionTreeClassifier, plot_tree
import matplotlib.pyplot as plt
clf = DecisionTreeClassifier(max_depth=4, random_state=0)
clf.fit(X_train, y_train)
# Visualise the actual decisions
plot_tree(clf, feature_names=X_train.columns, filled=True)
plt.show()
In words. At a node, look at the points there and compute each class's fraction; impurity is one minus the sum of those fractions squared (it's zero when every point is one class, and large when classes are evenly mixed). For any candidate split, compute the impurity of each child, weight each by the share of points it gets, and subtract that weighted sum from the parent's impurity. That difference is the gain — pick whichever split maximises it. The algorithm tries every feature and every threshold and keeps the winner.
impurityhow mixed the classes are at a node — zero for pure, larger for mixed
class fractionshare of points belonging to each class at the node
gainhow much the split reduces impurity — bigger is better
weighted child impurityeach child's impurity multiplied by its share of points, then summed
CART (Classification And Regression Trees) is the canonical algorithm. Greedy: at each node, search over all features and all thresholds for the split that most decreases impurity. Recurse on each child. Stop when a depth limit is reached or impurity decrease is too small.
Splitting criteria. For classification: Gini (default in sklearn) or entropy (information gain). They're nearly equivalent in practice. For regression: variance reduction — pick the split that most reduces within-node variance of y.
Controlling overfitting. Unpruned trees memorise the training set. Three knobs: max_depth (hard limit), min_samples_leaf (refuse splits that produce tiny leaves), min_impurity_decrease (refuse splits below a threshold). Cost-complexity pruning (CCP) is the principled approach: grow a deep tree, then prune back nodes whose contribution to accuracy is too small relative to their complexity.
Categorical features. sklearn's trees don't natively handle categoricals — they expect numeric inputs (use one-hot or target encoding). LightGBM and XGBoost handle them natively, often via partition-based splits which are more efficient than one-hot.
Reach for it when
Need an interpretable baseline you can show non-technical stakeholders
Quick exploration on a new dataset (no scaling needed)
Stage 1 of a tree-based pipeline (RF / boosting build on this)
Decisions that should be a literal flowchart
Skip it when
Continuous target where smooth predictions matter
You'll use an ensemble anyway — go straight to that
High-dim sparse text — linear models or NB win
Strong distributional assumptions you'd rather encode explicitly
In words. Instead of judging a tree only by how often it gets training points wrong, also charge it for how many leaves it has. α (alpha, a small positive number) is the price-per-leaf — at α = 0 you keep the whole tree, and as α grows you progressively prune branches whose accuracy gain doesn't justify their leaf cost. Sweeping α from 0 upwards traces out a nested sequence of pruned subtrees; pick whichever one has the best validation score (often the smallest within one standard error of the best, as in the code below).
penalised costthe quantity being minimised — fit plus complexity
misclassification costfraction of training points the tree gets wrong
number of leavescount of terminal nodes — a proxy for tree size
αregularization strength — bigger α means harsher penalty per leaf
Greedy splits are myopic. CART picks the best local split at each node, which isn't always the globally optimal tree. Finding the optimal tree is NP-hard. Modern interpretable-ML methods (OCT, MurTree) use mathematical programming to find globally optimal trees for small problems — useful for high-stakes applications where you want to defend the structure.
Bias and feature importance. Impurity-based importance (sklearn's feature_importances_) is biased toward high-cardinality features — they offer more split candidates. Permutation importance on a held-out set gives more honest estimates. For random forests this matters even more — see the discussion on the RF page.
Missing values. CART originally used surrogate splits — alternate splits learned to handle missing values gracefully. sklearn doesn't implement this; LightGBM/XGBoost route missing values to whichever child improves the loss most. The simplest workaround in sklearn: impute first, or use HistGradientBoostingClassifier which handles missing natively.
Monotonicity constraints. Sometimes you need the model to be monotone in a feature (e.g. credit risk should not decrease with income). XGBoost and LightGBM both support monotone constraints; vanilla CART doesn't. When this matters, switch frameworks.
Where trees go to die. Decision trees alone are rarely state-of-the-art today. They're foundations for ensembles (RF, GBM). The interesting research is now in: differentiable trees (soft decision trees, tree-MLPs), neural-tree hybrids (NODE, TabNet), and globally-optimal trees for interpretability-critical settings.
Reach for it when
Regulatory / interpretability requirements that need a literal flowchart
Cost-complexity pruning gives you a principled accuracy/complexity trade-off
You're building or debugging an ensemble and need to inspect base learners
Globally-optimal trees on small problems — for defensibility
Skip it when
You're chasing accuracy — ensembles dominate
Smooth function approximation needed
Adversarial robustness — trees have known attack surfaces
Online / streaming setting — incremental tree algorithms exist but are niche
from sklearn.tree import DecisionTreeClassifier
import numpy as np
# Cost-complexity pruning path
tree = DecisionTreeClassifier(random_state=0).fit(X_train, y_train)
path = tree.cost_complexity_pruning_path(X_train, y_train)
alphas, impurities = path.ccp_alphas, path.impurities
# Train one tree per alpha; pick the smallest tree within 1 SE of the best
trees = [
DecisionTreeClassifier(ccp_alpha=a, random_state=0).fit(X_train, y_train)
for a in alphas
]
val_scores = np.array([t.score(X_val, y_val) for t in trees])
best_i = val_scores.argmax()
se = val_scores.std() / np.sqrt(len(val_scores))
within = np.where(val_scores >= val_scores[best_i] - se)[0]
chosen = trees[within[-1]] # simplest tree within 1 SE — Occam's razor
print(f"Chose tree with {chosen.get_n_leaves()} leaves, "
f"val acc = {val_scores[within[-1]]:.3f}")