All Articles
transformersffnswiglumoeactivation-functionsdeep-learning

Transformer Deep Dive: Part 4 - FFN Modifications

Evolution of the Feed-Forward Network - from ReLU to GELU, SwiGLU gated activations, and Mixture of Experts for scaling to trillion-parameter models.

SW

Suchinthaka W.

January 18, 2025 · 6 min read

While attention determines which tokens to focus on, the Feed-Forward Network (FFN) processes what to do with that information. The FFN has undergone significant evolution from the original Transformer.

The Role of the FFN

In the original Transformer, the FFN is defined as:

FFN(x)=ReLU(xW1+b1)W2+b2\text{FFN}(x) = \text{ReLU}(xW_1 + b_1)W_2 + b_2

The FFN serves several critical purposes:

  1. Non-linearity: Introduces non-linear transformations
  2. Feature transformation: Projects to higher-dimensional space for richer interactions
  3. Knowledge storage: Recent research suggests FFNs act as key-value memories storing factual knowledge
  4. Computational capacity: Contains the majority of parameters

Parameter Distribution

| Component | Parameters | Typical Ratio | |-----------|------------|---------------| | Attention (Q, K, V, O) | 4×d24 \times d^2 | ~33% | | FFN (W1W_1, W2W_2) | 2×d×dff=8d22 \times d \times d_{ff} = 8d^2 | ~67% |

The FFN typically contains 2× more parameters than attention!

Evolution Timeline

2017: ReLU FFN (Original Transformer)
2016/18: GELU Activation (BERT, GPT)
2020: SwiGLU FFN (GLU variants)
2022+: MoE Integration (Mixtral, GPT-4)
2024: Modern LLMs (SwiGLU standard)

Activation Functions

ReLU (Original)

ReLU(x)=max(0,x)\text{ReLU}(x) = \max(0, x)

Properties:

  • Simple, fast to compute
  • Sparse activations (many zeros)
  • "Dying ReLU" problem: neurons that output 0 for all inputs

GELU (Gaussian Error Linear Unit)

GELU(x)=xΦ(x)xσ(1.702x)\text{GELU}(x) = x \cdot \Phi(x) \approx x \cdot \sigma(1.702x)

where Φ\Phi is the standard Gaussian CDF.

Properties:

  • Smooth, differentiable everywhere
  • Probabilistic interpretation: weights inputs by their quantile
  • Used in BERT, GPT-2/3, RoBERTa
def gelu(x):
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))

# Fast approximation
def gelu_fast(x):
    return x * torch.sigmoid(1.702 * x)

SiLU/Swish

SiLU(x)=xσ(x)\text{SiLU}(x) = x \cdot \sigma(x)

Properties:

  • Self-gated: the input gates itself
  • Smooth approximation of ReLU
  • Unbounded above, bounded below

Gated Linear Units (GLU)

The key innovation in modern FFNs is gating: using one branch to control information flow through another.

Original GLU

GLU(x)=(xW1)σ(xV)\text{GLU}(x) = (xW_1) \otimes \sigma(xV)

The sigmoid-activated branch "gates" the linear branch.

SwiGLU

Shazeer (2020) proposed SwiGLU, combining Swish activation with gating:

SwiGLU(x)=(Swish(xW1))(xV)\text{SwiGLU}(x) = (\text{Swish}(xW_1)) \otimes (xV)

or equivalently:

SwiGLU(x)=(xW1σ(xW1))(xV)\text{SwiGLU}(x) = (xW_1 \otimes \sigma(xW_1)) \otimes (xV)

Final FFN with SwiGLU:

FFNSwiGLU(x)=((Swish(xW1))(xV))W2\text{FFN}_{SwiGLU}(x) = ((\text{Swish}(xW_1)) \otimes (xV)) W_2

Implementation

class SwiGLU(nn.Module):
    def __init__(self, d_model: int, d_ff: int):
        super().__init__()
        # Note: 3 weight matrices instead of 2
        self.w1 = nn.Linear(d_model, d_ff, bias=False)
        self.w2 = nn.Linear(d_ff, d_model, bias=False)
        self.w3 = nn.Linear(d_model, d_ff, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # SwiGLU: Swish(x @ W1) * (x @ W3)
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

Parameter Adjustment

SwiGLU has 3 weight matrices instead of 2. To maintain the same parameter count, the hidden dimension is reduced:

dffSwiGLU=23dffReLUd_{ff}^{SwiGLU} = \frac{2}{3} d_{ff}^{ReLU}

Common configurations:

  • LLaMA: dff=83dmodeld_{ff} = \frac{8}{3} d_{model} (rounded to multiple of 256)
  • This gives 83×3=8\frac{8}{3} \times 3 = 8 total parameters like standard FFN with dff=4dd_{ff} = 4d

GLU Variant Comparison

| Variant | Formula | Used By | |---------|---------|---------| | GLU | (xW)σ(xV)(xW) \otimes \sigma(xV) | Original | | ReGLU | ReLU(xW)(xV)\text{ReLU}(xW) \otimes (xV) | - | | GEGLU | GELU(xW)(xV)\text{GELU}(xW) \otimes (xV) | T5 variants | | SwiGLU | Swish(xW)(xV)\text{Swish}(xW) \otimes (xV) | LLaMA, Mistral, PaLM |

Mixture of Experts (MoE)

The Scaling Challenge

Dense models face a fundamental limitation: all parameters are used for every token. Scaling to 1T+ parameters makes inference expensive.

MoE Solution

Idea: Have multiple "expert" FFNs but only activate a subset for each token.

MoE(x)=i=1NG(x)iEi(x)\text{MoE}(x) = \sum_{i=1}^{N} G(x)_i \cdot E_i(x)

where:

  • EiE_i are expert networks (typically FFNs)
  • G(x)G(x) is a gating/routing function that selects experts

Sparse Gating

The gating function typically uses top-k routing:

G(x)=Softmax(TopK(xWg,k))G(x) = \text{Softmax}(\text{TopK}(x \cdot W_g, k))

Only the top-k experts (usually k=1 or k=2) are activated per token.

Architecture

Input x
    │
    ▼
┌─────────┐
│ Router  │ ─────► Select Top-K experts
└─────────┘
    │
    ├──────────────┬──────────────┬──────────────┐
    ▼              ▼              ▼              ▼
┌────────┐   ┌────────┐   ┌────────┐   ┌────────┐
│Expert 1│   │Expert 2│   │Expert 3│   │Expert N│
│  (FFN) │   │  (FFN) │   │  (FFN) │   │  (FFN) │
└────────┘   └────────┘   └────────┘   └────────┘
    │              │              │              │
    └──────────────┴──────────────┴──────────────┘
                        │
                        ▼ Weighted sum
                     Output

Load Balancing

A major challenge is ensuring experts are used evenly. Auxiliary losses encourage balance:

Laux=αNi=1NfiPi\mathcal{L}_{aux} = \alpha \cdot N \sum_{i=1}^{N} f_i \cdot P_i

where:

  • fif_i = fraction of tokens routed to expert ii
  • PiP_i = average router probability for expert ii

MoE Models

| Model | Total Params | Active Params | Experts | Top-K | |-------|--------------|---------------|---------|-------| | Mixtral 8x7B | 46.7B | 12.9B | 8 | 2 | | Mixtral 8x22B | 141B | 39B | 8 | 2 | | GPT-4 (rumored) | ~1.8T | ~220B | 16 | 2 |

Implementation

class MoELayer(nn.Module):
    def __init__(self, d_model: int, d_ff: int, n_experts: int, top_k: int):
        super().__init__()
        self.n_experts = n_experts
        self.top_k = top_k

        # Router
        self.gate = nn.Linear(d_model, n_experts, bias=False)

        # Experts (each is an FFN)
        self.experts = nn.ModuleList([
            SwiGLU(d_model, d_ff) for _ in range(n_experts)
        ])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len, d_model = x.shape

        # Router logits
        router_logits = self.gate(x)  # [B, T, n_experts]

        # Top-k routing
        weights, indices = torch.topk(router_logits, self.top_k, dim=-1)
        weights = F.softmax(weights, dim=-1)

        # Compute expert outputs (simplified, actual impl uses batching)
        output = torch.zeros_like(x)
        for i, expert in enumerate(self.experts):
            # Find tokens routed to this expert
            mask = (indices == i).any(dim=-1)
            if mask.any():
                expert_out = expert(x[mask])
                # Weight by router probability
                expert_weight = weights[..., indices == i]
                output[mask] += expert_weight * expert_out

        return output

Expert Parallelism

MoE enables a new dimension of parallelism:

| Parallelism | What's Split | |-------------|--------------| | Data | Batch across devices | | Tensor | Layer weights across devices | | Pipeline | Layers across devices | | Expert | Different experts on different devices |

Modern FFN Architecture

Putting it all together, a modern LLM FFN block:

class ModernFFN(nn.Module):
    """LLaMA-style FFN with SwiGLU"""
    def __init__(self, d_model: int, d_ff: int = None):
        super().__init__()
        # Default: 8/3 * d_model, rounded
        if d_ff is None:
            d_ff = int(8 * d_model / 3)
            d_ff = 256 * ((d_ff + 255) // 256)  # Round to 256

        self.gate_proj = nn.Linear(d_model, d_ff, bias=False)
        self.up_proj = nn.Linear(d_model, d_ff, bias=False)
        self.down_proj = nn.Linear(d_ff, d_model, bias=False)

    def forward(self, x):
        return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))

Summary

| Era | FFN Type | Activation | Parameters | |-----|----------|------------|------------| | 2017 | Standard | ReLU | 2 matrices | | 2018 | Standard | GELU | 2 matrices | | 2020+ | Gated (SwiGLU) | Swish | 3 matrices | | 2023+ | MoE + SwiGLU | Swish | N×3 matrices |


In the next post, we'll explore Part 5: Training Improvements - optimizers (Adam, AdamW, Lion), learning rate schedules, mixed precision training, and gradient checkpointing.

Did you find this helpful?
Share:

Responses

Be the first to share your thoughts!