Structured logs, the right print statements, and which debugger tricks work on tensors.
Mode
Key idea
"NaN" is a clue, not a stack trace. ML bugs usually present as silent degradation — bad metrics, divergent loss, slow training — not as crashes. The trick is making each level of the stack legible: structured logs you can search, useful prints at the right verbosity, a debugger that handles tensors.
Use real logging, not print. Python's logging module — or better, loguru — gives you levels, timestamps, structured records, and rotation for free. The five-minute investment pays off forever.
Log what matters. Loss, learning rate, gradient norms, data shapes, the configuration object on startup. Not every batch — sample every N batches. If you're using a tracker (W&B / MLflow), it's also your log.
Most ML bugs. NaN / Inf propagation, shape mismatches (sometimes silently broadcasting), gradient explosion or vanishing, learning-rate too high, data corruption upstream of training, label noise, distribution shift between train and val.
from loguru import logger
import torch
# Structured logging with levels
logger.add("train.log", rotation="100 MB", level="INFO",
format="{time} {level} {message}")
logger.info("config: {}", cfg)
for step in range(num_steps):
loss, metrics = train_step()
if torch.isnan(loss):
logger.error("NaN at step {}", step)
debug_dump(step)
raise RuntimeError("NaN loss")
if step % 100 == 0:
gnorm = sum(p.grad.norm().item() for p in model.parameters() if p.grad is not None)
logger.info("step={} loss={:.4f} gnorm={:.4f}", step, loss.item(), gnorm)
What to watch in training
Loss: smooth decrease; NaN spike = something exploded
Gradient norm: stable; growing = lr too high; collapsing = vanishing
Activation stats: most activations should be neither 0 nor saturated
Learning rate: log it explicitly; schedule bugs are common
GPU utilization: low util usually = data-loading bottleneck
Common silent failures
Wrong data type (fp16 underflow, int truncation)
Broadcasting where you didn't expect (e.g. (B,) vs (B, 1))
Detached graph — gradients don't flow
Frozen layers you forgot to unfreeze
Wrong device — silent half-CPU half-GPU
Want anomaly detection, NaN sources, & debugging recipes?
Re-init: bad initialisation? deeper-net pathology?
Add eps to denominators, clip gradients
Check the data: outliers, label encoding, NaNs upstream
Try fp32: an fp16 underflow / overflow is masking the problem
$$ \text{loss goes NaN or diverges} \;\Rightarrow\; \text{check: learning rate, initialisation, numerical stability, data, precision} $$
In words. When training blows up — loss becomes NaN, or shoots to infinity — there are five usual suspects, in roughly this order of frequency. The ⇒ ("implies") in the math means "if the left side happens, investigate the right". Learning rate too high is by far the most common; try dropping it 10×. Initialisation issues cause deep networks to explode at step 0. Numerical stability covers things like log(0) or division by very small numbers. Data issues (NaN in inputs, outliers, wrong label encoding) are subtle. Precision (fp16 underflow) is increasingly common with mixed-precision training.
NaN/divergeloss becomes "not a number" or grows unboundedly
lrlearning rate — try cutting by 10×
initweight initialisation scheme
numerical stabilityadd ε to denominators, log(x+ε), clip gradients
dataNaN / outliers / bad labels upstream
precisionfp16 / bf16 underflow or overflow — try fp32
Anomaly detection during training.torch.autograd.set_detect_anomaly(True) — slow, but tracks where a NaN originated. Use for one debug run; turn off for production.
Gradient clipping.torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0). Almost always worth it for transformers and RNNs. The clip-then-NaN heuristic: if clipping fixes the divergence, you have an exploding gradient problem.
NaN forensics. Hook every layer's forward/backward to check for NaN; first layer that produces them is the culprit. Common sources: 1/0 in normalisation, log(0), exp(huge), softmax over identical logits with fp16.
The 1-batch overfit. Take 2–8 examples. The model should overfit perfectly in under 1000 steps. If it can't, something is fundamentally wrong (architecture, loss, data pipeline).
Tensor-aware debugger. pdb works fine; for tensor inspection use display(tensor.shape, tensor.dtype, tensor.device, tensor.requires_grad). ipdb is nicer. PyCharm and VSCode have visual tensor inspectors.
The pickle vs JSON rule. Save configs, hyperparameters, and metrics as JSON or YAML — diff-friendly, language-agnostic. Save model weights as pickle / safetensors. Don't mix the two purposes.
import torch
import torch.nn as nn
# NaN-source finder: hook every module
def attach_nan_hooks(model):
def hook(name):
def fn(module, inp, out):
if isinstance(out, torch.Tensor) and torch.isnan(out).any():
print(f"NaN in forward output of {name}")
return fn
for name, mod in model.named_modules():
mod.register_forward_hook(hook(name))
# Print per-layer activation stats to spot saturating layers
def activation_summary(model, x):
activations = {}
def hook(name):
return lambda m, i, o: activations.update({name: o.detach()})
for name, mod in model.named_modules():
mod.register_forward_hook(hook(name))
model(x)
for name, a in activations.items():
print(f"{name:30s} mean={a.mean():.3f} std={a.std():.3f} "
f"max={a.max():.3f} sat={(a.abs() > 5).float().mean():.2%}")
Want anomaly detection at scale, distributed debugging, & OOM forensics?
Loss curve patterns
$$ \text{divergence}, \;\text{plateau}, \;\text{cliff}, \;\text{oscillation}, \;\text{step decay} \;\to\; \text{each implies a specific bug class} $$
Each pattern has a small set of likely causes
Documented troubleshooting tree → much faster than guessing
$$ \{\text{divergence},\; \text{plateau},\; \text{cliff},\; \text{oscillation},\; \text{step decay}\} \;\to\; \text{each pattern maps to a specific class of bug} $$
In words. Loss curves come in a small zoo of recognisable shapes, and each shape narrows the search for what's wrong. Divergence (shooting up) usually means lr is too high or there's a numerical issue. Plateau from step 0 means gradients aren't flowing — something's detached or frozen. Cliff (sudden drop) is usually a schedule kick or finally finding the right answer. Oscillation means lr is too high or batch size is too small. Step decay is normal — just confirms your scheduler is working. The arrow → means "implies"; recognising the pattern compresses hours of debugging into minutes.
divergenceloss shoots upward — usually lr or numerical issue
plateauflat from start — gradient flow broken
cliffsudden drop — schedule change or breakthrough
oscillationnoisy ups and downs — lr too high or batch too small
step decaydiscrete drops matching the LR schedule — usually fine
Loss curve forensics. Diverges sharply: lr too high, bad init, or numerical issue. Plateaus immediately: gradient flow stopped (detach, frozen). Oscillates: lr too high or batch too small. Cliff (sudden drop): a learning rate schedule kick, or the model finally found the right answer.
Distributed debugging. Two ranks producing different losses on the same data → a sync or seeding issue. Use torch.distributed.barrier() + per-rank logging to find where they diverge. Always reproduce on a single GPU before debugging distributed; you'd be amazed how often a 1-GPU debug fixes the cluster bug.
OOM (Out-of-Memory) forensics. Memory usage growing over time → leak (something not getting freed; common with detached graphs kept alive). Plateaus high but stable → just need a smaller batch or more aggressive checkpointing. torch.cuda.memory_summary() shows allocation breakdown.
Profiler-led debugging. If training is slow, profile first. Common culprits: data loading (CPU bottleneck — increase num_workers), small ops (kernel launch overhead — fuse them), single-host sync (torch.distributed.gather blocking on the slowest rank).
The "minimal reproducer" discipline. Reproduce any non-trivial bug in < 30 lines of code with explicit seeds and data. Most ML bugs become 10× easier once isolated; many disappear on re-creation.
Logging hygiene at scale. Per-rank log files. JSON-lines format for parsing. Trace IDs across services. Sentry / Datadog / similar for alerting on real production failures. Cardinality matters — don't log per-example fields you'll have a billion of.
Reproducible bug reports. Save the config, the data hash, the git sha, and the exact CUDA / driver / torch version. Most "this used to work" bugs are environment drift.
import torch, gc
def memory_audit(label=""):
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
alloc = torch.cuda.memory_allocated() / 1e9
reserved = torch.cuda.memory_reserved() / 1e9
print(f"[{label}] alloc {alloc:.2f} GB reserved {reserved:.2f} GB")
# Find the biggest tensors alive — useful for leak hunting
def biggest_tensors(top_k=10):
tensors = []
for obj in gc.get_objects():
try:
if torch.is_tensor(obj):
tensors.append((obj.numel() * obj.element_size(), tuple(obj.shape), obj.dtype))
except: pass
tensors.sort(reverse=True)
for size, shape, dtype in tensors[:top_k]:
print(f" {size / 1e6:6.1f} MB {shape} {dtype}")