11.4 Precision-Recall Curves for Imbalanced Datasets
Right, let’s talk about the one metric to rule them all for imbalanced datasets. You’ve probably been told that accuracy is a dirty liar in these situations, and you were told correctly. If I have a dataset where 99% of transactions are not fraudulent, my idiot model can achieve 99% accuracy by just yelling “NOT FRAUD!” every single time. It’s technically correct, but utterly useless. We need a more nuanced way to judge performance, and that’s where the precision-recall curve comes in. It’s the trusty sidekick you need when your classes are wildly out of balance.
Think of it this way: Precision and Recall are two bickering but brilliant detectives on a case.
- Precision: “When I say it’s fraud, how often am I actually right?” This is about confidence. High precision means you don’t cry wolf. Your positive predictions are trustworthy.
- Recall: “Out of all the actual fraud, how much did I actually catch?” This is about completeness. High recall means you’re not letting bad guys slip through the net.
The eternal tug-of-war between these two is the heart of the matter. You can have high precision (only predicting the most blatant fraud cases) but low recall (missing a ton of subtle fraud). Or you can have high recall (flagging everything remotely suspicious) but abysmal precision (and now you’ve got a thousand false alarms and a very angry customer support team). The precision-recall curve beautifully visualizes this trade-off for every possible decision threshold your model can use.
The Anatomy of a PR Curve
Let’s build this thing from the ground up. Your model (say, a logistic regression or an SVM) doesn’t just spit out a final class label; it outputs a probability or a score for each sample. The default threshold is usually 0.5. “Score >= 0.5? That’s a 1. Score < 0.5? That’s a 0.”
But who died and made 0.5 king? Nobody. That’s the first thing you need to unlearn. The PR curve is created by sweeping this threshold from 0 to 1 and calculating the precision and recall at each possible threshold value. You then plot precision (y-axis) against recall (x-axis). The result is a curve that typically starts high in the top-right (high recall, lower precision) and swoops down to the bottom-right as the threshold increases and you make fewer, but more confident, positive predictions.
The “ideal” point is the top-right corner: perfect precision and perfect recall. The closer your curve hugs that corner, the better your model is. A curve that languishes near the diagonal baseline is a model that’s basically guessing.
Here’s how you generate one using scikit-learn. We’ll use a classic imbalanced dataset.
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import precision_recall_curve, average_precision_score
import matplotlib.pyplot as plt
# Let's cook up a suitably imbalanced dataset. 1000 samples, only 100 are the positive class (10%)
X, y = make_classification(n_samples=1000, n_features=20, n_informative=2,
n_redundant=10, n_clusters_per_class=1,
weights=[0.9, 0.1], random_state=42)
# Split it. Stratify=y is crucial here to preserve the imbalance in both sets.
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, stratify=y, random_state=42)
# Train a simple model. No point in overcomplicating this for demonstration.
model = LogisticRegression(random_state=42, max_iter=1000)
model.fit(X_train, y_train)
# Get the predicted probabilities for the positive class (class 1)
y_scores = model.predict_proba(X_test)[:, 1]
# This is the magic function. It returns all the precision/recall values and their corresponding thresholds.
precision, recall, thresholds = precision_recall_curve(y_test, y_scores)
# Let's plot it.
plt.figure(figsize=(8, 6))
plt.plot(recall, precision, marker='.', label='Our Logistic Regression')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve')
plt.grid(True)
# Let's also find the Average Precision (AP) score - a single number summary of the curve.
ap_score = average_precision_score(y_test, y_scores)
print(f"Average Precision Score: {ap_score:.3f}")
plt.legend()
plt.show()
Interpreting the Curve and Choosing a Threshold
Looking at the curve is one thing; using it to make a decision is another. The thresholds array returned by precision_recall_curve is your key to operationalizing this. Let’s say your business problem is this: “I need to catch at least 80% of all fraud (Recall >= 0.8), and I want the highest possible precision at that level.”
You can find the threshold that gets you as close to that recall as possible.
import numpy as np
# Find the index of the first threshold where recall falls below 0.8
# (The arrays are ordered from low threshold to high threshold, meaning recall decreases)
target_recall = 0.8
index = np.argmax(recall < target_recall) - 1 # Go back one step to the last point at or above 0.8
# If the index is invalid, just use the last one
if index < 0:
index = len(thresholds) - 1
recommended_threshold = thresholds[index]
precision_at_target = precision[index]
recall_at_target = recall[index]
print(f"To achieve ~{recall_at_target:.2f} recall, use a threshold of {recommended_threshold:.3f}")
print(f"This will give you a precision of ~{precision_at_target:.2f}")
This gives you a concrete, data-driven reason to move the threshold away from the arbitrary 0.5 to a value that actually suits your problem’s needs.
The Pitfalls and The “No Free Lunch” Reality
First, the big one: The Baseline Matters. On a severely imbalanced dataset, the “baseline” isn’t 0.5. The baseline is the ratio of positive examples in your data. If only 1% of your data is positive, a model that achieves an Average Precision (AP) above 0.01 is actually doing something. Always compare your AP score to the fraction of positives. If your AP is 0.02 on a 1% positive set, you’ve only marginally improved over random guessing.
Second, beware of wiggly, unstable curves. They usually mean your model is overconfident or you don’t have enough data. A smooth curve is generally a good sign. If you see a sawtooth pattern, it often indicates that small changes in the threshold cause large, discontinuous changes in which predictions are made—often a hallmark of tree-based models or models trained on very small datasets.
Finally, remember that optimizing for the PR curve is still a simplification. It doesn’t account for the actual financial cost of a false positive versus a false negative. That final decision of where to operate on the curve—whether to favor precision or recall—isn’t a statistical question. It’s a business one. Your job is to hand the product manager this curve and say, “Here are your options. You tell me which trade-off you can live with.”