SWA + MoE: Attention and Sparse ExpertsΒΆ
Layers 9β16 combine Sliding Window Attention (SWA) for local reasoning with a Shared-Expert Mixture-of-Experts (MoE) for sparse computation. Layers 17β24 drop attention and pair Axon-SSM with the same MoE design.
Sliding Window AttentionΒΆ
OverviewΒΆ
| Property | Value |
|---|---|
| Hidden dim | 2,560 |
| Num heads | 32 |
| Head dim (\(d_k\)) | 80 (\(2560 / 32\)) |
| Window size | 4,096 |
| Complexity | \(O(n \cdot w)\) instead of \(O(n^2)\) |
| Layers | 9β16 only |
| KV cache | External KVCache for incremental decoding |
AlgorithmΒΆ
Input x: (batch, seq_len, 2560)
β
βββββββββΌββββββββ
βΌ βΌ βΌ
q_proj k_proj v_proj (each 2560 β 2560)
β β β
βΌ βΌ βΌ
reshape + transpose β (batch, 32, seq_len, 80)
β β β
β βββββ¬ββββ
β β
β ββββββββ΄βββββββ
β β cache update β (append K, V if decoding)
β ββββββββ¬βββββββ
β β
βΌ βΌ
Q K, V
β β
βββββββ¬ββββββ
βΌ
βββββββββββββββββ
β scores = QKα΅ β scaled by 1/βd_k
β / β80 β
βββββββββ¬ββββββββ
βΌ
βββββββββββββββββ
β + sliding β causal AND window mask
β window mask β
βββββββββ¬ββββββββ
βΌ
βββββββββββββββββ
β softmax β
βββββββββ¬ββββββββ
βΌ
βββββββββββββββββ
β Γ V β
βββββββββ¬ββββββββ
βΌ
reshape + transpose β (batch, seq_len, 2560)
β
βΌ
o_proj (2560 β 2560)
β
βΌ
Output: (batch, seq_len, 2560)
Attention FormulaΒΆ
Where \(M\) is the combined mask applied element-wise to the attention scores.
Sliding Window MaskΒΆ
Each position can attend only to previous positions within the window:
The mask combines two conditions:
- Causal: \(j \leq i\) β tokens cannot attend to future positions
- Window: \(i - j < 4096\) β tokens cannot attend beyond the window
In code:
causal_mask = k_pos[None, :] <= (q_pos[:, None] + causal_offset)
window_mask = (q_pos[:, None] + causal_offset) - k_pos[None, :] < self.window_size
mask = mx.where(causal_mask & window_mask, 0.0, -mx.inf)
The causal_offset accounts for KV cache length during autoregressive decoding, where kv_len > seq_len.
KV Cache for Incremental DecodingΒΆ
Only the 8 SWA layers (9β16) produce a KV cache. During prefill, K and V tensors for the full prompt are stored. During decode, each new token's K/V row is appended:
Prefill: K, V shape = (batch, 32, prompt_len, 80)
Decode: K, V shape = (batch, 32, prompt_len + n_decoded, 80)
The cache grows linearly with generated tokens. For long contexts, TurboQuant compresses these KV tensors by 6Γ or more (see Memory Budget).
Shared-Expert MoEΒΆ
OverviewΒΆ
| Property | Value |
|---|---|
| Total experts | 8 routed + 1 shared |
| Top-\(k\) | 2 |
| Expert FFN dim | 4,096 |
| Active params per token | ~1.4B (~44% of total 3.2B) |
| Layers | 9β16 (with SWA) and 17β24 (with SSM) |
ArchitectureΒΆ
Input x: (batch, seq_len, 2560)
β
βββββββββ΄ββββββββ
β β
βΌ βΌ
βββββββββββ βββββββββββββββ
β Router β βShared Expert β always active
β (gate) β β (MLP+GLU) β
ββββββ¬βββββ ββββββββ¬βββββββ
β β
βΌ β
ββββββββββββββββ β
β softmax gate β β
β β top-2 pick β β
ββββββββ¬ββββββββ β
β β
βΌ β
ββββββββββββ β
β Expert 0 β β
β Expert 3 β β selected by routing
β ... β β
ββββββ¬ββββββ β
β β
βΌ βΌ
weighted sum sigmoid gate Γ shared_out
β β
ββββββββ¬ββββββββ
βΌ
y + shared_out
β
βΌ
Output: (batch, seq_len, 2560)
RoutingΒΆ
The gate produces expert scores via a linear projection + softmax:
gates = self.gate(x) # (batch, seq_len, 8)
gates = mx.softmax(gates, axis=-1)
inds = mx.argpartition(-gates, kth=k-1)[..., :k] # top-2 indices
scores = mx.take_along_axis(gates, inds, axis=-1) # top-2 scores
The top-2 indices (inds) are passed through stop_gradient to prevent routing gradient instability. Expert outputs are weighted by their softmax scores and summed.
SwitchGLU: Expert-Routed SwiGLUΒΆ
Each routed expert is implemented as a SwitchGLU β a SwiGLU MLP that routes tokens to per-expert weight matrices using MLX's gather_mm:
Input x + expert indices
β
ββββββΌβββββ
βΌ βΌ βΌ
gate up down (SwitchLinear per expert)
β β β
βΌ βΌ β
SwiGLU: SiLU(gate) Γ up
β
βΌ
down_proj
β
βΌ
Expert output
Gather-sort optimization: When the number of tokens exceeds 64, tokens are sorted by expert index before gather_mm. This groups same-expert tokens contiguously, improving memory access patterns:
if indices.size >= 64:
x_sorted, idx_flat, inv_order = _gather_sort(x_exp, indices)
# ... process sorted ...
y = _scatter_unsort(y_flat, inv_order, shape=(B, L, K))
Shared ExpertΒΆ
The shared expert is a standard SwiGLU MLP (gate_proj, up_proj, down_proj) applied to all tokens unconditionally. Its output is gated by a learned sigmoid:
This allows the model to dynamically blend shared knowledge (always available) with routed expert knowledge (specialized):
Why a Shared Expert?ΒΆ
| Without shared expert | With shared expert |
|---|---|
| All knowledge must be routed | General knowledge always available |
| Router errors cause information loss | Shared expert acts as safety net |
| Load balancing is critical | Less sensitive to routing quality |
| Each expert must re-learn common patterns | Shared expert handles common patterns, experts specialize |
Block CompositionΒΆ
The SWA+MoE block and SSM+MoE block share the same MoE design but differ in their first sub-layer:
AxonSWAMoEBlock (Layers 9β16)ΒΆ
x β RMSNorm β SWA β (+residual) β RMSNorm β MoE β (+residual) β output
β β
KV cache no cache
AxonSSMMoEBlock (Layers 17β24)ΒΆ
x β RMSNorm β SSM β (+residual) β RMSNorm β MoE β (+residual) β output
β β
SSM state no cache
Both blocks use pre-norm residual connections with separate RMSNorm instances for each sub-layer.
Sparsity and Thermal BenefitsΒΆ
With top-2 out of 8 experts, only 25% of routed expert parameters are active per token. Combined with the shared expert:
- Active parameters per token: ~1.4B (44% of 3.2B total)
- Dormant parameters: ~1.8B (56%) β no compute, no memory bandwidth
- Thermal impact: Lower chip utilization enables sustained inference/training on a fanless MacBook Air
The swiglu activation function is also JIT-compiled with @mx.compile:
See alsoΒΆ
- β Architecture Overview
- Sandwich Architecture Paper β Mathematical formulation of the three-zone design
- Axon-SSM β State space model used in layers 1β8 and 17β24
- Memory Budget β KV cache memory analysis for SWA layers
- API β Layers β
SlidingWindowAttentionandSharedExpertMoEPython API