Training a large language model is one of the most resource-intensive computational tasks in modern AI. A single GPT-4 scale training run can cost tens of millions of dollars and consume megawatt-hours of electricity over several months. At this scale, every percentage point of efficiency matters, and the difference between a stable and unstable training run can mean millions of dollars saved or wasted.

In this post, we examine the key techniques that make modern LLM training feasible: optimizers designed for the scale of transformer parameters, learning rate schedules that guide convergence, mixed precision arithmetic that doubles throughput without sacrificing model quality, gradient checkpointing to fit larger models in memory, gradient clipping for stability, and distributed training strategies that coordinate hundreds or thousands of GPUs.

The Training Loop

At its core, LLM training follows the standard supervised learning loop: forward pass to compute the loss, backward pass to compute gradients, and an optimizer step to update parameters. But the details at transformer scale are anything but standard.

The canonical training step updates parameters θ\theta according to:

θt+1=θtηtUpdate(θL,θt,t)\theta_{t+1} = \theta_t - \eta_t \cdot \text{Update}(\nabla_\theta \mathcal{L}, \theta_t, t)

where ηt\eta_t is a time-varying learning rate (governed by a schedule), θL\nabla_\theta \mathcal{L} is the gradient of the language modeling loss (typically cross-entropy over the vocabulary), and Update()\text{Update}(\cdot) is the optimizer-specific transformation of the raw gradient.

For autoregressive language models, the loss on a sequence x1,x2,,xTx_1, x_2, \ldots, x_T is the average negative log-likelihood of next-token prediction:

L(θ)=1Tt=1TlogPθ(xtx1,,xt1)\mathcal{L}(\theta) = -\frac{1}{T} \sum_{t=1}^{T} \log P_\theta(x_t \mid x_1, \ldots, x_{t-1})

A modern training step involves far more than this equation suggests. The forward pass runs in reduced precision (BF16), gradients are accumulated across micro-batches, clipped to prevent explosions, then synchronized across hundreds of GPUs before the optimizer applies its update in FP32.

import torch
from torch.cuda.amp import autocast

def training_step(model, batch, optimizer, scaler, grad_accum_steps, max_grad_norm):
    """A single training step with mixed precision, gradient accumulation, and clipping."""
    optimizer.zero_grad()
    total_loss = 0.0

    for micro_step in range(grad_accum_steps):
        micro_batch = batch[micro_step]
        with autocast(dtype=torch.bfloat16):
            logits = model(micro_batch["input_ids"])
            loss = torch.nn.functional.cross_entropy(
                logits.view(-1, logits.size(-1)),
                micro_batch["labels"].view(-1),
                ignore_index=-100,
            )
            loss = loss / grad_accum_steps  # Normalize by accumulation steps

        loss.backward()
        total_loss += loss.item()

    # Gradient clipping before optimizer step
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_grad_norm)
    optimizer.step()

    return total_loss

Optimizers: From SGD to AdamW

Vanilla SGD and Its Limitations

Stochastic Gradient Descent updates parameters by subtracting the gradient scaled by a learning rate: θt+1=θtηgt\theta_{t+1} = \theta_t - \eta \cdot g_t. This works well for convex problems but struggles with the highly non-convex loss landscapes of transformers. The gradient can oscillate wildly across different parameter groups --- embedding matrices may have gradients many orders of magnitude larger than attention weight matrices --- making a single global learning rate inadequate.

Adam: Adaptive Moment Estimation

Adam (Kingma & Ba, 2015) addresses this by maintaining per-parameter running estimates of the first moment (mean) and second moment (uncentered variance) of the gradient:

mt=β1mt1+(1β1)gt(first moment estimate)m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t \quad \text{(first moment estimate)} vt=β2vt1+(1β2)gt2(second moment estimate)v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2 \quad \text{(second moment estimate)}

Because mtm_t and vtv_t are initialized at zero, they are biased toward zero during early training. The bias-corrected estimates are:

m^t=mt1β1t,v^t=vt1β2t\hat{m}_t = \frac{m_t}{1 - \beta_1^t}, \quad \hat{v}_t = \frac{v_t}{1 - \beta_2^t}

The update divides the first moment by the square root of the second moment, effectively giving each parameter its own adaptive learning rate:

θt+1=θtηm^tv^t+ϵ\theta_{t+1} = \theta_t - \eta \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}

Standard values are β1=0.9\beta_1 = 0.9, β2=0.999\beta_2 = 0.999, ϵ=108\epsilon = 10^{-8}. For LLM training, β2=0.95\beta_2 = 0.95 is increasingly common (as used in LLaMA and GPT-3), which makes the variance estimate more responsive to recent gradients and can improve stability in later training stages.

AdamW: Why L2 Regularization Fails in Adam

Loshchilov & Hutter (2019) identified a subtle but critical flaw in how Adam handles weight decay. To understand it, consider how L2 regularization is typically implemented: the loss becomes L=L+λ2θ2\mathcal{L}' = \mathcal{L} + \frac{\lambda}{2}\|\theta\|^2, so the gradient becomes gt=gt+λθtg_t' = g_t + \lambda \theta_t.

In vanilla SGD, applying L2 regularization through the gradient is mathematically equivalent to applying weight decay directly. The SGD update with L2 is:

θt+1=θtη(gt+λθt)=(1ηλ)θtηgt\theta_{t+1} = \theta_t - \eta (g_t + \lambda \theta_t) = (1 - \eta\lambda)\theta_t - \eta g_t

This is identical to subtracting ηλθt\eta\lambda\theta_t directly from the parameters (weight decay). However, in Adam, the regularization gradient λθt\lambda\theta_t gets divided by v^t+ϵ\sqrt{\hat{v}_t} + \epsilon:

θt+1=θtηm^t+λθtv^t+ϵ\theta_{t+1} = \theta_t - \eta \frac{\hat{m}_t + \lambda\theta_t}{\sqrt{\hat{v}_t} + \epsilon}

This means parameters with large historical gradients (large v^t\hat{v}_t) receive less effective regularization, while parameters with small gradients receive more. This is the opposite of what we want --- large, active parameters should arguably be regularized more strongly. The adaptive scaling that makes Adam effective for optimization actively undermines the regularization.

AdamW fixes this by applying weight decay directly to the parameters, completely bypassing the adaptive scaling:

θt+1=θtη(m^tv^t+ϵ+λθt)\theta_{t+1} = \theta_t - \eta \left( \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} + \lambda\theta_t \right)

Note the key difference: weight decay (λθt\lambda\theta_t) is added after the adaptive division, not before. Every parameter now receives the same proportional decay regardless of its gradient history.

class AdamW:
    """AdamW optimizer with decoupled weight decay.

    The key difference from Adam: weight decay is applied directly to
    parameters, not through the gradient (which would be scaled by the
    adaptive learning rate).
    """
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.95),
                 eps=1e-8, weight_decay=0.1):
        self.params = list(params)
        self.lr = lr
        self.beta1, self.beta2 = betas
        self.eps = eps
        self.weight_decay = weight_decay
        self.t = 0

        # Initialize moment buffers
        self.m = [torch.zeros_like(p) for p in self.params]
        self.v = [torch.zeros_like(p) for p in self.params]

    def step(self):
        self.t += 1
        for i, param in enumerate(self.params):
            if param.grad is None:
                continue

            grad = param.grad.data

            # Update biased first and second moment estimates
            self.m[i] = self.beta1 * self.m[i] + (1 - self.beta1) * grad
            self.v[i] = self.beta2 * self.v[i] + (1 - self.beta2) * grad ** 2

            # Bias correction
            m_hat = self.m[i] / (1 - self.beta1 ** self.t)
            v_hat = self.v[i] / (1 - self.beta2 ** self.t)

            # AdamW update: weight decay is DECOUPLED from the adaptive update
            param.data -= self.lr * (
                m_hat / (v_hat.sqrt() + self.eps)  # Adaptive gradient step
                + self.weight_decay * param.data    # Direct weight decay
            )

Optimizer Memory Overhead

Every optimizer state variable costs memory proportional to model size. For a 70B parameter model in FP32, each extra state tensor adds 280 GB:

OptimizerStates per ParameterExtra Memory per ParamNotes
SGD00 bytesSimplest, but requires careful LR tuning
SGD + Momentum1 (momentum buffer)4 bytesMore stable convergence
Adam / AdamW2 (mtm_t, vtv_t)8 bytesAdaptive LR, standard for LLMs
Adafactor~1 (factored vtv_t)~4 bytesFactorizes second moment as outer product
Lion1 (momentum buffer)4 bytesSign-based update, competitive results
8-bit Adam2 (quantized)2 bytesQuantized optimizer states

For a 70B model, AdamW requires ~560 GB of optimizer state alone (two FP32 tensors), which is a primary motivation for techniques like FSDP and 8-bit optimizers.

Learning Rate Schedules

The learning rate is arguably the single most important hyperparameter in LLM training. Modern practice universally uses schedules that combine warmup with a decay phase.

Linear Warmup

Warmup is essential for transformer training stability. During the first WW steps, the learning rate increases linearly from near-zero to the peak value:

η(t)=ηmaxtW,t<W\eta(t) = \eta_{\max} \cdot \frac{t}{W}, \quad t < W

Why is warmup necessary? Three factors converge:

  1. Adam's variance estimate is unreliable early on. The second moment vtv_t is initialized to zero and takes many steps to reflect the true gradient variance. With a large learning rate, the denominator v^t\sqrt{\hat{v}_t} is artificially small, causing extremely large updates that can destabilize training or push the model into a bad loss basin.

  2. Layer normalization gradients are large initially. Before the model learns meaningful representations, gradients through layer normalization can be very noisy. Warmup gives the normalization statistics time to stabilize.

  3. Embedding matrices see sparse, high-magnitude gradients. Only a small subset of tokens appears in each batch, but those token embeddings receive concentrated gradient updates. Warmup prevents these few vectors from being pushed too far before the model has a chance to learn distributional patterns.

Typical warmup durations are 0.1-2% of total training steps. LLaMA uses 2000 steps; GPT-3 uses warmup over the first 375 million tokens.

Cosine Decay

After warmup, the learning rate follows a cosine curve from ηmax\eta_{\max} down to ηmin\eta_{\min}:

η(t)=ηmin+12(ηmaxηmin)(1+cos(π(tW)TW))\eta(t) = \eta_{\min} + \frac{1}{2}(\eta_{\max} - \eta_{\min})\left(1 + \cos\left(\frac{\pi \cdot (t - W)}{T - W}\right)\right)

where WW is the warmup period and TT is the total training duration. Cosine decay provides a smooth, gradual reduction that spends more time at moderate learning rates than linear decay. Most LLMs set ηmin=0.1ηmax\eta_{\min} = 0.1 \cdot \eta_{\max}.

Linear Decay

A simpler alternative that decreases the learning rate at a constant rate:

η(t)=ηmax(1tWTW)\eta(t) = \eta_{\max} \cdot \left(1 - \frac{t - W}{T - W}\right)

Linear decay is less common for large-scale pretraining but still used in some fine-tuning scenarios.

Warmup-Stable-Decay (WSD)

An emerging schedule used in recent models (e.g., MiniCPM). WSD has three phases:

  1. Warmup: Linear ramp-up (same as above).
  2. Stable: Hold at peak learning rate for the majority of training.
  3. Decay: Rapid cosine or exponential decay in the final phase (typically last 10-20%).

The insight behind WSD is that a constant learning rate during the middle phase enables the model to explore the loss landscape more freely, while the rapid final decay allows it to settle into a sharp minimum. This can also make it easier to resume training or extend the training budget without restarting the schedule from scratch.

import math
from torch.optim.lr_scheduler import LambdaLR

def get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps,
                                     min_lr_ratio=0.1):
    """Cosine decay with linear warmup, used by LLaMA, GPT-3, etc."""
    def lr_lambda(step):
        # Linear warmup
        if step < warmup_steps:
            return step / max(1, warmup_steps)
        # Cosine decay to min_lr_ratio of peak
        progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
        cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
        return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_decay
    return LambdaLR(optimizer, lr_lambda)


def get_wsd_schedule(optimizer, warmup_steps, stable_steps, decay_steps,
                     min_lr_ratio=0.0):
    """Warmup-Stable-Decay schedule as used in MiniCPM."""
    total_steps = warmup_steps + stable_steps + decay_steps
    def lr_lambda(step):
        if step < warmup_steps:
            return step / max(1, warmup_steps)
        elif step < warmup_steps + stable_steps:
            return 1.0
        else:
            decay_progress = (step - warmup_steps - stable_steps) / max(1, decay_steps)
            cosine_decay = 0.5 * (1.0 + math.cos(math.pi * decay_progress))
            return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_decay
    return LambdaLR(optimizer, lr_lambda)

Training Recipes from Real Models

The following table summarizes hyperparameter choices from published LLM training runs:

HyperparameterGPT-3 (175B)LLaMA 2 (70B)Chinchilla (70B)Mistral (7B)
OptimizerAdamAdamWAdamWAdamW
Peak LR0.6e-41.5e-41.0e-43.0e-4
β1,β2\beta_1, \beta_20.9, 0.950.9, 0.950.9, 0.950.9, 0.95
Weight Decay0.10.10.10.1
Warmup375M tokens2000 steps1500 steps1000 steps
LR ScheduleCosineCosineCosineCosine
Final LR10% of peak10% of peak10% of peak10% of peak
Batch Size (tokens)3.2M4M1.5M4M
Gradient Clipping1.01.01.01.0
Total Tokens300B2T1.4T~8T (estimated)

The convergence across different labs and model scales is striking --- the community has largely settled on AdamW with β2=0.95\beta_2 = 0.95, cosine decay to 10% of peak, and gradient clipping at 1.0.

Gradient Clipping

Gradient clipping is a simple but essential technique for training stability. Without it, a single bad batch can produce enormous gradients that corrupt learned parameters and cause loss spikes from which the model may never recover.

The standard approach is global norm clipping. First, compute the global gradient norm across all parameters:

gglobal=igi2\|g\|_{\text{global}} = \sqrt{\sum_{i} \|g_i\|^2}

If this exceeds a threshold cc (typically c=1.0c = 1.0), scale all gradients down proportionally:

gigicmax(c,gglobal)g_i \leftarrow g_i \cdot \frac{c}{\max(c, \|g\|_{\text{global}})}

This preserves the direction of the gradient update while limiting its magnitude. The clipping threshold of 1.0 is nearly universal across LLM training runs, and gradient norm monitoring is one of the most important training diagnostics --- a sustained increase in gradient norms often precedes a loss spike or divergence.

def clip_grad_norm_(parameters, max_norm=1.0):
    """Global gradient norm clipping (simplified version of PyTorch's implementation)."""
    parameters = [p for p in parameters if p.grad is not None]
    # Compute global norm
    total_norm_sq = sum(p.grad.data.norm() ** 2 for p in parameters)
    total_norm = total_norm_sq.sqrt()

    # Scale gradients if norm exceeds threshold
    clip_coef = max_norm / max(total_norm, max_norm)
    for p in parameters:
        p.grad.data.mul_(clip_coef)

    return total_norm

Mixed Precision Training

The Numerical Precision Landscape

Modern GPUs offer several floating-point formats, each with different tradeoffs between range, precision, and throughput:

FormatTotal BitsSignExponentMantissaDynamic RangeUse Case
FP32321823±3.4×1038\pm 3.4 \times 10^{38}Master weights, optimizer states
TF32191810±3.4×1038\pm 3.4 \times 10^{38}NVIDIA Tensor Core internal
BF1616187±3.4×1038\pm 3.4 \times 10^{38}Forward/backward pass
FP16161510±65504\pm 65504Forward/backward (with loss scaling)
FP8 (E4M3)8143±448\pm 448Emerging for training
FP8 (E5M2)8152±57344\pm 57344Emerging for gradients

BF16 vs FP16: Why BF16 Won

The choice between BF16 and FP16 comes down to range vs. precision:

FP16 allocates 5 bits to the exponent and 10 to the mantissa. It has good precision (roughly 3 decimal digits) but a maximum representable value of only 65504. During transformer training, loss values and intermediate activations can easily exceed this range. When they do, the result is either infinity or NaN, and training diverges. The workaround is loss scaling: multiply the loss by a large constant before the backward pass (to push small gradients into representable range), then divide the gradients back down before the optimizer step. A GradScaler dynamically adjusts this scale factor, reducing it whenever overflow is detected.

BF16 allocates 8 bits to the exponent (same as FP32) and only 7 to the mantissa. This gives it the same dynamic range as FP32 (±3.4×1038\pm 3.4 \times 10^{38}) at the cost of lower precision (roughly 2 decimal digits). The key insight is that for neural network training, range matters more than precision. Gradients vary over many orders of magnitude, and BF16 handles this natively without any loss scaling. The reduced precision is acceptable because the stochasticity of SGD already introduces noise far larger than BF16 rounding errors.

In practice, BF16 has become the standard for LLM training because it eliminates the fragile loss-scaling machinery, simplifies the training code, and produces equivalent final model quality.

The Mixed Precision Strategy

Even with BF16, certain operations must remain in FP32 to maintain numerical stability:

  1. Master weights are stored in FP32. The optimizer updates are computed in FP32 and applied to these master weights.
  2. Forward pass runs in BF16. A BF16 copy of the weights is used for the forward computation. On modern GPUs, this runs at 2x the throughput of FP32 on Tensor Cores.
  3. Backward pass runs in BF16. Gradients are computed in reduced precision.
  4. Gradient accumulation is done in FP32. When accumulating gradients across micro-batches, the accumulated buffer must be FP32 to avoid precision loss from repeated additions of small values.
  5. Optimizer step operates in FP32. The moment estimates (mtm_t, vtv_t) and the parameter update are computed in full precision.

Certain numerical operations should always use FP32 regardless of the global precision setting: softmax (exponentiation is sensitive to rounding), layer normalization (variance computation), and loss computation (log-probabilities can be very small).

import torch
from torch.cuda.amp import autocast

def mixed_precision_training_loop(model, dataloader, optimizer, num_epochs,
                                   grad_accum_steps=4, max_grad_norm=1.0):
    """Full mixed precision training loop with BF16.

    BF16 does not require GradScaler (unlike FP16), simplifying the code.
    """
    model.train()

    for epoch in range(num_epochs):
        for step, batch in enumerate(dataloader):
            is_accumulation_step = (step + 1) % grad_accum_steps != 0

            # Forward pass in BF16
            with autocast(dtype=torch.bfloat16):
                logits = model(batch["input_ids"])
                loss = torch.nn.functional.cross_entropy(
                    logits.view(-1, logits.size(-1)),
                    batch["labels"].view(-1),
                    ignore_index=-100,
                )
                loss = loss / grad_accum_steps

            # Backward pass (gradients computed in BF16, accumulated in FP32)
            loss.backward()

            if not is_accumulation_step:
                # Clip gradients (computed on FP32 gradient buffers)
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), max_norm=max_grad_norm
                )
                # Optimizer step updates FP32 master weights
                optimizer.step()
                optimizer.zero_grad()

                if step % 100 == 0:
                    print(f"Step {step}, Loss: {loss.item():.4f}, "
                          f"Grad Norm: {grad_norm:.4f}")

Memory Savings from Mixed Precision

For a model with NN parameters, the memory breakdown per parameter is:

ComponentFP32 OnlyMixed Precision (BF16 + FP32)
Model weights4 bytes (FP32)2 bytes (BF16) + 4 bytes (FP32 master)
Gradients4 bytes (FP32)2 bytes (BF16)
Optimizer states (AdamW)8 bytes (2x FP32)8 bytes (2x FP32)
Total per parameter16 bytes16 bytes
Activations (saved for backward)FP32BF16 (2x reduction)

The parameter-level savings are modest, but the activation memory savings are substantial. Activations dominate memory for long sequences, and storing them in BF16 cuts that memory in half, enabling longer context lengths or larger batch sizes.

Gradient Checkpointing

The Activation Memory Problem

During the backward pass, we need the activations from the forward pass to compute gradients. Naively, this means storing activations for every layer. For a model with LL layers, batch size BB, sequence length TT, and hidden dimension dd:

Activation MemoryL×B×T×d×bytes_per_element\text{Activation Memory} \approx L \times B \times T \times d \times \text{bytes\_per\_element}

For LLaMA-70B (L=80L = 80, d=8192d = 8192) with B=1B = 1, T=4096T = 4096, in BF16: this is roughly 80×1×4096×8192×25.480 \times 1 \times 4096 \times 8192 \times 2 \approx 5.4 GB just for the hidden state activations. Including attention scores, FFN intermediates, and normalization buffers, the actual figure is several times higher and can easily exceed 100 GB.

How Checkpointing Works

Gradient checkpointing (Chen et al., 2016) trades compute for memory. Instead of storing all activations, we designate certain layers as "checkpoints" and only store their outputs. During the backward pass, when we need activations from non-checkpointed layers, we recompute them by running a partial forward pass from the nearest checkpoint.

The most common strategy is to checkpoint every transformer block boundary:

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint


class TransformerBlock(nn.Module):
    """A single transformer block (simplified)."""
    def __init__(self, d_model, n_heads, d_ff):
        super().__init__()
        self.attn_norm = nn.LayerNorm(d_model)
        self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
        self.ff_norm = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model),
        )

    def forward(self, x):
        x = x + self.attn(self.attn_norm(x), self.attn_norm(x), self.attn_norm(x))[0]
        x = x + self.ff(self.ff_norm(x))
        return x


class CheckpointedTransformerModel(nn.Module):
    """Transformer with gradient checkpointing to reduce activation memory."""
    def __init__(self, n_layers, d_model, n_heads, d_ff, use_checkpointing=True):
        super().__init__()
        self.layers = nn.ModuleList([
            TransformerBlock(d_model, n_heads, d_ff) for _ in range(n_layers)
        ])
        self.use_checkpointing = use_checkpointing

    def forward(self, x):
        for layer in self.layers:
            if self.use_checkpointing and self.training:
                # Activations for this layer are NOT stored.
                # They will be recomputed during backward pass.
                x = checkpoint(layer, x, use_reentrant=False)
            else:
                x = layer(x)
        return x

The Memory-Compute Tradeoff

The tradeoff is clean and predictable. Checkpointing every layer means we only store the input to the model and the final output, reducing activation memory from O(L)O(L) to O(1)O(1), but we recompute every layer's forward pass during backward, roughly doubling the total compute. The optimal strategy checkpoints every L\sqrt{L} layers, achieving O(L)O(\sqrt{L}) memory with approximately 50% compute overhead.

StrategyActivation MemoryForward Compute OverheadWhen to Use
No checkpointingO(L)O(L)1.0xSmall models, ample GPU memory
Every L\sqrt{L} layersO(L)O(\sqrt{L})~1.5xBalanced approach
Every layerO(1)O(1)~2.0xMaximum memory savings
Selective (attention only)~O(L/2)O(L / 2)~1.3xAttention is the bottleneck

In practice, checkpointing every layer is most common for large model training because the compute overhead (roughly 30-40% in practice, less than the theoretical 2x due to memory bandwidth effects) is acceptable given the memory savings.

Distributed Training

A 70B parameter model in FP32 requires 280 GB just for the weights, far exceeding any single GPU's memory. Even with mixed precision and checkpointing, training at scale requires distributing the computation across many GPUs.

Data Parallelism (DDP)

The simplest distributed strategy is to replicate the model on every GPU and split the data. Each GPU processes a different mini-batch, computes gradients independently, then synchronizes gradients via an all-reduce operation before the optimizer step.

DDP (Distributed Data Parallel) in PyTorch overlaps gradient communication with backward computation: as soon as a layer's gradients are computed, the all-reduce for that layer begins while the backward pass continues through earlier layers.

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

def setup_ddp(rank, world_size):
    """Initialize DDP process group."""
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

def train_with_ddp(rank, world_size, model, dataloader, optimizer):
    setup_ddp(rank, world_size)
    model = model.to(rank)
    model = DDP(model, device_ids=[rank])

    for batch in dataloader:
        batch = {k: v.to(rank) for k, v in batch.items()}
        loss = model(batch["input_ids"], labels=batch["labels"]).loss
        loss.backward()   # Gradients all-reduced automatically
        optimizer.step()
        optimizer.zero_grad()

DDP's limitation: every GPU must hold a full copy of the model. For a 70B model with AdamW, this means ~1.1 TB per GPU (weights + gradients + optimizer states) --- clearly infeasible.

Fully Sharded Data Parallelism (FSDP)

FSDP (Zhao et al., 2023) extends data parallelism by sharding model parameters, gradients, and optimizer states across GPUs. This is conceptually similar to DeepSpeed ZeRO Stage 3.

The key operations in FSDP are:

  1. All-gather before each layer's forward pass: collect the full parameter tensor from all GPUs.
  2. Forward computation: Run the layer with the full parameters.
  3. Discard the gathered parameters after use (only keep the local shard).
  4. All-gather again during the backward pass, compute gradients.
  5. Reduce-scatter gradients: each GPU receives the gradient shard corresponding to its parameter shard.
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
    ShardingStrategy,
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
import functools

# Define wrapping policy: shard at the transformer block level
auto_wrap_policy = functools.partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls={TransformerBlock},
)

# Mixed precision configuration
mixed_precision_policy = MixedPrecision(
    param_dtype=torch.bfloat16,     # Parameters gathered in BF16
    reduce_dtype=torch.float32,     # Gradient reduction in FP32
    buffer_dtype=torch.bfloat16,    # Buffers in BF16
)

model = FSDP(
    model,
    sharding_strategy=ShardingStrategy.FULL_SHARD,
    mixed_precision=mixed_precision_policy,
    auto_wrap_policy=auto_wrap_policy,
    device_id=torch.cuda.current_device(),
)

Per-GPU Memory Comparison

The memory reduction from FSDP is dramatic. For a 70B parameter model with AdamW in mixed precision across 8 GPUs:

ComponentDDP (per GPU)FSDP Full Shard (per GPU)
Parameters (BF16)140 GB17.5 GB
Gradients (BF16)140 GB17.5 GB
Optimizer states (FP32)560 GB70 GB
FP32 master weights280 GB35 GB
Total1,120 GB140 GB

This makes it feasible to train a 70B model on 8x 80GB A100s with FSDP, which would be impossible with DDP.

Tensor Parallelism (TP)

Tensor parallelism (Shoeybi et al., 2019) splits individual matrix operations across GPUs. For a linear layer Y=XWY = XW where WRd×dW \in \mathbb{R}^{d \times d}, we can partition WW column-wise across NN GPUs:

Y=X[W1W2WN]=[XW1XW2XWN]Y = X[W_1 \mid W_2 \mid \ldots \mid W_N] = [XW_1 \mid XW_2 \mid \ldots \mid XW_N]

Each GPU computes Yi=XWiY_i = XW_i and the results are concatenated via an all-gather. For the MLP in a transformer block, this is applied to both linear layers with complementary splits (column-parallel for the first, row-parallel for the second) to minimize communication.

Tensor parallelism is typically used with degree 2, 4, or 8 within a single node where GPUs are connected by NVLink (900 GB/s on H100 SXM), because the all-reduce communication at every layer is latency-sensitive.

Pipeline Parallelism (PP)

Pipeline parallelism assigns different groups of layers to different GPUs. A model with 80 layers across 4 GPUs might assign layers 1-20 to GPU 0, 21-40 to GPU 1, and so on.

The naive implementation creates "pipeline bubbles" where GPUs sit idle waiting for activations from earlier stages. Micro-batching (GPipe) and interleaved scheduling (PipeDream) reduce bubble overhead by breaking the batch into smaller micro-batches and overlapping computation.

The bubble fraction for a PP-stage pipeline with MM micro-batches is approximately:

Bubble Fraction=P1M+P1\text{Bubble Fraction} = \frac{P - 1}{M + P - 1}

With MPM \gg P, the bubble overhead becomes negligible.

3D Parallelism

For the largest training runs (thousands of GPUs), all three strategies are combined:

Total GPUs=DP×TP×PP\text{Total GPUs} = \text{DP} \times \text{TP} \times \text{PP}

Typical configurations exploit the hardware topology:

  • TP within a node (4-8 GPUs connected by NVLink).
  • PP across nodes within a rack (fast interconnect).
  • DP/FSDP across racks.

Example: training a 175B model on 1024 H100 GPUs might use TP=8 (within each node), PP=4 (across 4 nodes), and DP=32 (32 groups of 4 nodes).

ParallelismWhat is splitCommunicationTypical Scale
Data (DDP/FSDP)Batches (and optionally model state)All-reduce / reduce-scatter8-1000s of GPUs
Tensor (TP)Individual layers/matricesAll-reduce per layer2-8 GPUs (within node)
Pipeline (PP)Groups of layersPoint-to-point activations2-16 stages

Putting It All Together: A Training Recipe

Combining everything discussed, here is a summary of the standard configuration for modern LLM pretraining:

ComponentStandard ChoiceRationale
OptimizerAdamW (β1=0.9\beta_1=0.9, β2=0.95\beta_2=0.95, ϵ=108\epsilon=10^{-8})Decoupled weight decay; β2=0.95\beta_2 = 0.95 improves late-training stability
Weight Decay0.1Applied to all params except biases and LayerNorm
Peak Learning RateScales with model size (e.g., 3e-4 for 7B, 1.5e-4 for 70B)Smaller models tolerate higher LR
Warmup1-2% of total steps (1000-2000 steps)Stabilizes Adam's variance estimate
LR ScheduleCosine decay to 10% of peakSmooth decay, well-studied empirically
PrecisionBF16 mixed precisionSame range as FP32, no loss scaling needed
Gradient ClippingGlobal norm = 1.0Prevents loss spikes from bad batches
Gradient CheckpointingEvery transformer blockEssential for large models
Batch SizeRamp from small to large (4M tokens typical)Large batch for throughput, small early for stability
Distributed StrategyFSDP + TP (within node)Balances memory and communication

In the next post, we will explore Part 6: Inference Optimization --- the challenges of deploying trained models in production, including KV-cache mechanics, quantization from INT8 to INT4, speculative decoding for faster generation, and continuous batching with PagedAttention.

References

  • Kingma, D. P. & Ba, J. (2015). Adam: A Method for Stochastic Optimization. ICLR 2015. arXiv:1412.6980.
  • Loshchilov, I. & Hutter, F. (2019). Decoupled Weight Decay Regularization. ICLR 2019. arXiv:1711.05101.
  • Chen, T., Xu, B., Zhang, C., & Guestrin, C. (2016). Training Deep Nets with Sublinear Memory Cost. arXiv:1604.06174.
  • Shoeybi, M., Patwary, M., Puri, R., LeGresley, P., Casper, J., & Catanzaro, B. (2019). Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism. arXiv:1909.08053.
  • Zhao, Y., Gu, A., Varma, R., Luo, L., Huang, C., Xu, M., ... & Chintala, S. (2023). PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel. VLDB 2023. arXiv:2304.11277.
  • Touvron, H., Lavril, T., Izacard, G., et al. (2023). LLaMA: Open and Efficient Foundation Language Models. arXiv:2302.13971.
  • Touvron, H., Martin, L., Stone, K., et al. (2023). LLaMA 2: Open Foundation and Fine-Tuned Chat Models. arXiv:2307.09288.
  • Brown, T. B., Mann, B., Ryder, N., et al. (2020). Language Models are Few-Shot Learners. NeurIPS 2020. arXiv:2005.14165.
  • Hoffmann, J., Borgeaud, S., Mensch, A., et al. (2022). Training Compute-Optimal Large Language Models. arXiv:2203.15556.
  • Chen, X., Liang, C., Huang, D., et al. (2023). Symbolic Discovery of Optimization Algorithms (Lion). arXiv:2302.06675.
  • Micikevicius, P., Narang, S., Alben, J., et al. (2018). Mixed Precision Training. ICLR 2018. arXiv:1710.03740.
  • Hu, S., Tu, Y., Han, X., et al. (2024). MiniCPM: Unveiling the Potential of Small Language Models with Scalable Training Strategies. arXiv:2404.06395.
Share:
SW

Written by Suchinthaka Wanninayaka

AI/ML Researcher exploring semantic communications, diffusion models, and language model systems. Writing about deep learning from theory to production.

Responses

?

No responses yet. Be the first to share your thoughts!