How gradient descent grew up — momentum, adaptive learning rates, and the modern toolkit.
Mode
Key idea
Vanilla SGD is rarely the fastest. Real loss surfaces have narrow ravines, saddle points, and curvature that changes from one direction to another. Modern optimizers solve these in different ways: momentum remembers past direction, RMSprop / Adagrad scale each parameter by its own gradient history, Adam combines both. The viz below races them on the same start.
Same start, same step budget — watch four optimizers race across a curved loss surface
lr = 0.05step 0
All four optimizers see the same loss surface from the same start. SGD moves straight down the local gradient and zig-zags in narrow valleys. Momentum accumulates velocity and "rolls through" valleys. RMSprop shrinks per-parameter steps where gradients are big, freeing it to take bigger steps where they're small. Adam combines momentum with RMSprop's adaptive scaling — usually the safest first pick.
SGD. Update is θ ← θ − η·∇L. Cheap, well-understood, often the best generalization in deep learning — but slow to converge on ill-conditioned problems where one direction needs huge steps and another tiny ones.
Momentum. Keep a running velocity that smooths gradient noise and accelerates along consistent directions. Nesterov momentum looks one step ahead before computing the gradient, which gives a small but real speedup.
Adam. Maintain per-parameter running estimates of the gradient mean and squared mean; scale each step by mean / √(second moment). Essentially momentum × RMSprop. The default for most deep learning today, though SGD-with-momentum sometimes generalizes better on vision tasks.
Practical truth: the learning rate matters more than the optimizer choice. A well-tuned SGD often beats a poorly-tuned Adam. Always sweep the learning rate first.
Adam (and friends)
Default for transformer / NLP training
You don't have time to tune lr per layer / per phase
Sparse gradients (embeddings, NLP) — Adam handles them well
RNN / LSTM training is much friendlier with Adam
SGD-with-momentum
Image classification, where it generalizes a little better
You've tuned the learning rate schedule carefully
Reproducing classical results (most ResNet papers use it)
You want the simplest possible thing in the loop
import torch
# SGD with momentum — the deep-learning workhorse
opt_sgd = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, nesterov=True)
# Adam — the safe default everywhere else
opt_adam = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
# Learning rate matters more than optimizer. Always have a schedule:
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt_adam, T_max=100)
Want the math: momentum, Adam, second-order, line search?
In words. Adam keeps two running averages of recent gradients. The first, velocity, smooths the gradient itself — that's the momentum piece. The second, scale, averages the squared gradient — telling you how big the gradient has typically been on each parameter. Each step divides the velocity by the square root of the scale, so parameters with consistently large gradients take small steps and parameters with tiny gradients take big ones. The β values control how much history each average retains (commonly 0.9 and 0.999). The "bias correction" division by 1 − βt fixes a small bias from initialising both averages at zero — it matters most in the first few hundred steps.
gradientcurrent gradient of the loss
velocityrunning average of recent gradients (momentum)
scalerunning average of recent squared gradients (per-parameter typical magnitude)
β1, β2how much memory each running average keeps (commonly 0.9 and 0.999)
ηlearning rate; ε tiny constant that prevents divide-by-zero
Momentum, properly. The velocity update v ← β v + g is an exponentially-weighted average of past gradients with time-constant ≈ 1/(1−β). Past gradients in directions that point the same way reinforce; gradients that cancel out are smoothed. This is exactly why momentum helps in narrow ravines — the consistent ravine-floor direction gets amplified, while the bouncing perpendicular component cancels.
Adagrad → RMSprop → Adam. Adagrad scales each parameter's step by 1/√(Σ g²) — but the sum grows monotonically, killing learning eventually. RMSprop replaces the sum with an EMA, fixing the decay-to-zero problem. Adam adds momentum on top. AdamW decouples weight decay from the gradient (Loshchilov & Hutter, 2019) — for transformers this is a meaningful improvement.
Learning rate schedules. The single highest-leverage knob is the lr schedule, not the optimizer. Warm-up (linear ramp from 0 to lr over a few hundred steps) is essential for transformers — Adam's bias correction misbehaves at step 1. Cosine decay is the modern standard for the rest. One-cycle and triangular schedules (Smith 2017) are useful when you want a quick convergence.
Second-order methods. Newton's method scales steps by the inverse Hessian — works wonderfully in low dimensions, intractable in high. Practical approximations include K-FAC (block-diagonal Fisher), Shampoo (Kronecker factors of the Hessian), and L-BFGS (history-based quasi-Newton). Useful for fine-tuning or small models; rarely beat tuned Adam at scale.
Implicit regularization. Different optimizers reach different solutions on the same loss surface. SGD's gradient noise has been argued to act as a regularizer that finds "flatter" minima — this is part of why it sometimes generalizes better than Adam, despite Adam reaching a lower training loss.
import torch
# Linear warm-up + cosine decay — the modern transformer recipe
def warm_cos_lr(step, warmup, total, lr_max):
if step < warmup:
return lr_max * step / warmup
p = (step - warmup) / max(1, total - warmup)
return 0.5 * lr_max * (1 + math.cos(math.pi * p))
# Gradient clipping — almost always worth it for transformers / RNNs
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# AdamW: weight decay should NOT pass through Adam's denominator
opt = torch.optim.AdamW(model.parameters(), lr=lr,
betas=(0.9, 0.95), weight_decay=0.1)
Want K-FAC, Shampoo, sharpness-aware minimization, and natural gradients?
FFisher information matrix — the Hessian of the KL between distributions
F−1inverse Fisher; rescales the gradient by local curvature in distribution space
Steps in distribution space — reparameterization-invariant
$$ \text{new weights} \;=\; \text{old weights} \;-\; \eta \,\times\, (\text{curvature in distribution space})^{-1} \,\times\, \text{gradient of loss} $$
In words. The natural gradient is "regular gradient descent, but measure distance the right way". Ordinary SGD treats every parameter as if a unit step in parameter space were equally meaningful. The Fisher information matrixF measures how much the model's output distribution actually changes when each parameter moves. Multiplying the gradient by its inverse rescales the step so the same step always produces a similar amount of change in the output distribution. The result is reparameterization-invariant: if you rewrote your weights in different units, the trajectory would be the same. In practice F is huge, so methods like K-FAC and Shampoo approximate it cheaply.
old weights, new weightsparameters before and after the step
gradient of losssame gradient SGD would use
curvature in distribution spaceFisher information — how much the output distribution changes when each parameter moves
(·)−1inverse — undoes that curvature, so each step produces a controlled change in output
Reparameterization-invariant: rescaling the parameter axes doesn't change the trajectory
K-FAC. Kronecker-Factored Approximate Curvature (Martens & Grosse, 2015) approximates the Fisher as a block-diagonal Kronecker product per layer. Practical for medium-sized networks; offers real wall-clock speedup for some workloads but adds substantial implementation complexity.
Shampoo. Anil et al. (2020) — second-order preconditioning that maintains Kronecker factors of the gradient covariance and applies their inverse 1/4 root. Slower per step than Adam but converges in fewer steps; recent variants are competitive on real-world training.
SAM — Sharpness-Aware Minimization. Foret et al. (2021) — explicitly look for parameters that minimize loss and have small gradient norm in their neighbourhood (flat minima). Adds a small inner-loop "find the worst direction nearby" step. Reliably improves generalization on vision tasks; ~2× the compute cost.
Lion / Sophia / Schedule-Free. Three recent optimizers from Chen et al. (2023), Liu et al. (2024), and Defazio et al. (2024) respectively. Lion uses sign-of-momentum updates; Sophia uses a clipped second-order term; Schedule-Free does away with the lr schedule entirely. Each shows wins on some benchmarks but hasn't displaced AdamW as the default.
Why SGD generalizes better. Empirically, SGD-with-momentum often finds "wider" minima than Adam — measured by Hessian eigenvalues at the solution. There are theoretical arguments (gradient noise as injected SDE) and counter-arguments. The practical implication is unchanged: try AdamW first, and for vision benchmarks try SGD-with-momentum as a comparison.
Trust region methods. TRPO and PPO (in RL) use a KL constraint instead of a learning rate — they implicitly act like a natural gradient step. Lessons from those have leaked back into supervised learning (e.g., the "trust region for gradient steps" view of Adam's denominator).
import torch
# Sharpness-Aware Minimization in 10 lines
class SAM(torch.optim.Optimizer):
def __init__(self, params, base_opt, rho=0.05, **kwargs):
self.base = base_opt(params, **kwargs)
defaults = dict(rho=rho, **self.base.defaults)
super().__init__(self.base.param_groups, defaults)
@torch.no_grad()
def first_step(self):
norm = torch.norm(torch.stack([p.grad.norm() for g in self.param_groups
for p in g["params"] if p.grad is not None]))
for g in self.param_groups:
scale = g["rho"] / (norm + 1e-12)
for p in g["params"]:
if p.grad is None: continue
e_w = p.grad * scale
p.add_(e_w); self.state[p]["e_w"] = e_w
@torch.no_grad()
def second_step(self):
for g in self.param_groups:
for p in g["params"]:
if "e_w" in self.state[p]: p.sub_(self.state[p]["e_w"])
self.base.step()
# Usage:
# loss.backward(); opt.first_step()
# loss = compute_loss(); loss.backward(); opt.second_step()