Transformer Deep Dive: Part 4 - FFN Modifications
Evolution of the Feed-Forward Network - from ReLU to GELU, SwiGLU gated activations, and Mixture of Experts for scaling to trillion-parameter models.
Suchinthaka W.
January 18, 2025 · 6 min read
While attention determines which tokens to focus on, the Feed-Forward Network (FFN) processes what to do with that information. The FFN has undergone significant evolution from the original Transformer.
The Role of the FFN
In the original Transformer, the FFN is defined as:
The FFN serves several critical purposes:
- Non-linearity: Introduces non-linear transformations
- Feature transformation: Projects to higher-dimensional space for richer interactions
- Knowledge storage: Recent research suggests FFNs act as key-value memories storing factual knowledge
- Computational capacity: Contains the majority of parameters
Parameter Distribution
| Component | Parameters | Typical Ratio | |-----------|------------|---------------| | Attention (Q, K, V, O) | | ~33% | | FFN (, ) | | ~67% |
The FFN typically contains 2× more parameters than attention!
Evolution Timeline
2017: ReLU FFN (Original Transformer)
2016/18: GELU Activation (BERT, GPT)
2020: SwiGLU FFN (GLU variants)
2022+: MoE Integration (Mixtral, GPT-4)
2024: Modern LLMs (SwiGLU standard)
Activation Functions
ReLU (Original)
Properties:
- Simple, fast to compute
- Sparse activations (many zeros)
- "Dying ReLU" problem: neurons that output 0 for all inputs
GELU (Gaussian Error Linear Unit)
where is the standard Gaussian CDF.
Properties:
- Smooth, differentiable everywhere
- Probabilistic interpretation: weights inputs by their quantile
- Used in BERT, GPT-2/3, RoBERTa
def gelu(x):
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
# Fast approximation
def gelu_fast(x):
return x * torch.sigmoid(1.702 * x)
SiLU/Swish
Properties:
- Self-gated: the input gates itself
- Smooth approximation of ReLU
- Unbounded above, bounded below
Gated Linear Units (GLU)
The key innovation in modern FFNs is gating: using one branch to control information flow through another.
Original GLU
The sigmoid-activated branch "gates" the linear branch.
SwiGLU
Shazeer (2020) proposed SwiGLU, combining Swish activation with gating:
or equivalently:
Final FFN with SwiGLU:
Implementation
class SwiGLU(nn.Module):
def __init__(self, d_model: int, d_ff: int):
super().__init__()
# Note: 3 weight matrices instead of 2
self.w1 = nn.Linear(d_model, d_ff, bias=False)
self.w2 = nn.Linear(d_ff, d_model, bias=False)
self.w3 = nn.Linear(d_model, d_ff, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# SwiGLU: Swish(x @ W1) * (x @ W3)
return self.w2(F.silu(self.w1(x)) * self.w3(x))
Parameter Adjustment
SwiGLU has 3 weight matrices instead of 2. To maintain the same parameter count, the hidden dimension is reduced:
Common configurations:
- LLaMA: (rounded to multiple of 256)
- This gives total parameters like standard FFN with
GLU Variant Comparison
| Variant | Formula | Used By | |---------|---------|---------| | GLU | | Original | | ReGLU | | - | | GEGLU | | T5 variants | | SwiGLU | | LLaMA, Mistral, PaLM |
Mixture of Experts (MoE)
The Scaling Challenge
Dense models face a fundamental limitation: all parameters are used for every token. Scaling to 1T+ parameters makes inference expensive.
MoE Solution
Idea: Have multiple "expert" FFNs but only activate a subset for each token.
where:
- are expert networks (typically FFNs)
- is a gating/routing function that selects experts
Sparse Gating
The gating function typically uses top-k routing:
Only the top-k experts (usually k=1 or k=2) are activated per token.
Architecture
Input x
│
▼
┌─────────┐
│ Router │ ─────► Select Top-K experts
└─────────┘
│
├──────────────┬──────────────┬──────────────┐
▼ ▼ ▼ ▼
┌────────┐ ┌────────┐ ┌────────┐ ┌────────┐
│Expert 1│ │Expert 2│ │Expert 3│ │Expert N│
│ (FFN) │ │ (FFN) │ │ (FFN) │ │ (FFN) │
└────────┘ └────────┘ └────────┘ └────────┘
│ │ │ │
└──────────────┴──────────────┴──────────────┘
│
▼ Weighted sum
Output
Load Balancing
A major challenge is ensuring experts are used evenly. Auxiliary losses encourage balance:
where:
- = fraction of tokens routed to expert
- = average router probability for expert
MoE Models
| Model | Total Params | Active Params | Experts | Top-K | |-------|--------------|---------------|---------|-------| | Mixtral 8x7B | 46.7B | 12.9B | 8 | 2 | | Mixtral 8x22B | 141B | 39B | 8 | 2 | | GPT-4 (rumored) | ~1.8T | ~220B | 16 | 2 |
Implementation
class MoELayer(nn.Module):
def __init__(self, d_model: int, d_ff: int, n_experts: int, top_k: int):
super().__init__()
self.n_experts = n_experts
self.top_k = top_k
# Router
self.gate = nn.Linear(d_model, n_experts, bias=False)
# Experts (each is an FFN)
self.experts = nn.ModuleList([
SwiGLU(d_model, d_ff) for _ in range(n_experts)
])
def forward(self, x: torch.Tensor) -> torch.Tensor:
batch_size, seq_len, d_model = x.shape
# Router logits
router_logits = self.gate(x) # [B, T, n_experts]
# Top-k routing
weights, indices = torch.topk(router_logits, self.top_k, dim=-1)
weights = F.softmax(weights, dim=-1)
# Compute expert outputs (simplified, actual impl uses batching)
output = torch.zeros_like(x)
for i, expert in enumerate(self.experts):
# Find tokens routed to this expert
mask = (indices == i).any(dim=-1)
if mask.any():
expert_out = expert(x[mask])
# Weight by router probability
expert_weight = weights[..., indices == i]
output[mask] += expert_weight * expert_out
return output
Expert Parallelism
MoE enables a new dimension of parallelism:
| Parallelism | What's Split | |-------------|--------------| | Data | Batch across devices | | Tensor | Layer weights across devices | | Pipeline | Layers across devices | | Expert | Different experts on different devices |
Modern FFN Architecture
Putting it all together, a modern LLM FFN block:
class ModernFFN(nn.Module):
"""LLaMA-style FFN with SwiGLU"""
def __init__(self, d_model: int, d_ff: int = None):
super().__init__()
# Default: 8/3 * d_model, rounded
if d_ff is None:
d_ff = int(8 * d_model / 3)
d_ff = 256 * ((d_ff + 255) // 256) # Round to 256
self.gate_proj = nn.Linear(d_model, d_ff, bias=False)
self.up_proj = nn.Linear(d_model, d_ff, bias=False)
self.down_proj = nn.Linear(d_ff, d_model, bias=False)
def forward(self, x):
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
Summary
| Era | FFN Type | Activation | Parameters | |-----|----------|------------|------------| | 2017 | Standard | ReLU | 2 matrices | | 2018 | Standard | GELU | 2 matrices | | 2020+ | Gated (SwiGLU) | Swish | 3 matrices | | 2023+ | MoE + SwiGLU | Swish | N×3 matrices |
In the next post, we'll explore Part 5: Training Improvements - optimizers (Adam, AdamW, Lion), learning rate schedules, mixed precision training, and gradient checkpointing.
Transformer Deep Dive: Part 3 - Attention Modifications
NextTransformer Deep Dive: Part 5 - Training Improvements
Related Articles
Responses
Be the first to share your thoughts!