Training a large language model may cost millions of dollars, but inference --- the process of generating text from a trained model --- accounts for the vast majority of total compute expenditure over a model's lifetime. A model trained once will serve billions of requests. At this scale, a 2x improvement in inference throughput directly halves serving costs, making inference optimization one of the highest-leverage problems in production ML.

In this post, we dissect the key techniques that make LLM serving practical: why autoregressive decoding is memory-bound rather than compute-bound, how the KV-cache eliminates redundant computation (and why it creates its own memory challenges), quantization methods that shrink models by 4-8x with minimal quality loss, speculative decoding that generates multiple tokens per forward pass, and continuous batching with PagedAttention for maximizing GPU utilization.

The Inference Challenge

LLM inference has two distinct phases with very different computational characteristics:

  1. Prefill (prompt processing): The entire input prompt is processed in a single forward pass. This phase is compute-bound --- the GPU's arithmetic units are the bottleneck, similar to training. Matrix multiplications operate on full sequence-length tensors, achieving high arithmetic intensity.

  2. Decode (token generation): Tokens are generated one at a time, autoregressively. Each step requires loading the entire model from GPU memory to compute a single token's logits. This phase is memory-bandwidth-bound --- the GPU spends most of its time waiting for data to arrive from HBM, not performing arithmetic.

The key metrics for production serving reflect these two phases:

MetricDescriptionTypical Target
Time to First Token (TTFT)Latency of the prefill phase< 500ms
Time per Output Token (TPOT)Latency of each decode step< 50ms
ThroughputTotal tokens generated per second across all requestsMaximize
Memory EfficiencyFraction of GPU memory used productivelyMaximize

Why Decoding is Memory-Bound

To understand the memory bottleneck, consider the arithmetic intensity of a single decode step. For a model with NN parameters, generating one token requires roughly 2N2N FLOPs (one multiply-add per parameter). With a 70B model, that is 1.4×10111.4 \times 10^{11} FLOPs. Meanwhile, loading all 70B parameters in BF16 requires reading 1.4×10111.4 \times 10^{11} bytes from HBM.

The arithmetic intensity is:

Arithmetic Intensity=FLOPsBytes Loaded=2N2N=1 FLOP/byte\text{Arithmetic Intensity} = \frac{\text{FLOPs}}{\text{Bytes Loaded}} = \frac{2N}{2N} = 1 \text{ FLOP/byte}

Modern GPUs have compute-to-bandwidth ratios far exceeding 1:

GPUHBM Bandwidth (TB/s)BF16 TFLOPsCompute:Bandwidth Ratio
A100 80GB SXM2.0312156:1
H100 80GB SXM3.35990296:1
H200 141GB4.8990206:1

With a ratio of ~200:1, the GPU can perform 200 FLOPs for every byte loaded, but single-token decoding only needs 1 FLOP per byte. This means the compute units are idle more than 99% of the time during decoding, waiting for parameters to stream from memory. This is the fundamental reason why batching (processing multiple requests simultaneously) and reducing memory transfers (via quantization, caching) are so impactful.

KV-Cache

The Redundant Computation Problem

In autoregressive generation, we produce tokens one at a time. At each step tt, the self-attention mechanism computes:

Attention(Qt,K1:t,V1:t)=softmax(QtK1:tTdk)V1:t\text{Attention}(Q_t, K_{1:t}, V_{1:t}) = \text{softmax}\left(\frac{Q_t K_{1:t}^T}{\sqrt{d_k}}\right) V_{1:t}

The query QtQ_t only contains the new token, but the keys K1:tK_{1:t} and values V1:tV_{1:t} include all previous tokens. Without caching, we would recompute the key and value projections for every previous token at every step:

Step 1: Compute K₁, V₁ for token "The"
Step 2: Compute K₁, V₁ for "The", K₂, V₂ for "cat"        → K₁,V₁ recomputed!
Step 3: Compute K₁,V₁, K₂,V₂, K₃,V₃ for "The cat sat"    → K₁,V₁,K₂,V₂ recomputed!
Step 4: Compute K₁..₃, V₁..₃, K₄, V₄ for "The cat sat on" → K₁..₃,V₁..₃ recomputed!

The total computation for generating nn tokens scales as O(n2d)O(n^2 \cdot d), since step tt processes tt tokens.

How KV-Cache Works

The solution is straightforward: cache the key and value projections once computed, and reuse them in subsequent steps.

KV-Cache mechanism showing how keys and values are cached and reused across decoding steps

With KV-caching, each decode step only computes QtQ_t, KtK_t, VtV_t for the single new token, appends KtK_t and VtV_t to the cache, and computes attention using the cached keys and values:

Step 1: Compute K₁, V₁ → Cache: [K₁], [V₁]
Step 2: Compute K₂, V₂ → Cache: [K₁, K₂], [V₁, V₂]       → Only new token projected!
Step 3: Compute K₃, V₃ → Cache: [K₁, K₂, K₃], [V₁, V₂, V₃]
Step 4: Compute K₄, V₄ → Cache: [K₁..₄], [V₁..₄]

Total computation for nn tokens drops from O(n2d)O(n^2 \cdot d) to O(nd)O(n \cdot d) --- a linear improvement.

Implementation

A production KV-cache implementation pre-allocates memory for the maximum sequence length and fills it incrementally:

import torch
import torch.nn as nn
import torch.nn.functional as F


class KVCache:
    """Pre-allocated KV-Cache for efficient autoregressive decoding.

    Pre-allocates tensors for keys and values up to the maximum sequence
    length, then fills them incrementally as tokens are generated.
    """
    def __init__(self, max_batch_size: int, max_seq_len: int,
                 n_kv_heads: int, head_dim: int,
                 dtype=torch.bfloat16, device="cuda"):
        self.max_seq_len = max_seq_len
        self.cache_k = torch.zeros(
            (max_batch_size, n_kv_heads, max_seq_len, head_dim),
            dtype=dtype, device=device,
        )
        self.cache_v = torch.zeros(
            (max_batch_size, n_kv_heads, max_seq_len, head_dim),
            dtype=dtype, device=device,
        )
        self.seq_len = 0

    def update(self, k_new: torch.Tensor, v_new: torch.Tensor):
        """Append new keys and values to the cache.

        Args:
            k_new: New key tensor of shape (batch, n_kv_heads, new_seq_len, head_dim)
            v_new: New value tensor of shape (batch, n_kv_heads, new_seq_len, head_dim)

        Returns:
            Full cached keys and values up to the current position.
        """
        batch_size, _, new_seq_len, _ = k_new.shape
        end_pos = self.seq_len + new_seq_len

        assert end_pos <= self.max_seq_len, "Sequence length exceeds cache capacity"

        self.cache_k[:batch_size, :, self.seq_len:end_pos, :] = k_new
        self.cache_v[:batch_size, :, self.seq_len:end_pos, :] = v_new
        self.seq_len = end_pos

        return (
            self.cache_k[:batch_size, :, :self.seq_len, :],
            self.cache_v[:batch_size, :, :self.seq_len, :],
        )

    def reset(self):
        self.seq_len = 0


class CachedAttention(nn.Module):
    """Multi-head attention with KV-cache support for inference."""
    def __init__(self, d_model: int, n_heads: int, n_kv_heads: int):
        super().__init__()
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.head_dim = d_model // n_heads
        self.n_rep = n_heads // n_kv_heads  # GQA repetition factor

        self.q_proj = nn.Linear(d_model, n_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(d_model, n_kv_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(d_model, n_kv_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(n_heads * self.head_dim, d_model, bias=False)

    def forward(self, x: torch.Tensor, kv_cache: KVCache = None):
        batch, seq_len, _ = x.shape

        q = self.q_proj(x).view(batch, seq_len, self.n_heads, self.head_dim)
        k = self.k_proj(x).view(batch, seq_len, self.n_kv_heads, self.head_dim)
        v = self.v_proj(x).view(batch, seq_len, self.n_kv_heads, self.head_dim)

        # Transpose to (batch, heads, seq_len, head_dim)
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        # Update cache and get full key/value history
        if kv_cache is not None:
            k, v = kv_cache.update(k, v)

        # Expand KV heads for GQA: (batch, n_kv_heads, seq, dim) -> (batch, n_heads, seq, dim)
        if self.n_rep > 1:
            k = k.repeat_interleave(self.n_rep, dim=1)
            v = v.repeat_interleave(self.n_rep, dim=1)

        # Scaled dot-product attention
        scale = self.head_dim ** -0.5
        attn_weights = torch.matmul(q, k.transpose(-2, -1)) * scale

        # Causal mask (only needed during prefill; during decode, q has length 1)
        if seq_len > 1:
            causal_mask = torch.triu(
                torch.full((seq_len, k.size(-2)), float("-inf"), device=x.device),
                diagonal=k.size(-2) - seq_len + 1,
            )
            attn_weights = attn_weights + causal_mask

        attn_weights = F.softmax(attn_weights, dim=-1)
        output = torch.matmul(attn_weights, v)

        output = output.transpose(1, 2).reshape(batch, seq_len, -1)
        return self.o_proj(output)

KV-Cache Memory Analysis

The KV-cache stores two tensors (K and V) per layer, per head, for every token in the sequence. For a model with LL layers, nkvn_{kv} key-value heads, and head dimension dhd_h, the cache size per token per batch element is:

KV Cache per token=2×L×nkv×dh×bytes\text{KV Cache per token} = 2 \times L \times n_{kv} \times d_h \times \text{bytes}

For the full sequence of length SS with batch size BB:

Total KV Cache=2×L×nkv×dh×S×B×bytes\text{Total KV Cache} = 2 \times L \times n_{kv} \times d_h \times S \times B \times \text{bytes}

Let us compute concrete numbers for LLaMA models in BF16 (2 bytes per element):

ModelLayersnkvn_{kv}dhd_hKV per Token4K Context32K Context128K Context
LLaMA-2 7B32321280.5 MB2 GB16 GB64 GB
LLaMA-2 13B40401280.8 MB3.1 GB25 GB100 GB
LLaMA-2 70B808 (GQA)1280.3 MB1.25 GB10 GB40 GB
LLaMA-3 405B1268 (GQA)1280.5 MB2 GB16 GB64 GB

Notice that LLaMA-2 70B uses GQA with only 8 KV heads instead of 64, which reduces its KV-cache by 8x compared to full MHA. This is one of the primary motivations for Grouped-Query Attention (as we discussed in Part 3). Without GQA, a 70B model at 128K context would need 320 GB of KV-cache alone --- more than any single GPU can hold.

For a serving scenario with batch size 32 and 4K context, a 70B GQA model needs 0.3×4096×32400.3 \times 4096 \times 32 \approx 40 GB of KV-cache --- roughly half the capacity of an 80 GB A100. The KV-cache, not the model weights, becomes the binding memory constraint for high-throughput serving.

Quantization

The Core Idea

Quantization reduces the numerical precision of model weights (and optionally activations) from 16 or 32 bits to 8, 4, or even lower bits. Since decoding is memory-bandwidth-bound, reducing the size of the weights proportionally increases the number of tokens we can generate per second.

The basic formulation for uniform affine quantization maps a floating-point tensor WW to bb-bit integers:

Wint=round(Wzs),s=max(W)min(W)2b1,z=min(W)W_{\text{int}} = \text{round}\left(\frac{W - z}{s}\right), \quad s = \frac{\max(W) - \min(W)}{2^b - 1}, \quad z = \min(W)

where ss is the scale and zz is the zero-point. Dequantization recovers an approximation:

W^=sWint+z\hat{W} = s \cdot W_{\text{int}} + z

The quantization error is the rounding error: WW^\|W - \hat{W}\|, which we want to minimize while using as few bits as possible.

Weight-Only vs. Weight-and-Activation Quantization

For LLMs, weight-only quantization is the dominant approach. Weights are static (they do not change between requests) and can be carefully quantized offline. Activations, on the other hand, vary with each input and contain outliers that make quantization harder.

ApproachWhat is QuantizedCompute KernelTypical Formats
Weight-onlyWeights (offline)W_int x A_fp16 mixed-precision matmulINT4, INT8, NF4
Weight + ActivationBoth (online)W_int x A_int integer matmulINT8 x INT8

Weight-only INT4 quantization reduces memory by 4x while keeping activations in FP16/BF16, achieving the memory savings needed to fit large models on fewer GPUs while maintaining most of the model quality.

Group Quantization

Quantizing an entire weight matrix with a single scale and zero-point is too coarse --- the outlier values force a wide range, wasting bits on the majority of values that cluster near zero. Group quantization divides each row (or column) into groups of GG elements (typically G=128G = 128) and assigns a separate scale and zero-point to each group:

import torch

def quantize_symmetric(weight: torch.Tensor, bits: int = 4,
                       group_size: int = 128) -> tuple:
    """Symmetric group quantization.

    Divides each row into groups and quantizes each group independently
    with its own scale factor. Symmetric quantization centers around zero,
    so no zero-point is needed.

    Args:
        weight: FP16/BF16 weight tensor of shape (out_features, in_features)
        bits: Number of quantization bits
        group_size: Number of elements per quantization group

    Returns:
        Tuple of (quantized_weight, scales)
    """
    out_features, in_features = weight.shape
    assert in_features % group_size == 0

    # Reshape into groups: (out_features, n_groups, group_size)
    n_groups = in_features // group_size
    weight_grouped = weight.reshape(out_features, n_groups, group_size)

    # Compute per-group scale (symmetric: max absolute value)
    qmax = 2 ** (bits - 1) - 1  # e.g., 7 for 4-bit
    scales = weight_grouped.abs().amax(dim=-1, keepdim=True) / qmax
    scales = scales.clamp(min=1e-10)  # Avoid division by zero

    # Quantize
    weight_int = torch.round(weight_grouped / scales).clamp(-qmax, qmax).to(torch.int8)

    # Reshape back
    weight_int = weight_int.reshape(out_features, in_features)
    scales = scales.reshape(out_features, n_groups)

    return weight_int, scales


def dequantize_symmetric(weight_int: torch.Tensor, scales: torch.Tensor,
                         group_size: int = 128) -> torch.Tensor:
    """Dequantize a symmetric group-quantized weight tensor."""
    out_features, in_features = weight_int.shape
    n_groups = in_features // group_size

    weight_grouped = weight_int.reshape(out_features, n_groups, group_size).float()
    scales_expanded = scales.unsqueeze(-1)

    return (weight_grouped * scales_expanded).reshape(out_features, in_features)

The storage overhead of the scale factors is small: for group size 128 and 4-bit quantization, each group of 128 values (64 bytes in INT4) stores one FP16 scale (2 bytes), adding ~3% overhead.

GPTQ: Optimal Rounding via Second-Order Information

GPTQ (Frantar et al., 2022) goes beyond naive rounding by using second-order information (the Hessian of the layer-wise reconstruction error) to determine the optimal rounding direction for each weight. The key insight is that rounding a single weight up vs. down affects the entire output of the layer, and the Hessian tells us which direction minimizes that impact.

GPTQ processes weights one column at a time, and for each weight, it:

  1. Rounds the weight to the nearest quantized value.
  2. Compensates for the rounding error by adjusting the remaining (not yet quantized) weights in the same row, using the inverse Hessian to determine the optimal compensation.

This produces significantly better quantized models than naive rounding, especially at 4-bit and 3-bit precision.

AWQ: Activation-Aware Weight Quantization

AWQ (Lin et al., 2024) observes that not all weights are equally important. Weights corresponding to large activation magnitudes contribute more to the output and should be quantized more carefully. AWQ identifies "salient" weight channels by analyzing activation statistics from a small calibration dataset, then applies per-channel scaling to protect these important weights before quantization.

The scaling is chosen to minimize the quantization error for the most important channels, at the cost of slightly higher error for less important ones --- a tradeoff that consistently improves quality.

Quantization Landscape Summary

MethodBitsCalibration DataKey IdeaTypical Quality (vs FP16)
Round-to-Nearest (RTN)8 / 4NoneNaive roundingGood at INT8, poor at INT4
GPTQ4 / 3~128 samplesHessian-based optimal rounding≤1% degradation at INT4
AWQ4~128 samplesActivation-aware channel scaling≤1% degradation at INT4
SqueezeLLM4 / 3~128 samplesNon-uniform quantizationCompetitive with GPTQ
NF4 (QLoRA)4NoneNormal-float quantizationDesigned for fine-tuning
QuIP#2 / 4~128 samplesIncoherence processing + lattice codesBest quality at 2-bit
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
import torch

# INT4 quantization with bitsandbytes (NF4 format)
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",           # Normal Float 4-bit
    bnb_4bit_compute_dtype=torch.bfloat16,  # Compute in BF16
    bnb_4bit_use_double_quant=True,      # Quantize the quantization constants too
)

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-70b-hf",
    quantization_config=quantization_config,
    device_map="auto",       # Automatically shard across available GPUs
    torch_dtype=torch.bfloat16,
)
# 70B model in ~35 GB instead of ~140 GB

Speculative Decoding

The Insight

As we established, autoregressive decoding is memory-bound: the GPU loads all model weights to generate a single token. The compute units are vastly underutilized. Speculative decoding (Leviathan et al., 2023; Chen et al., 2023) exploits this idle compute by generating multiple candidate tokens cheaply, then verifying them in parallel with the large model.

Speculative decoding workflow showing draft model generating candidates and target model verifying in parallel

The key property that makes speculative decoding work is that verification is free in autoregressive models. If we have KK candidate tokens, we can compute the target model's probability for all KK tokens in a single forward pass (the same forward pass we would need for the next token anyway, just with KK extra tokens in the input). The marginal cost of verification is negligible compared to a fresh forward pass for each token.

The Algorithm in Detail

Speculative decoding proceeds in rounds. Each round:

  1. Draft phase: A small, fast draft model (e.g., a 1B model when the target is 70B) generates KK candidate tokens autoregressively. This is fast because the draft model is small.

  2. Verification phase: The target model processes the entire candidate sequence in a single forward pass, producing probability distributions p(xtx<t)p(x_t \mid x_{<t}) for each position.

  3. Accept/reject: For each candidate token, we compare the draft model's probability q(xt)q(x_t) with the target model's probability p(xt)p(x_t). We accept the token with probability min(1,p(xt)/q(xt))\min(1, p(x_t) / q(x_t)). This rejection sampling scheme guarantees that the final output distribution is identical to the target model's distribution --- speculative decoding is lossless.

  4. Correction: At the first rejected position, we sample from an adjusted distribution normalize(max(0,p(x)q(x)))\text{normalize}(\max(0, p(x) - q(x))) to correct for the draft model's bias.

import torch
import torch.nn.functional as F


def speculative_decode(
    draft_model,
    target_model,
    input_ids: torch.Tensor,
    max_new_tokens: int,
    draft_length: int = 5,
    temperature: float = 1.0,
):
    """Speculative decoding with rejection sampling.

    Generates tokens that are distributed identically to sampling from the
    target model, but potentially much faster by amortizing the target
    model's forward pass over multiple tokens.

    Args:
        draft_model: Small, fast model for generating candidates.
        target_model: Large model whose distribution we want to sample from.
        input_ids: Input token IDs, shape (1, seq_len).
        max_new_tokens: Maximum tokens to generate.
        draft_length: Number of speculative tokens per round (K).
        temperature: Sampling temperature.

    Returns:
        Generated token IDs.
    """
    generated = input_ids.clone()
    tokens_generated = 0

    while tokens_generated < max_new_tokens:
        # --- Draft phase ---
        # Generate K candidate tokens with the small model
        draft_tokens = []
        draft_probs = []
        draft_input = generated.clone()

        for _ in range(draft_length):
            with torch.no_grad():
                logits = draft_model(draft_input).logits[:, -1, :]
                probs = F.softmax(logits / temperature, dim=-1)
                token = torch.multinomial(probs, num_samples=1)
                draft_tokens.append(token)
                draft_probs.append(probs)
                draft_input = torch.cat([draft_input, token], dim=-1)

        draft_tokens = torch.cat(draft_tokens, dim=-1)  # (1, K)

        # --- Verification phase ---
        # Single forward pass through target model for all K+1 positions
        candidate_seq = torch.cat([generated, draft_tokens], dim=-1)
        with torch.no_grad():
            target_logits = target_model(candidate_seq).logits

        # Extract target probabilities at the K draft positions
        # Position indices: last K+1 positions of the output
        start_pos = generated.size(1) - 1
        target_probs_all = F.softmax(
            target_logits[:, start_pos:start_pos + draft_length + 1, :] / temperature,
            dim=-1,
        )

        # --- Accept/Reject ---
        n_accepted = 0
        for i in range(draft_length):
            token_id = draft_tokens[0, i].item()

            p_target = target_probs_all[0, i, token_id].item()
            p_draft = draft_probs[i][0, token_id].item()

            # Rejection sampling: accept with probability min(1, p/q)
            acceptance_prob = min(1.0, p_target / max(p_draft, 1e-10))

            if torch.rand(1).item() < acceptance_prob:
                n_accepted += 1
            else:
                # Reject: sample from the adjusted distribution
                adjusted = torch.clamp(
                    target_probs_all[0, i, :] - draft_probs[i][0, :],
                    min=0.0,
                )
                adjusted = adjusted / adjusted.sum()
                correction_token = torch.multinomial(adjusted, num_samples=1)

                generated = torch.cat([
                    generated, draft_tokens[:, :i], correction_token.unsqueeze(0)
                ], dim=-1)
                tokens_generated += i + 1
                break
        else:
            # All K tokens accepted! Sample one more from the target at position K+1
            bonus_token = torch.multinomial(
                target_probs_all[0, draft_length, :].unsqueeze(0), num_samples=1
            )
            generated = torch.cat([generated, draft_tokens, bonus_token], dim=-1)
            tokens_generated += draft_length + 1

    return generated


Expected Speedup

If each draft token has acceptance probability α\alpha (determined by how well the draft model approximates the target), the expected number of tokens generated per verification round with KK draft tokens is:

E[tokens per round]=1αK+11α\mathbb{E}[\text{tokens per round}] = \frac{1 - \alpha^{K+1}}{1 - \alpha}

The wallclock speedup depends on the relative cost of the draft and target models. If the draft model runs in time cTtargetc \cdot T_{\text{target}} where c1c \ll 1, the speedup is approximately:

Speedup1αK+1(1α)(1+Kc)\text{Speedup} \approx \frac{1 - \alpha^{K+1}}{(1 - \alpha)(1 + K \cdot c)}
Acceptance Rate (α\alpha)K=4 Tokens/RoundK=8 Tokens/RoundApprox. Speedup (K=5, c=0.05)
0.51.942.001.5x
0.72.532.792.0x
0.83.003.572.4x
0.93.695.702.9x
0.954.247.303.4x

In practice, speculative decoding achieves 2-3x speedup for code generation and translation tasks (where the draft model is a good predictor), and 1.3-1.8x for more creative/open-ended generation.

Draft Model Selection

The choice of draft model significantly impacts the acceptance rate and overall speedup:

StrategyDraft ModelTypical α\alphaNotes
Smaller versionLLaMA-1B for LLaMA-70B0.7-0.85Most common approach
Quantized targetINT4 version of target0.8-0.9High acceptance, but still expensive
N-gram / lookupToken frequency table0.3-0.5Nearly free, low acceptance
Medusa headsExtra prediction heads on target0.6-0.8No separate model needed
EAGLEFeature-level draft0.7-0.85Predicts hidden states, not tokens

Continuous Batching

The Static Batching Problem

In traditional (static) batching, a batch of requests is assembled, processed until the longest request finishes, and then the batch is released. Short requests that finish early waste GPU cycles while padding to match the longest request:

Static Batch:
Request A: [============]
Request B: [====]            ← GPU idle for 67% of the time
Request C: [========]        ← GPU idle for 33% of the time
           ↑ All must wait for A to finish before new requests can begin

If requests have variable output lengths (which they always do), static batching wastes 30-70% of GPU capacity on padding.

Continuous Batching (Iteration-Level Scheduling)

Continuous batching (Yu et al., 2022) operates at the granularity of individual decode steps rather than complete requests. After each decode iteration, finished requests are evicted and new requests are inserted:

Continuous Batching:
Time  →   T0   T1   T2   T3   T4   T5   T6   T7   T8
Slot 0:   A    A    A    A    D    D    F    F    F
Slot 1:   B    B    C    C    C    E    E    G    G
Slot 2:   C    C    C    D    D    D    D    G    G

Requests A,B finish → slots reused by D,E immediately
No wasted GPU cycles on padding!

Continuous batching can increase throughput by 2-5x compared to static batching, depending on the variance in output lengths.

PagedAttention and vLLM

The memory management challenge in continuous batching is the KV-cache. Each active request has a KV-cache that grows with each generated token. Requests start and end at different times, creating fragmentation in GPU memory --- analogous to the memory fragmentation problem in operating systems.

PagedAttention (Kwon et al., 2023), introduced in vLLM, solves this by borrowing the concept of virtual memory paging. Instead of allocating a single contiguous block for each request's KV-cache, PagedAttention divides the cache into fixed-size blocks (e.g., 16 tokens per block) and maps them through a block table:

Request A's KV-Cache (logical): [Block 0][Block 1][Block 2][Block 3]
                                    ↓        ↓        ↓        ↓
Physical GPU memory pages:      [Page 7 ][Page 2 ][Page 13][Page 4 ]
                                    ↑                 ↑
Request B's KV-Cache (logical): [Block 0][Block 1][Block 2]
                                    ↓        ↓        ↓
Physical GPU memory pages:      [Page 1 ][Page 9 ][Page 11]

This provides several benefits:

  • No internal fragmentation: Memory is allocated in small fixed-size pages, not large contiguous blocks. Waste is at most one page per request.
  • No external fragmentation: Pages can be allocated from anywhere in GPU memory, unlike contiguous allocation which suffers from fragmentation.
  • Memory sharing: Requests that share a common prefix (e.g., system prompt) can share KV-cache pages via copy-on-write, dramatically reducing memory for batched serving with shared prompts.
  • Near-optimal utilization: vLLM reports KV-cache memory utilization above 96%, compared to 20-50% for static allocation.
from vllm import LLM, SamplingParams

# Initialize vLLM with PagedAttention
llm = LLM(
    model="meta-llama/Llama-2-70b-chat-hf",
    tensor_parallel_size=4,          # Spread across 4 GPUs
    max_num_seqs=256,                # Max concurrent sequences
    max_num_batched_tokens=8192,     # Max tokens per iteration
    gpu_memory_utilization=0.90,     # Use 90% of GPU memory for KV-cache
    quantization="awq",             # Optional: combine with quantization
    dtype="bfloat16",
)

# Serve multiple requests efficiently
prompts = [
    "Explain the theory of relativity in simple terms.",
    "Write a Python function to sort a linked list.",
    "What is the capital of France?",
    # ... hundreds more requests
]

sampling_params = SamplingParams(
    temperature=0.7,
    top_p=0.9,
    max_tokens=512,
)

# vLLM handles continuous batching, KV-cache management,
# and PagedAttention automatically
outputs = llm.generate(prompts, sampling_params)

Production Serving Stack

Optimization Pipeline

A production deployment typically stacks multiple optimizations:

Trained Model (FP16/BF16)
    ↓
Quantization (AWQ/GPTQ to INT4)         → 4x memory reduction
    ↓
Tensor Parallelism (across GPUs)         → Distribute model across nodes
    ↓
Continuous Batching + PagedAttention     → Maximize throughput
    ↓
Speculative Decoding (optional)          → Reduce per-request latency
    ↓
Production Serving Engine (vLLM, TRT-LLM)

Inference Engine Comparison

EnginePagedAttentionSpeculative DecodingQuantizationMulti-GPUBest For
vLLMYes (native)YesAWQ, GPTQ, FP8TP, PPHigh-throughput serving
TensorRT-LLMYesYesINT4, INT8, FP8TP, PPLowest latency on NVIDIA
SGLangYes (RadixAttention)YesAWQ, GPTQTPStructured generation
llama.cppPartialNoGGUF (2-8 bit)LimitedCPU/edge deployment
TGIYesNoAWQ, GPTQ, bitsandbytesTPHuggingFace ecosystem
OllamaVia llama.cppNoGGUFLimitedLocal development

Latency Breakdown

For a typical request with 512 input tokens and 256 output tokens on a LLaMA-70B (INT4, single A100):

PhaseTimePercentageBottleneck
Prefill (process 512 input tokens)~150ms12%Compute-bound
Decode (generate 256 output tokens)~1000ms (~4ms/token)80%Memory-bandwidth-bound
Sampling + post-processing~50ms4%CPU
Network + scheduling overhead~50ms4%I/O
Total~1250ms100%

Cost Optimization: Choosing the Right Configuration

The optimal serving configuration depends on whether you are optimizing for latency (time per request) or throughput (total tokens per dollar):

PriorityStrategyTypical Config
Lowest latencyTensor parallelism across many GPUs, speculative decoding, FP84-8 GPUs per model, small batches
Highest throughputMaximum batch size, INT4 quantization, continuous batching1-2 GPUs per model, large batches
Lowest costAggressive quantization (INT4), maximize batch size, spot instancesMinimum GPUs, maximum utilization

Summary

TechniqueWhat It DoesSpeedup / SavingsTradeoff
KV-CacheCaches key/value projections across decode stepsO(n2)O(n)O(n^2) \to O(n) computeMemory proportional to sequence length
GQAReduces KV heads (see Part 3)4-8x KV-cache reductionMarginal quality impact
INT4 QuantizationReduces weight precision to 4 bits4x memory, ~2x throughput≤1% quality loss with GPTQ/AWQ
Speculative DecodingGenerates multiple tokens per target model pass2-3x decode speedRequires draft model, variable speedup
Continuous BatchingInserts/removes requests at each decode step2-5x throughputImplementation complexity
PagedAttentionPaged memory management for KV-cache>96% memory utilizationCustom CUDA kernels

Each technique targets a different bottleneck, and they compose multiplicatively. A production stack combining INT4 quantization, continuous batching with PagedAttention, and speculative decoding can serve a 70B model at throughputs that were unimaginable just two years ago --- handling thousands of concurrent users from a single 8-GPU node.


In the next post, we will explore Part 7: Minor But Important Changes --- the seemingly small architectural decisions that collectively make a large difference: removing bias terms, tied vs. untied embeddings, parallel attention and FFN blocks, initialization schemes, and other design patterns found in modern LLMs.

References

  • Leviathan, Y., Kalman, M., & Matias, Y. (2023). Fast Inference from Transformers via Speculative Decoding. ICML 2023. arXiv:2211.17192.
  • Chen, C., Borgeaud, S., Irving, G., Lespiau, J.-B., Sifre, L., & Jumper, J. (2023). Accelerating Large Language Model Decoding with Speculative Sampling. arXiv:2302.01318.
  • Frantar, E., Ashkboos, S., Hoefler, T., & Alistarh, D. (2022). GPTQ: Accurate Post-Training Quantization for Generative Pre-Trained Transformers. ICLR 2023. arXiv:2210.17323.
  • Lin, J., Tang, J., Tang, H., Yang, S., Chen, W.-M., Wang, W.-C., ... & Han, S. (2024). AWQ: Activation-aware Weight Quantization for On-Device LLM Compression and Acceleration. MLSys 2024. arXiv:2306.00978.
  • Kwon, W., Li, Z., Zhuang, S., Sheng, Y., Zheng, L., Yu, C. H., ... & Stoica, I. (2023). Efficient Memory Management for Large Language Model Serving with PagedAttention. SOSP 2023. arXiv:2309.06180.
  • Yu, G.-I., Jeong, J. S., Kim, G.-W., Kim, S., & Chun, B.-G. (2022). Orca: A Distributed Serving System for Transformer-Based Generative Models. OSDI 2022.
  • Dettmers, T., Pagnoni, A., Holtzman, A., & Zettlemoyer, L. (2023). QLoRA: Efficient Finetuning of Quantized Language Models. NeurIPS 2023. arXiv:2305.14314.
  • Sheng, Y., Zheng, L., Yuan, B., Li, Z., Ryabinin, M., Chen, B., ... & Stoica, I. (2023). FlexGen: High-Throughput Generative Inference of Large Language Models with a Single GPU. ICML 2023. arXiv:2303.06865.
  • Pope, R., Douglas, S., Chowdhery, A., et al. (2023). Efficiently Scaling Transformer Inference. MLSys 2023. arXiv:2211.05102.
  • Cai, T., Li, Y., Geng, Z., Peng, H., Lee, J. D., Chen, D., & Dao, T. (2024). Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads. ICML 2024. arXiv:2401.10774.
  • Li, Y., Cai, T., Zhang, Y., Chen, D., & Dao, T. (2024). EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty. ICML 2024. arXiv:2401.15077.
Share:
SW

Written by Suchinthaka Wanninayaka

AI/ML Researcher exploring semantic communications, diffusion models, and language model systems. Writing about deep learning from theory to production.

Responses

?

No responses yet. Be the first to share your thoughts!