15.4 Adam, AdamW, and Adaptive Learning Rate Methods
Alright, let’s talk about the rockstars of optimization: adaptive learning rate methods. You’ve probably heard of Adam. It’s the default optimizer for, well, pretty much everything these days. And for good reason. It’s the workhorse that usually gets the job done without much fuss. But you’re not here for “usually.” You’re here to know why it works, when it might betray you, and what the deal is with its slightly more disciplined cousin, AdamW.
Let’s rewind. Before Adam, we had SGD. Simple, elegant, but dumb as a bag of hammers. It used the same learning rate for every parameter, which is like trying to tune a grand piano with a single wrench. Some parameters (the big, slow-moving bass strings) need a gentle nudge. Others (the high, tight treble strings) need quicker, more precise adjustments. This is the problem adaptive methods solve: they give each parameter its own personalized learning rate, updated based on its own history.
The Core Idea: Momentum and Squared Gradients
Adam (which stands for Adaptive Moment Estimation) is essentially the lovechild of two other clever ideas: RMSProp and SGD with Momentum.
Think of Momentum like this. SGD is a ball rolling down a hill, but it stops completely at the bottom of every little valley (a local minimum). Momentum makes it a heavy ball. It has inertia. It rolls through small bumps and shallow valleys, which helps it avoid getting stuck and generally just go faster downhill. It does this by keeping a running average of past gradients (the direction) and using that to inform the next step.
Now, RMSProp is the other parent. It’s all about adjusting the learning rate per parameter. It looks at the magnitude of recent gradients (ignoring the direction). If a parameter has been getting huge gradients (a volatile, “high-frequency” string), it scales down the learning rate for that parameter to avoid overshooting. If the gradients have been small (a stable parameter), it lets the learning rate stay larger.
Adam combines these. It keeps two running averages for each parameter:
- The first moment (m): The mean of the gradients (like Momentum, for direction).
- The second moment (v): The mean of the squared gradients (like RMSProp, for magnitude).
The “adaptive” learning rate for each parameter is then roughly m / sqrt(v). You’re dividing the suggested direction by a scaled measure of how volatile that parameter has been. Brilliant.
Here’s the kicker, and it’s the first thing everyone gets wrong: those running averages m and v start at zero. This causes a massive bias towards zero, especially in the first few steps. It’s like the optimizer has a cold start and can’t move properly. The original Adam paper solves this with bias correction. It scales m and v to account for this initial zero-state. You must, must, MUST implement this. Thankfully, every modern library does.
import torch
# Let's say we have some model parameters
model = torch.nn.Linear(10, 1)
# This is the classic, bias-corrected Adam. Note the 'betas' tuple.
# beta1 controls the decay of the first moment (m, momentum)
# beta2 controls the decay of the second moment (v, squared gradients)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.999))
The Nuts and Bolts: Betas and Epsilon
Those betas are hyperparameters, but you almost never need to touch them. beta1=0.9 means the momentum average looks about 10 steps back. beta2=0.999 means the squared gradient average has a much longer memory, looking about 1000 steps back. This makes sense; you want to know the long-term volatility of a parameter.
Then there’s epsilon (often eps in code), a tiny number (like 1e-8) added to the denominator: step = m / (sqrt(v) + eps). Its job is to prevent division by zero. But it’s also a sneaky source of problems. If eps is too large, it drowns out the adaptive learning rate for parameters with small v, effectively making the optimizer dumber. If it’s too small, you risk numerical instability. Just leave it at the default. The library authors have fought this battle for you.
Adam’s Dirty Secret: Weight Decay Isn’t What You Think
Here’s where the plot thickens. You know weight decay from good old SGD, right? It’s L2 regularization. It adds a term like - λ * w to the update, directly shrinking the weights towards zero on every step. It’s beautiful and effective.
Now, go look at the original Adam paper. It includes weight decay. But it implements it… poorly. In the original Adam, weight decay is just added to the gradient before the adaptive learning rate is applied. This is a critical mistake.
Why? Because Adam’s adaptive learning rate will also effectively regularize the weights. The two forms of regularization—adaptive learning rates and L2 decay—become entangled and interfere with each other. The amount of effective regularization becomes dependent on the history of the gradients, which is a mess. It makes tuning the weight decay hyperparameter a nightmare.
Enter AdamW: The Fix We Deserved
AdamW (W for Weight Decay) is the correction. It’s not a new optimizer; it’s Adam done right. In AdamW, weight decay is decoupled from the adaptive learning rate calculation. It’s applied directly to the weights after the main Adam update, just like it is in SGD.
This separation of concerns is a game-changer. It makes the weight decay hyperparameter actually mean what you think it means: a consistent, direct shrinkage of the weights, independent of the gradient history. It consistently leads to better generalization performance (i.e., models that work better on unseen data) than vanilla Adam.
# This is the way. Notice it's optim.AdamW, not optim.Adam.
# The weight_decay parameter here acts as the proper L2 regularization term.
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
When to Use What (The Real-World Advice)
So, do you always use AdamW? Almost always, yes. It’s my default starting point for most projects, especially with deep networks and transformers. It converges quickly and reliably.
But don’t be a dogmatic fool. There are still cases where SGD with Momentum and a good learning rate schedule can outperform Adam(W), particularly on tasks where you need to converge to an extremely sharp minimum (some types of super-convergence training). The generalization performance of well-tuned SGD is still legendary. The problem is that “well-tuned” part—it requires more babysitting.
The biggest pitfall with Adam is its tendency to sometimes converge to suboptimal solutions. Because it’s so effective at escaping small bumps early on, it can sometimes rush past a good minimum and get stuck in a flatter, worse one later. If you see your training loss stagnating at a higher value than you’d like, don’t be afraid to switch to SGD with a low learning rate to fine-tune and “polish” the solution. Think of AdamW as getting you 95% of the way there, and SGD as your master craftsman finishing the last 5%.
The bottom line: Understand the machinery. Use AdamW by default. And always, always question the defaults. Your model’s performance depends on it.