All Articles
transformerstrainingoptimizationadamwmixed-precisiondistributed-training

Transformer Deep Dive: Part 5 - Training Improvements

Modern training techniques for LLMs - AdamW optimizer, learning rate schedules, mixed precision training (FP16/BF16), gradient checkpointing, and distributed training strategies.

SW

Suchinthaka W.

January 19, 2025 · 7 min read

Training large language models requires carefully orchestrated techniques that address three fundamental challenges: optimization stability across billions of parameters, memory constraints on limited GPU resources, and computational efficiency across distributed systems.

The Training Loop

The training of a transformer involves iteratively updating parameters θ\theta to minimize a loss function:

θt+1=θtηUpdate(θL,θt,t)\theta_{t+1} = \theta_t - \eta \cdot \text{Update}(\nabla_\theta \mathcal{L}, \theta_t, t)

where the specific form of Update(·) defines the optimizer, η\eta is the learning rate (often scheduled), and computation may be distributed across devices with reduced precision.

Optimizers

Adam (Adaptive Moment Estimation)

Adam combines momentum with adaptive learning rates per parameter:

mt=β1mt1+(1β1)gtm_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t vt=β2vt1+(1β2)gt2v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2 m^t=mt1β1t,v^t=vt1β2t\hat{m}_t = \frac{m_t}{1 - \beta_1^t}, \quad \hat{v}_t = \frac{v_t}{1 - \beta_2^t} θt+1=θtηm^tv^t+ϵ\theta_{t+1} = \theta_t - \eta \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}

Typical values: β1=0.9\beta_1 = 0.9, β2=0.999\beta_2 = 0.999, ϵ=108\epsilon = 10^{-8}

AdamW: Decoupled Weight Decay

A critical discovery: L2 regularization and weight decay are not equivalent in Adam!

The Problem: In standard Adam with L2 regularization, the regularization term is scaled by the adaptive learning rate:

θt+1=θtηm^t+λθtv^t+ϵ\theta_{t+1} = \theta_t - \eta \frac{\hat{m}_t + \lambda\theta_t}{\sqrt{\hat{v}_t} + \epsilon}

Parameters with larger gradients receive less regularization—the opposite of what we want.

AdamW Solution: Apply weight decay directly to parameters, outside the adaptive update:

θt+1=θtη(m^tv^t+ϵ+λθt)\theta_{t+1} = \theta_t - \eta \left( \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} + \lambda\theta_t \right)
# Simplified AdamW
for param in model.parameters():
    if param.grad is None:
        continue

    # Standard Adam update
    m = beta1 * m + (1 - beta1) * param.grad
    v = beta2 * v + (1 - beta2) * param.grad ** 2
    m_hat = m / (1 - beta1 ** t)
    v_hat = v / (1 - beta2 ** t)

    # AdamW: weight decay applied directly
    param.data -= lr * (m_hat / (v_hat.sqrt() + eps) + weight_decay * param.data)

Optimizer Comparison

| Optimizer | Memory (per param) | Key Feature | |-----------|-------------------|-------------| | SGD | 0 bytes | Simple, needs tuning | | SGD + Momentum | 4 bytes | More stable | | Adam | 8 bytes | Adaptive LR | | AdamW | 8 bytes | Proper weight decay | | Adafactor | 4 bytes* | Memory efficient | | Lion | 4 bytes | Simpler, competitive |

*Adafactor uses factored second moments

Learning Rate Schedules

Warmup

Warmup is critical for transformer training. Start with a small learning rate and gradually increase:

lr(t)=lrmaxtwarmup_steps,t<warmup_steps\text{lr}(t) = \text{lr}_{max} \cdot \frac{t}{\text{warmup\_steps}}, \quad t < \text{warmup\_steps}

Why Warmup?

  • Adam's variance estimate vtv_t is biased early in training
  • Large initial gradients can destabilize training
  • Especially important with Pre-LN transformers

Cosine Decay

After warmup, decay the learning rate following a cosine curve:

lr(t)=lrmin+12(lrmaxlrmin)(1+cos(πtT))\text{lr}(t) = \text{lr}_{min} + \frac{1}{2}(\text{lr}_{max} - \text{lr}_{min})\left(1 + \cos\left(\frac{\pi t}{T}\right)\right)

Linear Decay

Simple linear decrease:

lr(t)=lrmax(1tT)\text{lr}(t) = \text{lr}_{max} \cdot \left(1 - \frac{t}{T}\right)

Common Configurations

| Model | Warmup | Decay | Final LR | |-------|--------|-------|----------| | GPT-3 | 375M tokens | Cosine | 10% of max | | LLaMA | 2000 steps | Cosine | 10% of max | | Chinchilla | 1500 steps | Cosine | 10% of max |

def get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps):
    def lr_lambda(step):
        if step < warmup_steps:
            return step / warmup_steps
        progress = (step - warmup_steps) / (total_steps - warmup_steps)
        return 0.1 + 0.45 * (1 + math.cos(math.pi * progress))
    return LambdaLR(optimizer, lr_lambda)

Mixed Precision Training

The Precision Hierarchy

| Format | Bits | Range | Precision | Use Case | |--------|------|-------|-----------|----------| | FP32 | 32 | ±3.4e38 | High | Master weights | | TF32 | 19 | ±3.4e38 | Medium | Tensor cores | | BF16 | 16 | ±3.4e38 | Low | Training | | FP16 | 16 | ±65504 | Medium | Training | | FP8 | 8 | ±448 | Low | Inference |

BF16 vs FP16

FP16 (IEEE Half Precision):

  • 1 sign, 5 exponent, 10 mantissa
  • Higher precision, limited range
  • Needs loss scaling to prevent overflow/underflow

BF16 (Brain Float):

  • 1 sign, 8 exponent, 7 mantissa
  • Lower precision, same range as FP32
  • No loss scaling needed
FP32:  [1][8 exponent bits][23 mantissa bits]
FP16:  [1][5 exponent bits][10 mantissa bits]
BF16:  [1][8 exponent bits][7 mantissa bits]

Mixed Precision Strategy

  1. Master weights in FP32
  2. Forward pass in FP16/BF16
  3. Backward pass in FP16/BF16
  4. Gradient accumulation in FP32
  5. Optimizer update in FP32
from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()  # For FP16, not needed for BF16
optimizer = torch.optim.AdamW(model.parameters())

for batch in dataloader:
    optimizer.zero_grad()

    # Forward in mixed precision
    with autocast(dtype=torch.bfloat16):
        loss = model(batch)

    # Backward (scaler only for FP16)
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

Memory Savings

| Precision | Memory per Param | Relative | |-----------|-----------------|----------| | FP32 weights + FP32 optimizer | 16 bytes | 1× | | FP16 weights + FP32 optimizer | 12 bytes | 0.75× | | BF16 weights + FP32 optimizer | 12 bytes | 0.75× |

Gradient Checkpointing

The Memory Problem

During backpropagation, we need activations from the forward pass. For a model with L layers:

Activation Memory=O(L×B×T×d)\text{Activation Memory} = O(L \times B \times T \times d)

For a 70B model with 80 layers, 2K context, this can exceed 100GB!

Checkpointing Strategy

Trade compute for memory: Don't store all activations. Instead:

  1. Store activations at "checkpoint" layers only
  2. During backward pass, recompute activations between checkpoints
import torch.utils.checkpoint as checkpoint

class CheckpointedTransformer(nn.Module):
    def __init__(self, layers):
        super().__init__()
        self.layers = layers

    def forward(self, x):
        for layer in self.layers:
            # Recompute forward pass during backward
            x = checkpoint.checkpoint(layer, x, use_reentrant=False)
        return x

Memory-Compute Tradeoff

| Strategy | Memory | Compute | |----------|--------|---------| | No checkpointing | O(L) | 1× | | Every layer | O(1) | ~2× | | Every √L layers | O(√L) | ~1.5× |

Distributed Training

Data Parallelism (DP)

Simplest approach: replicate model on each GPU, split data.

GPU 0: Full model, Batch 0
GPU 1: Full model, Batch 1
GPU 2: Full model, Batch 2
GPU 3: Full model, Batch 3
                ↓
        All-Reduce gradients
                ↓
        Synchronized update

Fully Sharded Data Parallelism (FSDP)

Shard model parameters, gradients, and optimizer states across GPUs:

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

model = FSDP(
    model,
    sharding_strategy=ShardingStrategy.FULL_SHARD,
    mixed_precision=MixedPrecision(
        param_dtype=torch.bfloat16,
        reduce_dtype=torch.float32,
    ),
)

Memory per GPU (70B Model)

| Strategy | Params | Grads | Optimizer | Total | |----------|--------|-------|-----------|-------| | No parallelism | 280GB | 280GB | 560GB | 1.1TB | | DP (8 GPUs) | 280GB | 280GB | 560GB | 1.1TB | | FSDP (8 GPUs) | 35GB | 35GB | 70GB | 140GB |

Tensor Parallelism

Split individual layers across GPUs:

Linear Layer: Y = XW

GPU 0: Y_0 = X @ W_0    (first half of columns)
GPU 1: Y_1 = X @ W_1    (second half of columns)
            ↓
        All-Gather
            ↓
        Y = [Y_0, Y_1]

Pipeline Parallelism

Split layers across GPUs, process micro-batches in pipeline:

Time  →  T0    T1    T2    T3    T4    T5
GPU 0:   F0    F1    F2    B0    B1    B2
GPU 1:         F0    F1    F2    B0    B1
GPU 2:               F0    F1    F2    B0
GPU 3:                     F0    F1    F2

F = Forward, B = Backward

3D Parallelism

Combine all strategies for maximum scale:

World Size=DP×TP×PP\text{World Size} = DP \times TP \times PP

Example: 1024 GPUs = 128 DP × 4 TP × 2 PP

Training Recipe Summary

| Component | Recommendation | |-----------|---------------| | Optimizer | AdamW (β1=0.9\beta_1=0.9, β2=0.95\beta_2=0.95) | | Weight Decay | 0.1 | | Warmup | 1-2% of training | | LR Schedule | Cosine decay to 10% | | Precision | BF16 mixed precision | | Gradient Clipping | 1.0 | | Batch Size | As large as memory allows |


In the next post, we'll explore Part 6: Inference Optimization - KV-cache, quantization, speculative decoding, and continuous batching for production deployment.

Did you find this helpful?
Share:

Responses

Be the first to share your thoughts!