Transformer Deep Dive: Part 2 - Architecture Changes
In Part 1, we revisited the original Transformer architecture from Vaswani et al. (2017). If you dropped that 2017 model into a modern training pipeline, it would fail spectacularly -- training would diverge, gradients would explode, and you would spend weeks debugging learning rate warmup schedules. The architecture that powers GPT-4, LLaMA 3, and Mistral has diverged from the original design in several fundamental ways, each motivated by concrete failure modes discovered through years of scaling experiments.
This post examines three critical architectural shifts: the move to decoder-only models, the repositioning of layer normalization, and the simplification of normalization itself. These are not incremental improvements. They represent hard-won lessons about what actually matters when you scale transformers to hundreds of billions of parameters.
1. The Rise of Decoder-Only Architecture
A Brief History
The original Transformer was an encoder-decoder model designed for machine translation. The encoder processes the source sentence bidirectionally, and the decoder generates the target sentence autoregressively, attending to the encoder output via cross-attention. This was a natural architecture for seq2seq tasks, but the field quickly discovered that simpler variants could be equally powerful.
In 2018, two competing paradigms emerged simultaneously. BERT (Devlin et al., 2018) took the encoder half, trained it with a masked language modeling objective, and achieved state-of-the-art results on classification and understanding tasks. GPT (Radford et al., 2018) took the decoder half, trained it with next-token prediction, and showed surprisingly strong zero-shot performance on diverse tasks.
By 2020, with GPT-3 demonstrating in-context learning at scale, the decoder-only paradigm had effectively won. Today, nearly every frontier model -- LLaMA, Mistral, Claude, GPT-4, Gemini, Qwen -- uses a decoder-only architecture.
Three Architecture Families
| Architecture | Attention Pattern | Training Objective | Notable Models |
|---|---|---|---|
| Encoder-Decoder | Bidirectional (enc) + Causal (dec) + Cross-attention | Span corruption, translation | T5, BART, mBART, Flan-T5 |
| Encoder-Only | Bidirectional | Masked Language Modeling | BERT, RoBERTa, DeBERTa |
| Decoder-Only | Causal (unidirectional) | Next-token prediction | GPT, LLaMA, Mistral, PaLM |
Why Decoder-Only Won
The dominance of decoder-only models was not preordained. Several concrete factors drove this convergence:
Unified training objective. Next-token prediction is the simplest possible objective. There are no masked spans to construct, no separate encoder and decoder losses to balance, and no architectural hyperparameters for cross-attention layers. The autoregressive objective naturally decomposes the joint probability of a sequence:
This means the training loss is simply the negative log-likelihood:
Every token in every training sequence contributes a supervision signal. There is no wasted computation on [MASK] tokens that appear only 15% of the time, as in BERT-style training.
Scaling simplicity. With one objective and one architecture, the only decisions are model size, data, and compute. This aligns perfectly with the scaling laws discovered by Kaplan et al. (2020) and later refined by Hoffmann et al. (2022, "Chinchilla"). The research community converged on a simple recipe: take a decoder-only transformer, scale it up, and feed it more data.
Emergent in-context learning. Perhaps the most surprising property of large decoder-only models is in-context learning (ICL): the ability to perform new tasks by conditioning on a few examples in the prompt, without any gradient updates. This effectively turns a single model into a general-purpose task solver, eliminating the need for task-specific architectures.
KV-cache efficiency. During autoregressive generation, decoder-only models naturally support a key-value cache. At each timestep, we only need to compute the query for the new token and attend to the cached keys and values from all previous tokens. This makes generation per token rather than . Encoder-decoder models require maintaining both an encoder KV-cache and a decoder KV-cache, with cross-attention adding complexity.
Task unification via prompting. A decoder-only model can handle classification, summarization, translation, reasoning, and code generation -- all through different prompt formats. The model's input and output share the same vocabulary and representation space, eliminating the need for task-specific heads.
Causal Masking in Detail
The defining feature of a decoder-only model is the causal attention mask, which ensures that each token can only attend to itself and preceding tokens. This is implemented by adding a mask to the attention scores before softmax:
where is an upper-triangular matrix of values:
After the softmax, positions with scores become zero, effectively preventing information from flowing backward in the sequence. Here is a complete implementation:
import torch
import torch.nn as nn
import torch.nn.functional as F
def create_causal_mask(seq_len: int, device: torch.device) -> torch.Tensor:
"""Create a causal attention mask.
Returns a (seq_len, seq_len) tensor where future positions
are set to -inf and past/current positions are 0.
"""
mask = torch.triu(
torch.ones(seq_len, seq_len, device=device),
diagonal=1
)
return mask.masked_fill(mask == 1, float('-inf'))
def causal_self_attention(
x: torch.Tensor,
W_q: nn.Linear,
W_k: nn.Linear,
W_v: nn.Linear,
n_heads: int
) -> torch.Tensor:
"""Causal multi-head self-attention.
Args:
x: Input tensor of shape (batch, seq_len, d_model)
W_q, W_k, W_v: Projection layers
n_heads: Number of attention heads
"""
B, T, C = x.shape
head_dim = C // n_heads
# Project to Q, K, V
q = W_q(x).view(B, T, n_heads, head_dim).transpose(1, 2)
k = W_k(x).view(B, T, n_heads, head_dim).transpose(1, 2)
v = W_v(x).view(B, T, n_heads, head_dim).transpose(1, 2)
# Scaled dot-product attention with causal mask
scale = head_dim ** -0.5
attn = (q @ k.transpose(-2, -1)) * scale
# Apply causal mask
mask = create_causal_mask(T, x.device)
attn = attn + mask # Broadcasting: (B, H, T, T) + (T, T)
attn = F.softmax(attn, dim=-1)
out = attn @ v # (B, H, T, head_dim)
# Recombine heads
out = out.transpose(1, 2).contiguous().view(B, T, C)
return out
The resulting attention pattern looks like this (1 = attends, 0 = masked):
| Pos 0 | Pos 1 | Pos 2 | Pos 3 | |
|---|---|---|---|---|
| Pos 0 | 1 | 0 | 0 | 0 |
| Pos 1 | 1 | 1 | 0 | 0 |
| Pos 2 | 1 | 1 | 1 | 0 |
| Pos 3 | 1 | 1 | 1 | 1 |
Prefix LM: A Hybrid Approach
It is worth noting that some models use a prefix LM pattern, where a prefix of the sequence uses bidirectional attention (no causal mask) and the remainder uses causal attention. This can be seen as a decoder-only model where the prompt portion gets bidirectional context. U-PaLM and some T5 variants explored this approach, though pure causal masking remains dominant.
2. Pre-Layer Normalization
The Training Instability Problem
The original Transformer placed layer normalization after the residual connection -- a design now called Post-Layer Normalization (Post-LN). This worked for the relatively small models of 2017, but as researchers tried to scale up, they hit a wall: training became increasingly unstable, requiring carefully tuned learning rate warmup schedules, and often diverging entirely for deeper models.
Xiong et al. (2020) provided a theoretical explanation. In Post-LN, the expected gradient norm at the output layer grows with depth, while the expected gradient at earlier layers can vanish. This creates a precarious optimization landscape that demands very careful warmup to avoid divergence.
Post-LN vs Pre-LN: Structural Comparison
The difference is a simple reordering of operations, but the consequences are profound.
Post-LN (Original Transformer):
Pre-LN (Modern):
Why This Matters: Gradient Flow Analysis
The critical insight is about the residual stream. In Pre-LN, if we unroll the residual connections across layers, the output of the network can be written as:
where represents the sublayer function (attention or FFN) at layer . This has a direct additive path from input to output -- the residual stream is never passed through a normalization layer.
Taking the gradient with respect to parameters in an early layer:
Because , the gradient flows directly from the loss back to any layer without being multiplicatively attenuated by intervening normalization layers. This is analogous to how ResNets solved the vanishing gradient problem in CNNs.
In Post-LN, by contrast, the normalization sits directly on the residual path. Each LayerNorm introduces a Jacobian that can attenuate or amplify gradients, and these effects compound across layers. Xiong et al. (2020) showed that the gradient norm at initialization follows:
This explains why Post-LN requires extensive learning rate warmup (often thousands of steps) while Pre-LN can begin training with the full learning rate immediately.
Practical Comparison
| Property | Post-LN | Pre-LN |
|---|---|---|
| Learning rate warmup | Essential (thousands of steps) | Minimal or unnecessary |
| Maximum stable learning rate | Smaller | Larger |
| Training stability for deep models | Fragile, prone to divergence | Robust |
| Final performance (when training succeeds) | Marginally better in some cases | Comparable |
| Used by | Original Transformer, early BERT | GPT-2/3, LLaMA, Mistral, PaLM |
The marginal performance advantage of Post-LN has motivated some recent work on stabilizing it (e.g., Admin initialization), but in practice, the stability advantages of Pre-LN have made it the universal default for large-scale training.
Pre-LN Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
class PreLNTransformerBlock(nn.Module):
"""A single transformer block with Pre-Layer Normalization.
This is the standard building block used in GPT-2, GPT-3,
LLaMA, and most modern decoder-only LLMs.
"""
def __init__(
self,
d_model: int,
n_heads: int,
d_ff: int,
dropout: float = 0.0,
bias: bool = False
):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.ln2 = nn.LayerNorm(d_model)
self.attention = nn.MultiheadAttention(
d_model, n_heads,
dropout=dropout,
bias=bias,
batch_first=True
)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff, bias=bias),
nn.GELU(),
nn.Linear(d_ff, d_model, bias=bias),
nn.Dropout(dropout),
)
self.dropout = nn.Dropout(dropout)
def forward(
self,
x: torch.Tensor,
attn_mask: torch.Tensor | None = None
) -> torch.Tensor:
# ---- Attention sub-block (Pre-LN) ----
# Normalize BEFORE the sublayer
normed = self.ln1(x)
attn_out, _ = self.attention(
normed, normed, normed,
attn_mask=attn_mask,
need_weights=False
)
# Residual connection bypasses normalization
x = x + self.dropout(attn_out)
# ---- FFN sub-block (Pre-LN) ----
x = x + self.ffn(self.ln2(x))
return x
class PreLNTransformer(nn.Module):
"""Complete Pre-LN decoder-only transformer.
Note the final LayerNorm after the last block -- this is
essential because the residual stream is unnormalized.
"""
def __init__(
self,
vocab_size: int,
d_model: int = 768,
n_heads: int = 12,
n_layers: int = 12,
d_ff: int = 3072,
max_seq_len: int = 2048,
dropout: float = 0.1
):
super().__init__()
self.token_emb = nn.Embedding(vocab_size, d_model)
self.pos_emb = nn.Embedding(max_seq_len, d_model)
self.dropout = nn.Dropout(dropout)
self.blocks = nn.ModuleList([
PreLNTransformerBlock(d_model, n_heads, d_ff, dropout)
for _ in range(n_layers)
])
# Final LayerNorm -- critical for Pre-LN architecture
self.ln_final = nn.LayerNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
B, T = input_ids.shape
positions = torch.arange(T, device=input_ids.device)
x = self.token_emb(input_ids) + self.pos_emb(positions)
x = self.dropout(x)
# Causal mask
mask = create_causal_mask(T, input_ids.device)
for block in self.blocks:
x = block(x, attn_mask=mask)
# Final normalization before output projection
x = self.ln_final(x)
logits = self.lm_head(x)
return logits
The Final LayerNorm
Notice the ln_final in the model above. In a Pre-LN architecture, the residual stream accumulates contributions from every layer without ever being normalized on the main path. By the time we reach the last layer, the activations can have grown substantially. The final LayerNorm brings the representation back to a normalized scale before projecting to vocabulary logits:
Omitting this final normalization typically leads to training instability or poor performance. Every major Pre-LN model (GPT-2, LLaMA, Mistral, PaLM) includes it.
3. RMSNorm: Simpler and Faster Normalization
LayerNorm Revisited
Standard Layer Normalization (Ba et al., 2016) computes both the mean and variance of the activations, then re-centers and re-scales:
where:
Here and are learnable gain and bias parameters, each of dimension . The computation involves two passes over the data (one for mean, one for variance), plus the subtraction and division.
The RMSNorm Simplification
Zhang and Sennrich (2019) proposed Root Mean Square Layer Normalization, which eliminates the mean computation entirely:
where:
Two things are different here. First, there is no mean subtraction -- the input is divided by its root mean square directly. Second, there is no bias parameter , only a gain .
Why Removing Mean Subtraction Works
The key insight is a mathematical relationship between the RMS and the standard deviation. The variance can be decomposed as:
Therefore:
For neural network activations, especially in deeper layers with residual connections, the mean tends to be close to zero. When , we get , and RMSNorm becomes approximately equivalent to LayerNorm (without the centering).
Zhang and Sennrich (2019) provided empirical evidence that the re-centering operation (mean subtraction) contributes negligibly to the success of LayerNorm, while the re-scaling operation (division by a measure of spread) is what actually stabilizes training. This is an elegant instance of removing unnecessary computation without sacrificing model quality.
Computational Savings
| Operation | LayerNorm | RMSNorm |
|---|---|---|
| Reduction passes | 2 (mean, then variance) | 1 (sum of squares) |
| Mean computation | Required | Not needed |
| Mean subtraction | Required | Not needed |
| Learnable parameters | and (2d) | only (d) |
| Wall-clock speedup | Baseline | ~10-15% faster |
| Memory for parameters | 2d floats | d floats |
The 10-15% speedup may seem modest, but normalization is applied at every sublayer of every transformer block. In a 32-layer LLaMA model, that is 64 normalization operations per forward pass. At scale, this adds up to meaningful savings in both training and inference time.
Implementation
import torch
import torch.nn as nn
class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization.
Used by LLaMA, LLaMA 2, LLaMA 3, Mistral, Gemma,
and most modern LLMs as a drop-in replacement for LayerNorm.
"""
def __init__(self, d_model: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(d_model))
def _norm(self, x: torch.Tensor) -> torch.Tensor:
# RMS = sqrt(mean(x^2) + eps)
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Cast to float32 for numerical stability, then back
output = self._norm(x.float()).type_as(x)
return output * self.weight
A few implementation details are worth noting:
torch.rsqrtcomputes in a single fused operation, which is faster than computingsqrtand then dividing.- Float32 accumulation: The norm computation is done in float32 even if the input is in bfloat16. This prevents numerical issues when squaring small values. The result is cast back to the original dtype afterward.
- No bias parameter: There is no additive bias, which means the output is purely a scaled version of the input direction.
A LLaMA-Style Block with RMSNorm
Combining Pre-LN with RMSNorm gives us the standard building block of modern LLMs:
class LLaMABlock(nn.Module):
"""Transformer block following the LLaMA architecture.
Uses Pre-LN with RMSNorm instead of LayerNorm.
"""
def __init__(self, d_model: int, n_heads: int, d_ff: int):
super().__init__()
self.attn_norm = RMSNorm(d_model)
self.ffn_norm = RMSNorm(d_model)
self.attention = nn.MultiheadAttention(
d_model, n_heads,
bias=False,
batch_first=True
)
# SwiGLU FFN (covered in Part 4)
self.w1 = nn.Linear(d_model, d_ff, bias=False)
self.w2 = nn.Linear(d_ff, d_model, bias=False)
self.w3 = nn.Linear(d_model, d_ff, bias=False)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
# Pre-RMSNorm + Attention
h = self.attn_norm(x)
attn_out, _ = self.attention(h, h, h, attn_mask=mask)
x = x + attn_out
# Pre-RMSNorm + SwiGLU FFN
h = self.ffn_norm(x)
x = x + self.w2(F.silu(self.w1(h)) * self.w3(h))
return x
Adoption Across Models
| Model | Year | Normalization | Position |
|---|---|---|---|
| Original Transformer | 2017 | LayerNorm | Post-LN |
| GPT-2 | 2019 | LayerNorm | Pre-LN |
| GPT-3 | 2020 | LayerNorm | Pre-LN |
| PaLM | 2022 | LayerNorm | Pre-LN (parallel) |
| LLaMA | 2023 | RMSNorm | Pre-LN |
| LLaMA 2 | 2023 | RMSNorm | Pre-LN |
| Mistral 7B | 2023 | RMSNorm | Pre-LN |
| Gemma | 2024 | RMSNorm | Pre-LN |
| LLaMA 3 | 2024 | RMSNorm | Pre-LN |
The trend is clear: RMSNorm with Pre-LN positioning has become the de facto standard.
A Note on QK-Norm
An emerging technique is QK-Norm, where an additional RMSNorm is applied to the query and key vectors before computing attention scores. This prevents attention logits from growing too large, which can cause issues with float16/bfloat16 precision:
Models like Gemma 2 and some LLaMA 3 variants use QK-Norm for additional training stability, especially at very large scales.
4. Removing Bias Terms
One additional change worth mentioning: most modern LLMs remove bias terms from linear layers throughout the model. The original Transformer used biases in attention projections (), FFN layers, and layer normalization.
Modern models like LLaMA set bias=False everywhere. The rationale is:
- Parameter efficiency: Bias terms add parameters per linear layer, which is negligible compared to weight matrices but adds implementation complexity.
- RMSNorm has no bias: Since RMSNorm already omits the bias term , removing biases from linear layers is consistent.
- Empirical finding: Multiple ablation studies have shown no degradation from removing biases.
Summary: Original Transformer vs Modern LLMs
| Component | Original Transformer (2017) | Modern LLMs (2023+) |
|---|---|---|
| Architecture | Encoder-Decoder | Decoder-Only |
| Training objective | Seq2Seq cross-entropy | Next-token prediction |
| LayerNorm position | Post-LN | Pre-LN |
| Normalization | LayerNorm (with bias) | RMSNorm (no bias) |
| Bias terms | Yes (everywhere) | No (removed) |
| Positional encoding | Sinusoidal (additive) | RoPE (multiplicative) |
| FFN activation | ReLU | SwiGLU |
| Attention variant | Multi-Head (MHA) | Grouped-Query (GQA) |
Each of these changes is motivated by a specific failure mode or efficiency improvement discovered through scaling. The modern LLM is not a minor evolution of the original Transformer -- it is a substantially rearchitected system, rebuilt piece by piece as researchers discovered what breaks at scale.
In the next post, we will dive into Part 3: Attention Modifications -- the evolution of positional encoding from sinusoidal to RoPE, the KV-cache efficiency improvements of Multi-Query and Grouped-Query Attention, and how FlashAttention exploits GPU memory hierarchies to make attention both faster and more memory-efficient.
References
- Vaswani, A., Shazeer, N., Parmar, N., et al. (2017). "Attention Is All You Need." NeurIPS 2017.
- Radford, A., Narasimhan, K., Salimans, T., Sutskever, I. (2018). "Improving Language Understanding by Generative Pre-Training." OpenAI.
- Devlin, J., Chang, M.-W., Lee, K., Toutanova, K. (2018). "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding." NAACL 2019.
- Xiong, R., Yang, Y., He, D., et al. (2020). "On Layer Normalization in the Transformer Architecture." ICML 2020.
- Zhang, B. and Sennrich, R. (2019). "Root Mean Square Layer Normalization." NeurIPS 2019.
- Ba, J. L., Kiros, J. R., Hinton, G. E. (2016). "Layer Normalization." arXiv:1607.06450.
- Kaplan, J., McCandlish, S., Henighan, T., et al. (2020). "Scaling Laws for Neural Language Models." arXiv:2001.08361.
- Hoffmann, J., Borgeaud, S., Mensch, A., et al. (2022). "Training Compute-Optimal Large Language Models." NeurIPS 2022.
- Touvron, H., Lavril, T., Izacard, G., et al. (2023). "LLaMA: Open and Efficient Foundation Language Models." arXiv:2302.13971.
- Touvron, H., Martin, L., Stone, K., et al. (2023). "Llama 2: Open Foundation and Fine-Tuned Chat Models." arXiv:2307.09288.
Written by Suchinthaka Wanninayaka
AI/ML Researcher exploring semantic communications, diffusion models, and language model systems. Writing about deep learning from theory to production.
Continue the Series
Transformer Deep Dive: Part 1 - The Original Transformer (2017)
10 min read
Next ArticleTransformer Deep Dive: Part 3 - Attention Modifications
22 min read
Related Articles
Responses
No responses yet. Be the first to share your thoughts!