All Articles
transformersattentionropeflash-attentiongqadeep-learning

Transformer Deep Dive: Part 3 - Attention Modifications

Evolution of the attention mechanism - from sinusoidal to RoPE positional encoding, Multi-Query Attention, Grouped Query Attention, and the revolutionary FlashAttention algorithm.

SW

Suchinthaka W.

January 17, 2025 · 6 min read

The attention mechanism is the heart of the Transformer. Since 2017, researchers have developed numerous improvements to make attention more efficient, extend context lengths, and reduce memory requirements.

The Challenges

  1. Quadratic Complexity: The QKQK^\top computation is O(n2)O(n^2) in sequence length
  2. Memory Bandwidth: Attention is memory-bound, not compute-bound
  3. Position Information: The original architecture lacks inherent position awareness
  4. KV-Cache Size: During inference, storing keys and values becomes expensive

Positional Encoding Evolution

Timeline

2017: Sinusoidal (Vaswani)
2018: Learned Embeddings (BERT, GPT)
2021: RoPE (Su et al.)
2022: ALiBi (Press et al.)
2023+: Extended RoPE variants

Original: Sinusoidal Encoding

PE(pos,2i)=sin(pos100002i/d)PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d}}\right) PE(pos,2i+1)=cos(pos100002i/d)PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d}}\right)

Properties:

  • Fixed (not learned)
  • Added to token embeddings
  • Can theoretically extrapolate to longer sequences

Learned Positional Embeddings

BERT and GPT-1/2 used learned position embeddings - simply a lookup table of learnable vectors.

Limitations:

  • Fixed maximum sequence length
  • No extrapolation beyond training length

Rotary Position Embedding (RoPE)

RoPE encodes position information directly into the attention mechanism by rotating query and key vectors.

Core Idea: Instead of adding position information to embeddings, apply a rotation matrix based on position:

fq(xm,m)=RΘ,m(Wqxm)f_q(x_m, m) = R_{\Theta,m} \cdot (W_q x_m)

where RΘ,mR_{\Theta,m} is a rotation matrix that depends on position mm.

Key Properties:

  • Relative position information encoded in qmknq_m^\top k_n
  • The attention score depends on (mn)(m - n), the relative distance
  • Naturally decays with distance

Implementation:

def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor):
    # Reshape to complex numbers
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))

    # Apply rotation via complex multiplication
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(-2)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(-2)

    return xq_out.type_as(xq), xk_out.type_as(xk)

Context Length Extension

RoPE enables various techniques for extending context length beyond training:

| Technique | Method | Used By | |-----------|--------|---------| | Position Interpolation | Scale positions | LLaMA 2 | | NTK-Aware Scaling | Modify base frequency | Various | | YaRN | Dynamic scaling | Mistral, LLaMA variants |

Multi-Query Attention (MQA)

The KV-Cache Problem

During inference, we cache keys and values to avoid recomputation. For long sequences, this becomes expensive:

KV Cache Size=2×L×nheads×dhead×seq_len\text{KV Cache Size} = 2 \times L \times n_{heads} \times d_{head} \times \text{seq\_len}

For a 70B model with 32K context, this can be 100+ GB!

MQA Solution

Idea: Share a single key and value head across all query heads.

Standard attention:

  • Q: [batch, seq, n_heads, d_head]
  • K: [batch, seq, n_heads, d_head]
  • V: [batch, seq, n_heads, d_head]

Multi-Query Attention:

  • Q: [batch, seq, n_heads, d_head]
  • K: [batch, seq, 1, d_head]
  • V: [batch, seq, 1, d_head]

Benefits:

  • KV cache reduced by factor of n_heads (e.g., 32×)
  • Inference much faster
  • Minimal quality loss

Grouped Query Attention (GQA)

GQA is a middle ground between MHA and MQA.

| Method | KV Heads | Query Heads | KV Cache | |--------|----------|-------------|----------| | MHA | H | H | 1× | | GQA | G | H | H/G× | | MQA | 1 | H | H× reduction |

Example: LLaMA 2 70B uses 8 KV heads for 64 query heads (8× reduction).

class GroupedQueryAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int, n_kv_heads: int):
        super().__init__()
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.n_groups = n_heads // n_kv_heads

        self.head_dim = d_model // n_heads

        self.wq = nn.Linear(d_model, n_heads * self.head_dim, bias=False)
        self.wk = nn.Linear(d_model, n_kv_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(d_model, n_kv_heads * self.head_dim, bias=False)
        self.wo = nn.Linear(n_heads * self.head_dim, d_model, bias=False)

    def forward(self, x):
        B, T, C = x.shape

        q = self.wq(x).view(B, T, self.n_heads, self.head_dim)
        k = self.wk(x).view(B, T, self.n_kv_heads, self.head_dim)
        v = self.wv(x).view(B, T, self.n_kv_heads, self.head_dim)

        # Repeat KV heads to match query heads
        k = k.repeat_interleave(self.n_groups, dim=2)
        v = v.repeat_interleave(self.n_groups, dim=2)

        # Standard attention computation
        # ...

FlashAttention

The Memory Bandwidth Problem

Standard attention is memory-bound, not compute-bound. Modern GPUs have:

| GPU | Memory BW (TB/s) | FP16 TFLOPs | Ratio | |-----|------------------|-------------|-------| | A100 | 2.0 | 312 | 156:1 | | H100 | 3.35 | 990 | 296:1 |

For every byte loaded from memory, we can do 150-300 operations. But standard attention reads/writes the full N×N attention matrix.

FlashAttention Algorithm

Key Insight: Never materialize the full N×N attention matrix. Instead, compute attention in blocks (tiles) that fit in fast SRAM.

Tiling Strategy:

  1. Divide Q, K, V into blocks
  2. For each block, compute partial softmax in SRAM
  3. Use online softmax to combine results
  4. Never write N×N matrix to slow HBM

Memory Hierarchy

     ┌─────────────────────────────────────────────────────────────┐
     │                GPU Compute Units (CUDA Cores, Tensor Cores)  │
     └─────────────────────────────────────────────────────────────┘
                                    ↑↓ Very fast
     ┌─────────────────────────────────────────────────────────────┐
     │                    SRAM (On-chip) ~20 MB                    │
     │                    ~19 TB/s bandwidth                        │
     └─────────────────────────────────────────────────────────────┘
                                    ↑↓ Slower (bottleneck!)
     ┌─────────────────────────────────────────────────────────────┐
     │                    HBM (Off-chip) 40-80 GB                  │
     │                    ~2-3 TB/s bandwidth                       │
     └─────────────────────────────────────────────────────────────┘

FlashAttention-2 Improvements

  1. Better parallelism: Parallelize over sequence length dimension
  2. Reduced non-matmul FLOPs: Minimize warp shuffles and communication
  3. Work partitioning: Better distribution across thread blocks

Performance Gains

| Method | Memory | Speed | |--------|--------|-------| | Standard Attention | O(N²) | Baseline | | FlashAttention | O(N) | 2-4× faster | | FlashAttention-2 | O(N) | 2× faster than FA |

Usage

FlashAttention is now integrated into PyTorch:

# PyTorch 2.0+
from torch.nn.functional import scaled_dot_product_attention

# Automatically uses FlashAttention when available
output = scaled_dot_product_attention(q, k, v, is_causal=True)

Sparse Attention Patterns

For very long sequences, even O(N) memory can be prohibitive. Sparse attention patterns reduce complexity further:

Types of Sparse Attention

  1. Local/Sliding Window: Each token attends to fixed window
  2. Strided: Attend to every k-th token
  3. Block Sparse: Attend to specific blocks
  4. Longformer: Local + global attention tokens

Sliding Window Attention

Used by Mistral:

Position 0: Can see [0]
Position 1: Can see [0, 1]
Position 2: Can see [0, 1, 2]
...
Position W: Can see [0, 1, ..., W]
Position W+1: Can see [1, 2, ..., W+1]  ← Slides!

With window size W, complexity is O(N × W) instead of O(N²).

Summary

| Innovation | Problem Solved | Used By | |------------|----------------|---------| | RoPE | Position encoding that extrapolates | LLaMA, Mistral, GPT-NeoX | | GQA | KV cache size | LLaMA 2, Mistral | | FlashAttention | Memory bandwidth | Nearly all modern LLMs | | Sliding Window | Very long contexts | Mistral, Longformer |


In the next post, we'll explore Part 4: FFN Modifications - activation functions (GELU, SwiGLU), gated FFNs, and Mixture of Experts.

Did you find this helpful?
Share:

Responses

Be the first to share your thoughts!