Skip to content

SWA + MoE: Attention and Sparse ExpertsΒΆ

Layers 9–16 combine Sliding Window Attention (SWA) for local reasoning with a Shared-Expert Mixture-of-Experts (MoE) for sparse computation. Layers 17–24 drop attention and pair Axon-SSM with the same MoE design.

Sliding Window AttentionΒΆ

OverviewΒΆ

Property Value
Hidden dim 2,560
Num heads 32
Head dim (\(d_k\)) 80 (\(2560 / 32\))
Window size 4,096
Complexity \(O(n \cdot w)\) instead of \(O(n^2)\)
Layers 9–16 only
KV cache External KVCache for incremental decoding

AlgorithmΒΆ

Input x: (batch, seq_len, 2560)
            β”‚
    β”Œβ”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”
    β–Ό       β–Ό       β–Ό
  q_proj  k_proj  v_proj     (each 2560 β†’ 2560)
    β”‚       β”‚       β”‚
    β–Ό       β–Ό       β–Ό
  reshape + transpose  β†’  (batch, 32, seq_len, 80)
    β”‚       β”‚       β”‚
    β”‚       β””β”€β”€β”€β”¬β”€β”€β”€β”˜
    β”‚           β”‚
    β”‚    β”Œβ”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”
    β”‚    β”‚ cache update β”‚  (append K, V if decoding)
    β”‚    β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”˜
    β”‚           β”‚
    β–Ό           β–Ό
    Q          K, V
    β”‚           β”‚
    β””β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”˜
          β–Ό
  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
  β”‚  scores = QKα΅€ β”‚   scaled by 1/√d_k
  β”‚    / √80      β”‚
  β””β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”˜
          β–Ό
  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
  β”‚ + sliding     β”‚   causal AND window mask
  β”‚   window mask β”‚
  β””β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”˜
          β–Ό
  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
  β”‚   softmax     β”‚
  β””β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”˜
          β–Ό
  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
  β”‚   Γ— V         β”‚
  β””β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”˜
          β–Ό
  reshape + transpose  β†’  (batch, seq_len, 2560)
          β”‚
          β–Ό
      o_proj (2560 β†’ 2560)
          β”‚
          β–Ό
      Output: (batch, seq_len, 2560)

Attention FormulaΒΆ

\[\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^T}{\sqrt{d_k}} + M\right) V\]

Where \(M\) is the combined mask applied element-wise to the attention scores.

Sliding Window MaskΒΆ

Each position can attend only to previous positions within the window:

\[M_{i,j} = \begin{cases} 0 & \text{if } j \leq i \text{ and } i - j < w \\ -\infty & \text{otherwise} \end{cases}\]

The mask combines two conditions:

  1. Causal: \(j \leq i\) β€” tokens cannot attend to future positions
  2. Window: \(i - j < 4096\) β€” tokens cannot attend beyond the window

In code:

causal_mask = k_pos[None, :] <= (q_pos[:, None] + causal_offset)
window_mask = (q_pos[:, None] + causal_offset) - k_pos[None, :] < self.window_size
mask = mx.where(causal_mask & window_mask, 0.0, -mx.inf)

The causal_offset accounts for KV cache length during autoregressive decoding, where kv_len > seq_len.

KV Cache for Incremental DecodingΒΆ

Only the 8 SWA layers (9–16) produce a KV cache. During prefill, K and V tensors for the full prompt are stored. During decode, each new token's K/V row is appended:

Prefill:  K, V shape = (batch, 32, prompt_len, 80)
Decode:   K, V shape = (batch, 32, prompt_len + n_decoded, 80)

The cache grows linearly with generated tokens. For long contexts, TurboQuant compresses these KV tensors by 6Γ— or more (see Memory Budget).

Shared-Expert MoEΒΆ

OverviewΒΆ

Property Value
Total experts 8 routed + 1 shared
Top-\(k\) 2
Expert FFN dim 4,096
Active params per token ~1.4B (~44% of total 3.2B)
Layers 9–16 (with SWA) and 17–24 (with SSM)

ArchitectureΒΆ

Input x: (batch, seq_len, 2560)
            β”‚
    β”Œβ”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”
    β”‚               β”‚
    β–Ό               β–Ό
 β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”   β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
 β”‚  Router  β”‚   β”‚Shared Expert β”‚   always active
 β”‚ (gate)   β”‚   β”‚  (MLP+GLU)  β”‚
 β””β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”˜   β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”˜
      β”‚               β”‚
      β–Ό               β”‚
 β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”     β”‚
 β”‚ softmax gate  β”‚     β”‚
 β”‚ β†’ top-2 pick  β”‚     β”‚
 β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”˜     β”‚
        β”‚             β”‚
        β–Ό             β”‚
  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”        β”‚
  β”‚ Expert 0 β”‚        β”‚
  β”‚ Expert 3 β”‚  ← selected by routing
  β”‚   ...    β”‚        β”‚
  β””β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”˜        β”‚
       β”‚              β”‚
       β–Ό              β–Ό
  weighted sum    sigmoid gate Γ— shared_out
       β”‚              β”‚
       β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”˜
              β–Ό
           y + shared_out
              β”‚
              β–Ό
      Output: (batch, seq_len, 2560)

RoutingΒΆ

The gate produces expert scores via a linear projection + softmax:

gates = self.gate(x)                    # (batch, seq_len, 8)
gates = mx.softmax(gates, axis=-1)
inds = mx.argpartition(-gates, kth=k-1)[..., :k]   # top-2 indices
scores = mx.take_along_axis(gates, inds, axis=-1)   # top-2 scores

The top-2 indices (inds) are passed through stop_gradient to prevent routing gradient instability. Expert outputs are weighted by their softmax scores and summed.

SwitchGLU: Expert-Routed SwiGLUΒΆ

Each routed expert is implemented as a SwitchGLU β€” a SwiGLU MLP that routes tokens to per-expert weight matrices using MLX's gather_mm:

Input x + expert indices
        β”‚
   β”Œβ”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”
   β–Ό    β–Ό    β–Ό
gate  up   down      (SwitchLinear per expert)
 β”‚    β”‚    β”‚
 β–Ό    β–Ό    β”‚
SwiGLU: SiLU(gate) Γ— up
        β”‚
        β–Ό
     down_proj
        β”‚
        β–Ό
    Expert output

Gather-sort optimization: When the number of tokens exceeds 64, tokens are sorted by expert index before gather_mm. This groups same-expert tokens contiguously, improving memory access patterns:

if indices.size >= 64:
    x_sorted, idx_flat, inv_order = _gather_sort(x_exp, indices)
    # ... process sorted ...
    y = _scatter_unsort(y_flat, inv_order, shape=(B, L, K))

Shared ExpertΒΆ

The shared expert is a standard SwiGLU MLP (gate_proj, up_proj, down_proj) applied to all tokens unconditionally. Its output is gated by a learned sigmoid:

\[y_\text{shared} = \sigma(W_\text{gate} \cdot x) \cdot \text{MLP}(x)\]

This allows the model to dynamically blend shared knowledge (always available) with routed expert knowledge (specialized):

shared_out = self.shared_expert(x)
shared_out = mx.sigmoid(self.shared_expert_gate(x)) * shared_out

Why a Shared Expert?ΒΆ

Without shared expert With shared expert
All knowledge must be routed General knowledge always available
Router errors cause information loss Shared expert acts as safety net
Load balancing is critical Less sensitive to routing quality
Each expert must re-learn common patterns Shared expert handles common patterns, experts specialize

Block CompositionΒΆ

The SWA+MoE block and SSM+MoE block share the same MoE design but differ in their first sub-layer:

AxonSWAMoEBlock (Layers 9–16)ΒΆ

x β†’ RMSNorm β†’ SWA β†’ (+residual) β†’ RMSNorm β†’ MoE β†’ (+residual) β†’ output
                              ↑                    ↑
                          KV cache            no cache

AxonSSMMoEBlock (Layers 17–24)ΒΆ

x β†’ RMSNorm β†’ SSM β†’ (+residual) β†’ RMSNorm β†’ MoE β†’ (+residual) β†’ output
                              ↑                    ↑
                        SSM state             no cache

Both blocks use pre-norm residual connections with separate RMSNorm instances for each sub-layer.

Sparsity and Thermal BenefitsΒΆ

With top-2 out of 8 experts, only 25% of routed expert parameters are active per token. Combined with the shared expert:

  • Active parameters per token: ~1.4B (44% of 3.2B total)
  • Dormant parameters: ~1.8B (56%) β€” no compute, no memory bandwidth
  • Thermal impact: Lower chip utilization enables sustained inference/training on a fanless MacBook Air

The swiglu activation function is also JIT-compiled with @mx.compile:

@mx.compile
def swiglu(x: mx.array, gate: mx.array) -> mx.array:
    return nn.silu(gate) * x

See alsoΒΆ