Transformer Deep Dive: Part 3 - Attention Modifications
Evolution of the attention mechanism - from sinusoidal to RoPE positional encoding, Multi-Query Attention, Grouped Query Attention, and the revolutionary FlashAttention algorithm.
Suchinthaka W.
January 17, 2025 · 6 min read
The attention mechanism is the heart of the Transformer. Since 2017, researchers have developed numerous improvements to make attention more efficient, extend context lengths, and reduce memory requirements.
The Challenges
- Quadratic Complexity: The computation is in sequence length
- Memory Bandwidth: Attention is memory-bound, not compute-bound
- Position Information: The original architecture lacks inherent position awareness
- KV-Cache Size: During inference, storing keys and values becomes expensive
Positional Encoding Evolution
Timeline
2017: Sinusoidal (Vaswani)
2018: Learned Embeddings (BERT, GPT)
2021: RoPE (Su et al.)
2022: ALiBi (Press et al.)
2023+: Extended RoPE variants
Original: Sinusoidal Encoding
Properties:
- Fixed (not learned)
- Added to token embeddings
- Can theoretically extrapolate to longer sequences
Learned Positional Embeddings
BERT and GPT-1/2 used learned position embeddings - simply a lookup table of learnable vectors.
Limitations:
- Fixed maximum sequence length
- No extrapolation beyond training length
Rotary Position Embedding (RoPE)
RoPE encodes position information directly into the attention mechanism by rotating query and key vectors.
Core Idea: Instead of adding position information to embeddings, apply a rotation matrix based on position:
where is a rotation matrix that depends on position .
Key Properties:
- Relative position information encoded in
- The attention score depends on , the relative distance
- Naturally decays with distance
Implementation:
def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor):
# Reshape to complex numbers
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
# Apply rotation via complex multiplication
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(-2)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(-2)
return xq_out.type_as(xq), xk_out.type_as(xk)
Context Length Extension
RoPE enables various techniques for extending context length beyond training:
| Technique | Method | Used By | |-----------|--------|---------| | Position Interpolation | Scale positions | LLaMA 2 | | NTK-Aware Scaling | Modify base frequency | Various | | YaRN | Dynamic scaling | Mistral, LLaMA variants |
Multi-Query Attention (MQA)
The KV-Cache Problem
During inference, we cache keys and values to avoid recomputation. For long sequences, this becomes expensive:
For a 70B model with 32K context, this can be 100+ GB!
MQA Solution
Idea: Share a single key and value head across all query heads.
Standard attention:
- Q: [batch, seq, n_heads, d_head]
- K: [batch, seq, n_heads, d_head]
- V: [batch, seq, n_heads, d_head]
Multi-Query Attention:
- Q: [batch, seq, n_heads, d_head]
- K: [batch, seq, 1, d_head]
- V: [batch, seq, 1, d_head]
Benefits:
- KV cache reduced by factor of n_heads (e.g., 32×)
- Inference much faster
- Minimal quality loss
Grouped Query Attention (GQA)
GQA is a middle ground between MHA and MQA.
| Method | KV Heads | Query Heads | KV Cache | |--------|----------|-------------|----------| | MHA | H | H | 1× | | GQA | G | H | H/G× | | MQA | 1 | H | H× reduction |
Example: LLaMA 2 70B uses 8 KV heads for 64 query heads (8× reduction).
class GroupedQueryAttention(nn.Module):
def __init__(self, d_model: int, n_heads: int, n_kv_heads: int):
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.n_groups = n_heads // n_kv_heads
self.head_dim = d_model // n_heads
self.wq = nn.Linear(d_model, n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(d_model, n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(d_model, n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(n_heads * self.head_dim, d_model, bias=False)
def forward(self, x):
B, T, C = x.shape
q = self.wq(x).view(B, T, self.n_heads, self.head_dim)
k = self.wk(x).view(B, T, self.n_kv_heads, self.head_dim)
v = self.wv(x).view(B, T, self.n_kv_heads, self.head_dim)
# Repeat KV heads to match query heads
k = k.repeat_interleave(self.n_groups, dim=2)
v = v.repeat_interleave(self.n_groups, dim=2)
# Standard attention computation
# ...
FlashAttention
The Memory Bandwidth Problem
Standard attention is memory-bound, not compute-bound. Modern GPUs have:
| GPU | Memory BW (TB/s) | FP16 TFLOPs | Ratio | |-----|------------------|-------------|-------| | A100 | 2.0 | 312 | 156:1 | | H100 | 3.35 | 990 | 296:1 |
For every byte loaded from memory, we can do 150-300 operations. But standard attention reads/writes the full N×N attention matrix.
FlashAttention Algorithm
Key Insight: Never materialize the full N×N attention matrix. Instead, compute attention in blocks (tiles) that fit in fast SRAM.
Tiling Strategy:
- Divide Q, K, V into blocks
- For each block, compute partial softmax in SRAM
- Use online softmax to combine results
- Never write N×N matrix to slow HBM
Memory Hierarchy
┌─────────────────────────────────────────────────────────────┐
│ GPU Compute Units (CUDA Cores, Tensor Cores) │
└─────────────────────────────────────────────────────────────┘
↑↓ Very fast
┌─────────────────────────────────────────────────────────────┐
│ SRAM (On-chip) ~20 MB │
│ ~19 TB/s bandwidth │
└─────────────────────────────────────────────────────────────┘
↑↓ Slower (bottleneck!)
┌─────────────────────────────────────────────────────────────┐
│ HBM (Off-chip) 40-80 GB │
│ ~2-3 TB/s bandwidth │
└─────────────────────────────────────────────────────────────┘
FlashAttention-2 Improvements
- Better parallelism: Parallelize over sequence length dimension
- Reduced non-matmul FLOPs: Minimize warp shuffles and communication
- Work partitioning: Better distribution across thread blocks
Performance Gains
| Method | Memory | Speed | |--------|--------|-------| | Standard Attention | O(N²) | Baseline | | FlashAttention | O(N) | 2-4× faster | | FlashAttention-2 | O(N) | 2× faster than FA |
Usage
FlashAttention is now integrated into PyTorch:
# PyTorch 2.0+
from torch.nn.functional import scaled_dot_product_attention
# Automatically uses FlashAttention when available
output = scaled_dot_product_attention(q, k, v, is_causal=True)
Sparse Attention Patterns
For very long sequences, even O(N) memory can be prohibitive. Sparse attention patterns reduce complexity further:
Types of Sparse Attention
- Local/Sliding Window: Each token attends to fixed window
- Strided: Attend to every k-th token
- Block Sparse: Attend to specific blocks
- Longformer: Local + global attention tokens
Sliding Window Attention
Used by Mistral:
Position 0: Can see [0]
Position 1: Can see [0, 1]
Position 2: Can see [0, 1, 2]
...
Position W: Can see [0, 1, ..., W]
Position W+1: Can see [1, 2, ..., W+1] ← Slides!
With window size W, complexity is O(N × W) instead of O(N²).
Summary
| Innovation | Problem Solved | Used By | |------------|----------------|---------| | RoPE | Position encoding that extrapolates | LLaMA, Mistral, GPT-NeoX | | GQA | KV cache size | LLaMA 2, Mistral | | FlashAttention | Memory bandwidth | Nearly all modern LLMs | | Sliding Window | Very long contexts | Mistral, Longformer |
In the next post, we'll explore Part 4: FFN Modifications - activation functions (GELU, SwiGLU), gated FFNs, and Mixture of Experts.
Transformer Deep Dive: Part 2 - Architecture Changes
NextTransformer Deep Dive: Part 4 - FFN Modifications
Related Articles
Responses
Be the first to share your thoughts!