6.2 Pruning: Pre-Pruning and Post-Pruning
Right, so you’ve built a decision tree. It’s a thing of beauty. It fits your training data perfectly. You run it on your test set and… oh. It’s a disaster. It’s memorized every single quirk and bit of noise in your training data, including the ID of the customer who bought the product and what they had for lunch. This is the textbook definition of overfitting, and it’s why a full-grown, un-pruned tree is often about as useful as a chocolate teapot.
We fix this with a technique called pruning. The metaphor is perfect: you’re cutting away the unnecessary branches (the complex, over-specific rules) to leave a stronger, more robust tree. There are two main strategies for this: stopping the growth early (pre-pruning) and growing it fully then cutting it back (post-pruning). One is like a cautious gardener, the other like a bold topiarist.
Pre-Pruning: The Anxious Gardener
Pre-pruning, also called ’early stopping’, involves setting hyperparameters that halt the tree’s growth before it becomes perfectly pure. You’re essentially putting rules in place that say, “Okay, that’s deep enough,” or “That node isn’t significant enough, stop there.”
The main levers you can pull in scikit-learn are:
max_depth: The absolute simplest one. “Don’t grow deeper than this many questions.”min_samples_split: A node must have at least this many samples before it’s even allowed to be split.min_samples_leaf: A leaf must have at least this many samples. This prevents a split that would create a leaf with, say, 2 samples in it—a classic overfitting move.min_impurity_decrease: A split must improve the impurity (Gini/Entropy) by at least this amount. If the potential gain is tiny, it’s not worth the complexity.
Here’s how you’d use them. Notice we’re not using the iris dataset for the millionth time; we’ll use a slightly more interesting one.
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
# Make a synthetic dataset that has a non-linear structure
X, y = make_moons(n_samples=300, noise=0.25, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
# An overfit tree for comparison
overfit_tree = DecisionTreeClassifier(random_state=42)
overfit_tree.fit(X_train, y_train)
print(f"Overfit Tree Train Score: {overfit_tree.score(X_train, y_train):.3f}")
print(f"Overfit Tree Test Score: {overfit_tree.score(X_test, y_test):.3f}")
# A pre-pruned tree
pruned_tree = DecisionTreeClassifier(
max_depth=4,
min_samples_split=10,
min_samples_leaf=5,
max_features='sqrt', # Another good one: limit features per split
random_state=42
)
pruned_tree.fit(X_train, y_train)
print(f"\nPruned Tree Train Score: {pruned_tree.score(X_train, y_train):.3f}")
print(f"Pruned Tree Test Score: {pruned_tree.score(X_test, y_test):.3f}")
The pitfall with pre-pruning? It’s called the horizon effect. An early, seemingly weak split might lead to a fantastic split later on. By stopping early, you might never see it. It’s like refusing to turn down a street because it looks a bit shabby, not realizing it’s the only route to the highway. Pre-pruning is fast and simple, but it can be myopic.
Post-Pruning: The Bold Topiarist
This is the more sophisticated approach. Here’s the plan: you let the tree grow. Let it get big, ugly, and deeply, deeply overfit. Then, you go back and surgically remove the branches that provide the least predictive power. You’re effectively replacing complex sub-trees with simple leaf nodes.
The most common method is Cost Complexity Pruning (aka Weakest Link Pruning). It introduces a hyperparameter alpha (≥0) that governs a trade-off between the tree’s complexity and its fit to the data. A higher alpha penalizes complexity more heavily, resulting in a simpler tree.
The beautiful thing about scikit-learn is that it can help you find the optimal alpha. A fitted DecisionTreeClassifier has a method cost_complexity_pruning_path which gives you the effective alphas and the corresponding impurities.
# Let the tree grow wild
tree = DecisionTreeClassifier(random_state=42)
tree.fit(X_train, y_train)
print(f"Full-grown tree depth: {tree.get_depth()}")
# Get the pruning path
path = tree.cost_complexity_pruning_path(X_train, y_train)
ccp_alphas, impurities = path.ccp_alphas, path.impurities
# We'll train a tree for each effective alpha, but skip the last one (it prunes everything)
pruned_trees = []
for ccp_alpha in ccp_alphas[:-1]:
pruned_tree = DecisionTreeClassifier(random_state=42, ccp_alpha=ccp_alpha)
pruned_tree.fit(X_train, y_train)
pruned_trees.append(pruned_tree)
# Let's see what happens to our test score
test_scores = [t.score(X_test, y_test) for t in pruned_trees]
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
ax.plot(ccp_alphas[:-1], test_scores, marker="o", drawstyle="steps-post")
ax.set_xlabel("Effective alpha")
ax.set_ylabel("Test accuracy")
ax.set_title("Accuracy vs. Alpha for Post-Pruned Trees")
plt.show()
You’ll typically see test accuracy shoot up as you introduce a little pruning (alpha around 0.01-0.02), then gradually decline as you become overzealous and chop the tree into a stump. The optimal value is that sweet spot at the peak of the curve.
So, Which One Should You Use?
In practice, post-pruning is generally superior. It doesn’t suffer from the horizon effect because it gets to see the whole picture before making cuts. It’s a more principled, data-driven approach. The downside? It’s computationally more expensive because you’re building a full tree first.
My advice? Use post-pruning (ccp_alpha) for your final, serious models. Use pre-pruning parameters (max_depth, min_samples_leaf) for quick initial experiments and for providing a hard ceiling on tree size, which is useful in ensemble methods like Random Forests where you’re building hundreds of these things and need to keep them in check. The best model often comes from using a combination—a generous max_depth to avoid the horizon effect, followed by a tuned ccp_alpha to do the precise sculpting.