13.3 Optuna: Define-by-Run API and Pruning
Right, so you’ve decided to stop just randomly poking at numbers like learning_rate=1e-4 and hoping for the best. Good. Welcome to the grown-up table. We’re going to talk about Optuna, which is, frankly, one of the best things to happen to hyperparameter optimization. Its “define-by-run” API is the key reason why. It feels like you’re just writing a script, not filling out a government form in triplicate.
Other libraries (I’m looking at you, Hyperopt) make you define your search space statically before your trial logic. It’s clunky. Optuna’s define-by-run approach lets you dynamically define parameters right where you need them, inside your objective function. This is wildly more flexible. Need to conditionally suggest a parameter based on another? Go for it. Want to structure your code logically? Nothing’s stopping you.
Here’s the absolute simplest way it works. You define an objective function that takes an Optuna trial object, and you use that object to ask for parameter suggestions.
import optuna
def objective(trial):
# Ask Optuna to suggest a value for a hyperparameter.
n_layers = trial.suggest_int('n_layers', 1, 5)
learning_rate = trial.suggest_float('learning_rate', 1e-5, 1e-2, log=True)
# ... Build your model using these values ...
model = create_model(n_layers, learning_rate)
# ... Train your model and evaluate it ...
accuracy = train_and_evaluate(model)
# Return the metric you want to minimize or maximize.
return accuracy
# Create a study and run the optimization.
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=100)
print(f"Best trial: {study.best_trial.value}")
print(f"Best params: {study.best_trial.params}")
See? No bizarre DSL, no separate space definition. You’re just writing Python. The trial.suggest_* methods are how you tell Optuna what to try. suggest_categorical for discrete choices, suggest_int for integers, and suggest_float for… well, you get it. The log=True argument is a lifesaver for things like learning rates where you care about orders of magnitude, not linear steps.
The Magic of Pruning: Not Every Child is a Winner
Here’s the concept that’ll save you a small fortune in cloud compute bills: pruning. You know that feeling when you’re three epochs into a training run and the loss is already heading to infinity? Pruning is the process of automatically telling that trial to just give up. Stop wasting GPU cycles. It’s a mercy killing.
Optuna uses asynchronous successive halving algorithms under the hood to compare trials while they’re running. A trial that’s performing poorly at an early step (like a low intermediate accuracy after a few epochs) gets killed off, freeing resources for more promising candidates.
To use it, you have to “report” your intermediate scores back to Optuna and ask it if you should quit.
def objective_with_pruning(trial):
n_layers = trial.suggest_int('n_layers', 1, 5)
learning_rate = trial.suggest_float('learning_rate', 1e-5, 1e-2, log=True)
model = create_model(n_layers, learning_rate)
# Let's say we train for 10 epochs, but we want to check for pruning after each.
for epoch in range(10):
accuracy = train_epoch(model) # Your function to train for one epoch
# Report the current score and ask if we should prune.
trial.report(accuracy, epoch)
# Handle pruning based on the intermediate value.
if trial.should_prune():
# This raises a built-in exception to cleanly stop the trial.
raise optuna.exceptions.TrialPruned()
return accuracy
# You MUST use a pruner when creating the study for this to work.
study = optuna.create_study(
direction='maximize',
pruner=optuna.pruners.MedianPruner() # A good default pruner
)
study.optimize(objective_with_pruning, n_trials=50)
The MedianPruner is a solid starting point. It basically says, “if your trial’s intermediate value is below the median of all trials at the same step, you’re probably a loser, and you’re gone.” It’s brutally effective.
Common Pitfalls and The “Oh, Right” Moments
Stateful Chaos: Your objective function must be stateless. If you use a global variable inside it and modify it, you will have a spectacularly bad time. Optuna runs trials concurrently. Each trial must be independent. If you need to set a random seed, use the
trialnumber itself to seed your RNGs for reproducibility.Pruner Tuning: The default pruner might be too aggressive or too lax for your problem. If it’s too aggressive, it might kill a slow starter that would have eventually become the champion. If it’s too lax, you’re not saving much time. You often need to adjust the
n_startup_trips(how many steps to wait before starting to prune) andn_warmup_steps(how many steps to ignore for median calculation) parameters of your pruner. Don’t just accept the defaults; think about your training curve.The Categorical Trap:
suggest_categorical(['option1', 'option2'])is your best and worst friend. It’s incredibly flexible, but if you give it too many options, especially continuous numeric ones disguised as categories, you’re forcing Optuna to treat them as unrelated choices. If you have a list of learning rates[1e-5, 1e-4, 1e-3, 1e-2], you’re better off usingsuggest_float(..., log=True). The sampler can intelligently explore the logarithmic space instead of just guessing between four points.Database Persistence: For any serious work, please use database storage (
storage='sqlite:///my_study.db'). It lets you stop and restart your script without losing all your progress. It also lets you run multiple scripts adding trials to the same study, which is a great way to parallelize on a cluster. The in-memory default is for toys and quick tests. I’ve lost a weekend’s worth of trials to a laptop reboot once. Learn from my pain.
The beauty of Optuna is that it gets out of your way. You write your training loop, drop in a few trial.suggest_* calls and a pruning check, and you’re off to the races. It feels less like configuring a framework and more like just getting help from a very smart, very impatient friend who’s constantly looking over your shoulder saying, “Yeah, that’s clearly not working. Try the next one.”