Mode

Key idea

The cluster gives you many GPUs; the framework lets them cooperate. Data parallel: same model on each GPU, different data, all-reduce gradients. Model parallel: split the model across GPUs. FSDP / ZeRO: shard both. The right choice depends on what fits — and "what fits" is the main constraint.

Data parallel (DDP). The default. Each GPU has a full copy of the model. They process different batches, then all-reduce gradients. Easy to set up, scales well for models that fit on one GPU.

FSDP (Fully Sharded Data Parallel). Shards parameters, gradients, and optimizer state across GPUs. Lets you train models that wouldn't fit on a single GPU. Built into PyTorch; DeepSpeed's ZeRO has comparable functionality.

Tensor / model parallel. Split individual layers across GPUs. Used for very large models where even FSDP is insufficient. Megatron, FairScale. Most teams don't need this.

Pipeline parallel. Run different layers on different GPUs, pipelined. GPipe, PipeDream. Useful when one model doesn't fit but a layer does.

Pick by model size

  • < 1 GPU fits: don't bother distributed — single GPU is simplest
  • Model fits, want more throughput: DDP
  • Model nearly fits: FSDP / ZeRO-2
  • Model too big for any GPU: FSDP + activation checkpointing + sometimes tensor parallel
  • Hundreds of GPUs: 2D / 3D parallelism (DDP × FSDP × pipeline)

Common pitfalls

  • Different seeds per rank → ranks diverge silently
  • Logging from all ranks → step counter inflated by world_size
  • Batch norm with global stats vs per-rank → wrong stats in distributed
  • Saving from rank ≠ 0 → race conditions, corrupted checkpoints

import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler

def setup_ddp():
    dist.init_process_group("nccl")     # NCCL for NVIDIA, gloo for CPU
    rank = dist.get_rank()
    torch.cuda.set_device(rank)
    return rank

def train(rank):
    model = MyNet().cuda(rank)
    model = DDP(model, device_ids=[rank])
    sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(),
                                  rank=rank, shuffle=True)
    loader  = DataLoader(dataset, batch_size=64, sampler=sampler, num_workers=4)

    for epoch in range(num_epochs):
        sampler.set_epoch(epoch)           # different shuffle per epoch
        for x, y in loader:
            loss = loss_fn(model(x.cuda(rank)), y.cuda(rank))
            loss.backward()
            opt.step(); opt.zero_grad()

        # Save only on rank 0
        if rank == 0:
            torch.save(model.module.state_dict(), f"epoch_{epoch}.pt")

# Launch: torchrun --nproc-per-node=4 train.py
Want FSDP, checkpoint sharding, & gradient communication optimisations?

FSDP memory savings

$$ \text{memory per GPU} \approx \frac{P + G + O}{W} + A $$

  • P parameters, G gradients, O optimizer state, W world size, A activations
  • Vanilla DDP: P + G + O + A per GPU
  • FSDP: divides the first three by W

$$ \text{memory per GPU} \;\approx\; \frac{\text{params} + \text{gradients} + \text{optimizer state}}{\text{number of GPUs}} \;+\; \text{activations} $$

In words. A back-of-the-envelope for how much GPU memory FSDP saves. There are four big consumers: the model parameters P, their gradients G, the optimizer state O (for Adam, this is roughly 2× the parameter count), and the activations A kept for backprop. Vanilla DDP keeps full copies of all four on every GPU. FSDP shards the first three across W GPUs (the "world size"), so each GPU only holds a fraction. Activations are not sharded by FSDP — they're per-sample, so each GPU still has its own based on its mini-batch.

  • paramsP — the model weights
  • gradientsG — same shape as params, one per backward pass
  • optimizer stateO — Adam's momentum + variance (~2× params)
  • number of GPUsW — total GPUs across the cluster ("world size")
  • activationsA — intermediate tensors kept for backward; not sharded

FSDP details. Wraps your model recursively at a chosen granularity (per layer or per block). Each forward pass: all-gathers parameters for the active block, then frees them. Backward gathers again. Costs extra communication for the memory savings.

Gradient bucketing. DDP batches small gradient tensors into "buckets" before all-reducing. Reduces per-tensor overhead. Default bucket size is fine; tune for very-small or very-large models.

Overlap compute and communication. Modern DDP / FSDP launch the next layer's compute while the previous layer's gradient is being all-reduced. The default is good; verify with the profiler that the gap between kernels is small.

Checkpoint sharding. A 100B-parameter model can't be saved as a single 400 GB file. PyTorch's FullStateDictType.SHARDED_STATE_DICT writes one file per rank. Reload with the same sharding.

Mixed precision in distributed. Use bf16 for compute, fp32 for the master parameters. GradScaler with FSDP. Most frameworks handle this automatically when you turn on AMP + the right FSDP precision policy.

The launcher. torchrun for PyTorch, accelerate launch from HuggingFace, deepspeed for DeepSpeed. Each handles process spawning, environment variables, and (sometimes) restart-from-failure.

import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision, BackwardPrefetch
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

def fsdp_wrap(model):
    return FSDP(
        model,
        auto_wrap_policy=transformer_auto_wrap_policy({TransformerBlock}),
        mixed_precision=MixedPrecision(
            param_dtype=torch.bfloat16,
            reduce_dtype=torch.bfloat16,
            buffer_dtype=torch.bfloat16,
        ),
        backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
        device_id=torch.cuda.current_device(),
        use_orig_params=True,
    )

# Save sharded checkpoint
from torch.distributed.fsdp import StateDictType, FullStateDictConfig
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
    sd = model.state_dict()
    torch.save(sd, f"ckpt-rank{dist.get_rank()}.pt")
Want tensor / pipeline parallelism, NCCL tuning, & multi-node debugging?

3D parallelism

$$ \text{world} = \text{tensor} \times \text{pipeline} \times \text{data} $$

  • Three orthogonal axes of parallelism
  • Total GPUs = product of the three
  • The standard recipe at LLM-training scale

$$ \text{total GPUs} \;=\; (\text{tensor-parallel size}) \;\times\; (\text{pipeline-parallel size}) \;\times\; (\text{data-parallel size}) $$

In words. At LLM-training scale, you split the work along three orthogonal axes simultaneously. Tensor parallel shards individual layers across a few GPUs in a node (NVLink). Pipeline parallel puts different layers on different nodes, streaming micro-batches through. Data parallel replicates the whole stack on multiple groups, with different data per replica. The cluster's total GPU count factorises into the product of the three group sizes — and the right factorisation depends on the model's shape, the interconnect topology, and what fits where.

  • tensor-parallel sizeGPUs cooperating on a single layer's weights
  • pipeline-parallel sizestages of the pipeline (each holds some layers)
  • data-parallel sizeindependent replicas processing different batches
  • total GPUsproduct of all three — your full cluster

Tensor parallel. Split a single layer (e.g., a linear's weight matrix) across GPUs. Each GPU computes part of the output. Requires fast intra-node interconnect (NVLink). Used for the largest models.

Pipeline parallel. Different layers on different GPUs. Micro-batches flow through the pipeline; "bubble" of idle time at start and end. GPipe (Huang et al. 2018), PipeDream, 1F1B (one-forward-one-backward) scheduling.

3D parallelism. Combine DP, TP, PP. Standard at the largest scales (1000+ GPUs). Megatron-LM is the reference implementation. NVIDIA's NeMo, DeepSpeed, and PyTorch's torch.distributed all support some version.

NCCL tuning. NCCL_DEBUG=INFO for diagnostics. Topology-aware: NCCL detects PCIe / NVLink. For very large jobs, tuning NCCL_TREE_THRESHOLD, NCCL_ALGO, NCCL_PROTO. Mostly trial-and-error; profile before tuning.

Multi-node debugging. When jobs run for hours on hundreds of nodes, a single bad NIC can corrupt training. Heartbeat checks, gradient norm checks, periodic checkpoints. Sentry / Datadog / Grafana with cluster-level alerting.

Elastic training. Nodes can join or drop mid-training. torchrun --rdzv-backend=etcd for elastic rendezvous. Checkpointing must be fast (incremental, streaming) to make resumes cheap.

The cluster you wish you had. Saturated all-reduce bandwidth is rare on real clusters; profile to confirm. Often the bottleneck is something more mundane: a slow storage backend, a bad scheduler queue, or a single bad node.

# Megatron-style tensor parallelism for a Linear layer
import torch, torch.nn as nn
import torch.distributed as dist

class ColumnParallelLinear(nn.Module):
    """Splits the weight matrix's output dimension across ranks."""
    def __init__(self, in_features, out_features, world_size):
        super().__init__()
        assert out_features % world_size == 0
        local_out = out_features // world_size
        self.linear = nn.Linear(in_features, local_out, bias=False)

    def forward(self, x):
        local_out = self.linear(x)
        # All-gather to reconstruct the full output (if needed downstream)
        gathered = [torch.zeros_like(local_out) for _ in range(dist.get_world_size())]
        dist.all_gather(gathered, local_out)
        return torch.cat(gathered, dim=-1)
Too dense?