Reusable training-loop boilerplate — Lightning, Accelerate, custom — so you spend time on the model, not the wrapper.
Mode
Key idea
The training loop is mostly the same across projects. Forward pass, backward, step, scheduler, log, eval, save. Don't rewrite it every time. Use Lightning, Accelerate, or a small in-house scaffold. The model logic is what's project-specific; the loop is plumbing.
Three options. (1) Vanilla PyTorch — the most control, the most boilerplate. (2) HuggingFace Accelerate — minimal changes to your existing loop, handles devices + distributed + mixed precision. (3) PyTorch Lightning — opinionated scaffold; you implement training_step, it handles everything else.
Lightning is great for typical supervised training; Accelerate is great when you want to keep your custom loop but get free distributed + mixed-precision; vanilla is right for unusual training patterns (RL, GANs, custom adversarial setups).
Picking a framework
Lightning: typical supervised training, fast iteration, good defaults
Accelerate: existing custom loop you don't want to throw away
Vanilla PyTorch: unusual training (RL, multi-network adversarial)
In words. Almost every training loop has the same skeleton. prep happens once at startup (build model, optimizer, data loader). Then a block of four stages — step the model, log metrics, run eval periodically, save a checkpoint — runs many times in a loop; the * is regex-style notation for "repeat zero or more times". At the end, save the final model. Frameworks differ only in how you customise each stage (callbacks, hooks, mixins) — the shape is the same.
train stepforward + backward + optimizer step on one batch
log / evaluate / checkpointperiodic side-effects inside the loop
Every framework's loop is some specialisation of this skeleton
Lightning callbacks. Lightning's extensibility model. EarlyStopping, ModelCheckpoint, LearningRateMonitor, custom callbacks. The right place to inject behaviour without bloating the LightningModule.
Accelerate. Wraps your existing loop. accelerator.prepare(model, optimizer, dataloader) handles device placement, mixed precision, distributed. Your loop stays mostly the same but gains all the framework features.
Mixin patterns. Reusable pieces: EMAMixin for exponential moving average of weights, GradAccumMixin for accumulation. Drop into any model.
Custom Lightning DataModules. Encapsulate your data setup: prepare_data (download), setup (split, transform), train_dataloader / val_dataloader. Lightning's separation of model and data scaffolding pays off across multiple projects.
Trainer flags.fast_dev_run=True for a smoke test, overfit_batches=2 for a 1-batch sanity check, limit_train_batches=0.1 for a quick partial run. Lightning has most of these built in; Accelerate gets them via your own flags.
Save the LightningModule, not the model. Lightning saves the module with hyperparameters; load with MyModel.load_from_checkpoint(path) and you get the same object with the same configuration. Don't extract the bare PyTorch model unless you need to deploy to non-Lightning code.
from accelerate import Accelerator
# Keep your own training loop; let Accelerate handle the rest
accelerator = Accelerator(mixed_precision="bf16", gradient_accumulation_steps=4)
model, opt, loader = accelerator.prepare(model, opt, loader)
for batch in loader:
with accelerator.accumulate(model):
out = model(batch["x"])
loss = loss_fn(out, batch["y"])
accelerator.backward(loss)
opt.step(); opt.zero_grad()
if accelerator.is_main_process and step % 100 == 0:
accelerator.print(f"step={step} loss={loss.item():.4f}")
accelerator.wait_for_everyone()
if accelerator.is_main_process:
accelerator.save(accelerator.unwrap_model(model).state_dict(), "ckpt.pt")
Want custom callbacks, EMA / SWA, & the unconventional loops?
In words. EMA (exponential moving average) keeps a slowly-updating shadow copy of your weights. At each step, the smoothed weights are mostly the previous smoothed weights, mixed with a tiny bit of the latest live weights. μ (mu, a Greek letter) is the decay — typically 0.999 or 0.9999; bigger means slower-changing (more smoothing). The superscript (t) just labels the step number. Use the smoothed copy at evaluation / inference time — it's almost always a slightly better, more stable predictor than the latest weights.
smoothed weights nowthe EMA shadow copy after the current step
μdecay factor (close to 1) — how much of the previous EMA to keep
current weightsthe live trained parameters at this step
Standard for diffusion training, self-supervised methods, and some supervised setups
EMA / SWA. Maintain a running average of weights during training; use it at inference. EMA decay typically 0.999 or 0.9999. SWA (Stochastic Weight Averaging) is the larger-step variant. Both reliably improve test accuracy by 0.1-1pp.
Gradient accumulation correctly. Forward + backward N micro-batches without stepping; then step once. Effective batch size is N × micro_batch. Watch out for BatchNorm — it uses per-rank stats, not effective batch stats.
Custom unusual loops. GAN training (alternating G and D updates), reinforcement learning (rollouts + updates), curriculum learning (data shifting over time). Lightning supports most via callbacks; pure PyTorch is sometimes cleaner.
Checkpoint hygiene. Save every N steps (resumable mid-epoch). Keep last K and best M checkpoints. Compress with safetensors. Decoupled from "the final model" — that's a registry artefact.
Resume on failure.trainer.fit(ckpt_path="last.ckpt") in Lightning. With Accelerate, save and restore both model + optimizer + scheduler + RNG state. Essential for jobs longer than a few hours.
Profile inside the training loop.torch.profiler can be attached as a Lightning callback. Look at the chrome trace; spot data-loader gaps, kernel-launch overhead, slow communications.
Test the training loop. Unit test the LightningModule's training_step with a hand-built batch. Asserts: loss is non-negative, gradients are non-zero on every parameter, output shape matches. Catches refactor bugs without launching a full run.
import lightning as L
import torch
import copy
class EMACallback(L.Callback):
def __init__(self, decay=0.999):
self.decay = decay; self.ema = None
def on_train_start(self, trainer, pl_module):
self.ema = copy.deepcopy(pl_module).eval()
for p in self.ema.parameters(): p.requires_grad_(False)
@torch.no_grad()
def on_train_batch_end(self, trainer, pl_module, *_):
for p_e, p in zip(self.ema.parameters(), pl_module.parameters()):
p_e.data.mul_(self.decay).add_(p.data, alpha=1 - self.decay)
def on_validation_epoch_start(self, trainer, pl_module):
self._swap(pl_module)
def on_validation_epoch_end(self, trainer, pl_module):
self._swap(pl_module)
def _swap(self, pl_module):
for p, p_e in zip(pl_module.parameters(), self.ema.parameters()):
p.data, p_e.data = p_e.data.clone(), p.data.clone()