All Articles
transformersmambassmrwkvlinear-attentiondeep-learning

Transformer Deep Dive: Part 8 - Alternative Architectures

Beyond Transformers - State Space Models (SSMs), Mamba, Linear Attention, RWKV, and hybrid architectures that challenge the attention paradigm.

SW

Suchinthaka W.

January 22, 2025 · 7 min read

The Transformer's quadratic complexity in sequence length (O(n2)O(n^2)) has motivated research into alternative architectures that can achieve sub-quadratic complexity while maintaining competitive performance.

The Motivation

Self-attention computes pairwise interactions between all positions:

Attention(Q,K,V)=softmax(QKdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)V

For sequence length n, this requires O(n2)O(n^2) time and space. For very long sequences (100K+ tokens), this becomes prohibitive.

The Quest for Linear Complexity

| Architecture | Time | Space | Long-range | |--------------|------|-------|------------| | Transformer | O(n2)O(n^2) | O(n2)O(n^2) | Excellent | | SSM/Mamba | O(n)O(n) | O(n)O(n) | Good | | Linear Attention | O(n)O(n) | O(n)O(n) | Limited | | RWKV | O(n)O(n) | O(n)O(n) | Good |

State Space Models (SSMs)

The Continuous Formulation

SSMs are rooted in control theory. A continuous-time SSM is defined by:

dh(t)dt=Ah(t)+Bx(t)\frac{dh(t)}{dt} = Ah(t) + Bx(t) y(t)=Ch(t)+Dx(t)y(t) = Ch(t) + Dx(t)

where:

  • x(t)x(t): Input signal
  • h(t)h(t): Hidden state vector (dimension N)
  • y(t)y(t): Output signal
  • AA: State transition matrix (N×NN \times N)
  • BB: Input projection (N×1N \times 1)
  • CC: Output projection (1×N1 \times N)
  • DD: Skip connection (often 0)

Discretization

For digital computation, we discretize using step size Δ\Delta:

Aˉ=exp(ΔA)\bar{A} = \exp(\Delta A) Bˉ=(ΔA)1(exp(ΔA)I)ΔB\bar{B} = (\Delta A)^{-1}(\exp(\Delta A) - I) \cdot \Delta B

The discrete recurrence becomes:

hk=Aˉhk1+Bˉxkh_k = \bar{A}h_{k-1} + \bar{B}x_k yk=Chky_k = Ch_k

Parallel Computation via Convolution

The recurrence can be unrolled as a convolution:

y=xKˉy = x * \bar{K}

where the kernel is:

Kˉ=(CBˉ,CAˉBˉ,CAˉ2Bˉ,,CAˉL1Bˉ)\bar{K} = (C\bar{B}, C\bar{A}\bar{B}, C\bar{A}^2\bar{B}, \ldots, C\bar{A}^{L-1}\bar{B})

This enables parallel training via FFT!

The S4 Innovation

S4 (Structured State Spaces for Sequence Modeling) made SSMs practical by:

  1. HiPPO initialization: Special A matrix that captures long-range dependencies
  2. Diagonal structure: Efficient computation with O(N)O(N) instead of O(N2)O(N^2)
  3. Parallel scan: GPU-efficient recurrence computation

Mamba: Selective State Spaces

The Key Insight

Traditional SSMs use fixed (input-independent) A, B, C matrices. Mamba makes them input-dependent:

Bt=fB(xt),Ct=fC(xt),Δt=fΔ(xt)B_t = f_B(x_t), \quad C_t = f_C(x_t), \quad \Delta_t = f_\Delta(x_t)

This enables content-aware processing—the model can decide what information to remember or forget.

The Selectivity Mechanism

Input: "The capital of France is"

Fixed SSM: Treats all tokens equally
Mamba:     Attends strongly to "France", weakly to "The"

Architecture

                    ┌─────────────────────────────┐
                    │         Mamba Block         │
                    │                             │
Input ──► Linear ──┼──► Conv1D ──► SSM ──┐      │
                    │                     │      │
                    │      ┌── × ──────────┤      │
                    │      │              │      │
                    └──────┴── SiLU ──────┴──► Out

Key components:

  1. Linear projection to expand dimension
  2. 1D convolution for local context
  3. Selective SSM for sequence mixing
  4. Gating with SiLU activation

Mamba Implementation (Simplified)

class MambaBlock(nn.Module):
    def __init__(self, d_model: int, d_state: int = 16, d_conv: int = 4):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state

        # Projections
        self.in_proj = nn.Linear(d_model, 2 * d_model)
        self.out_proj = nn.Linear(d_model, d_model)

        # Conv for local context
        self.conv = nn.Conv1d(d_model, d_model, d_conv, padding=d_conv-1, groups=d_model)

        # SSM parameters (input-dependent)
        self.x_proj = nn.Linear(d_model, d_state * 2 + 1)  # B, C, delta
        self.dt_proj = nn.Linear(1, d_model)

        # Fixed A (log scale for stability)
        self.A_log = nn.Parameter(torch.log(torch.arange(1, d_state + 1).float()))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, L, D = x.shape

        # Project and split
        xz = self.in_proj(x)
        x, z = xz.chunk(2, dim=-1)

        # Conv
        x = self.conv(x.transpose(1, 2))[:, :, :L].transpose(1, 2)
        x = F.silu(x)

        # SSM (selective scan - simplified)
        y = self.selective_scan(x)

        # Gating
        y = y * F.silu(z)

        return self.out_proj(y)

    def selective_scan(self, x):
        # Compute input-dependent B, C, delta
        # Run parallel scan
        # (Actual implementation uses custom CUDA kernels)
        ...

Mamba Performance

| Model | Params | Perplexity | Throughput | |-------|--------|------------|------------| | Transformer | 1.4B | 14.2 | 1× | | Mamba | 1.4B | 14.0 | 5× |

Mamba matches Transformer quality with 5× higher inference throughput!

Linear Attention

The Idea

Standard attention:

Attn(Q,K,V)=softmax(QK)V\text{Attn}(Q, K, V) = \text{softmax}(QK^\top)V

Linear attention removes softmax:

LinearAttn(Q,K,V)=ϕ(Q)(ϕ(K)V)\text{LinearAttn}(Q, K, V) = \phi(Q)(\phi(K)^\top V)

where ϕ\phi is a feature map.

The Trick: Associativity

By reordering computation:

ϕ(Q)(ϕ(K)V)=ϕ(Q)(ϕ(K)V)d×d\phi(Q)(\phi(K)^\top V) = \phi(Q) \cdot \underbrace{(\phi(K)^\top V)}_{d \times d}

We compute the d×dd \times d matrix first, then multiply with each query. This is O(nd2)O(n \cdot d^2) instead of O(n2d)O(n^2 \cdot d).

Feature Maps

| Method | ϕ(x)\phi(x) | Properties | |--------|-----------|------------| | Linear | xx | Simplest, limited expressivity | | ELU+1 | ELU(x)+1\text{ELU}(x) + 1 | Positive, smooth | | Random Features | 1mexp(Wx)\frac{1}{\sqrt{m}}\exp(Wx) | Approximates softmax | | Performer | Random Fourier features | Unbiased approximation |

Limitations

  • No sharp attention patterns: Can't focus on single tokens
  • Approximation quality: May not match softmax exactly
  • Training stability: Can be harder to train

RWKV

The Concept

RWKV combines the best of RNNs and Transformers:

  • Training: Parallelizable like Transformers
  • Inference: O(1)O(1) per token like RNNs

WKV Mechanism

RWKV uses a novel "WKV" (weighted key-value) mechanism:

wkvt=i=1t1e(t1i)w+kivi+eu+ktvti=1t1e(t1i)w+ki+eu+kt\text{wkv}_t = \frac{\sum_{i=1}^{t-1} e^{-(t-1-i)w+k_i} v_i + e^{u+k_t} v_t}{\sum_{i=1}^{t-1} e^{-(t-1-i)w+k_i} + e^{u+k_t}}

where:

  • ww: Learned decay rate per channel
  • uu: Learned bonus for current token
  • k,vk, v: Key and value projections

RWKV Architecture

Token → Embedding → [RWKV Block] × L → LayerNorm → Output

RWKV Block:
├── LayerNorm
├── Time Mixing (WKV)
├── Residual
├── LayerNorm
├── Channel Mixing (FFN-like)
└── Residual

Comparison to Attention

| Aspect | Attention | RWKV | |--------|-----------|------| | Complexity | O(n2)O(n^2) | O(n)O(n) | | Long-range | Excellent | Good (with decay) | | KV Cache | Grows with n | Fixed size | | Training | Parallel | Parallel | | Inference | Parallel | Sequential (fast) |

Hybrid Architectures

The Best of Both Worlds

Some architectures combine attention and linear methods:

Jamba (AI21):

  • Alternates Mamba and Attention layers
  • Uses MoE for scaling
  • Mamba for efficiency, Attention for complex patterns

Griffin (Google):

  • RNN-like gated linear recurrence
  • Local attention for nearby context
  • MLP for channel mixing

Design Patterns

Hybrid Block Options:

1. Interleaved:
   [Mamba] → [Attn] → [Mamba] → [Attn] → ...

2. Ratio-based:
   [Mamba] × 3 → [Attn] → [Mamba] × 3 → [Attn] → ...

3. Hierarchical:
   Local: Mamba
   Global: Sparse Attention

4. Parallel:
   [Mamba] ─┐
            ├─► Add
   [Attn] ──┘

Jamba Architecture

Block 0: Attention + MoE
Block 1-3: Mamba + MoE
Block 4: Attention + MoE
Block 5-7: Mamba + MoE
...

Ratio: 1 Attention : 7 Mamba

Comparison Summary

| Architecture | Training | Inference | Long Context | Quality | |--------------|----------|-----------|--------------|---------| | Transformer | Parallel | Parallel | O(n2)O(n^2) memory | Best | | Mamba | Parallel | Sequential | O(n)O(n) memory | Near-best | | Linear Attn | Parallel | Parallel | O(n)O(n) memory | Good | | RWKV | Parallel | Sequential | O(1)O(1) memory | Good | | Jamba | Parallel | Mixed | Efficient | Near-best |

When to Use What?

| Use Case | Recommendation | |----------|----------------| | Best quality, moderate context | Transformer | | Very long context (100K+) | Mamba or Hybrid | | Resource-constrained inference | RWKV | | Streaming applications | Mamba, RWKV | | Maximum throughput | Mamba |

The Future

The field is rapidly evolving:

  1. Hybrid architectures may dominate, combining attention's expressivity with SSM efficiency
  2. Hardware co-design will optimize for specific architectures
  3. Task-specific architectures may emerge for different use cases
  4. Scaling laws for alternative architectures are still being understood

Conclusion

While "Attention Is All You Need" revolutionized NLP, the quest for efficiency has spawned alternatives that challenge this paradigm. State Space Models, Mamba, and hybrid architectures offer compelling tradeoffs between quality, speed, and memory efficiency.

The transformer isn't going away—but it's no longer the only game in town.


This concludes the "Transformer Deep Dive" series. We've covered:

  1. Part 1: Original Transformer (2017)
  2. Part 2: Architecture Changes (Decoder-only, Pre-LN, RMSNorm)
  3. Part 3: Attention Modifications (RoPE, GQA, FlashAttention)
  4. Part 4: FFN Modifications (SwiGLU, MoE)
  5. Part 5: Training Improvements (AdamW, Mixed Precision)
  6. Part 6: Inference Optimization (KV-cache, Quantization)
  7. Part 7: Minor But Important Changes
  8. Part 8: Alternative Architectures

Thanks for following along!

Did you find this helpful?
Share:

Responses

Be the first to share your thoughts!