Axon-SSM: Selective State Space ModelΒΆ
Axon-SSM is Bit-Axon's Mamba-style selective state space model. It replaces standard Transformer self-attention with linear recurrence, achieving \(O(1)\) memory per token and eliminating the KV cache entirely.
OverviewΒΆ
| Property | Value |
|---|---|
| State dimension (\(d_\text{state}\)) | 16 |
| Convolution kernel (\(d_\text{conv}\)) | 4 |
| Expansion ratio (\(\text{ssm\_expand}\)) | 3 |
| Intermediate dimension (\(E\)) | \(2560 \times 3 = 7680\) |
| Memory per token | \(O(1)\) β fixed state vector |
| KV cache | None |
Axon-SSM appears in 16 of 24 layers:
- Layers 1β8: Pure SSM (
AxonSSMBlock) β the SSM's internal expansion replaces the FFN/MLP role entirely - Layers 17β24: SSM + MoE (
AxonSSMMoEBlock) β SSM handles recurrence, MoE adds sparse computation
AlgorithmΒΆ
The forward pass through Axon-SSM follows these steps:
Input x: (batch, seq_len, hidden_dim=2560)
β
βΌ
ββββ in_proj ββββ
β (D β 2E) β
βββββββββ¬ββββββββ
β split
βββββββ΄ββββββ
βΌ βΌ
x_branch z_branch (each dim E=7680)
β β
βΌ β
ββββββββββββ β
β conv1d β β causal, kernel=4, groups=E
β (depthwise)β β
ββββββ¬ββββββ β
βΌ β
βββββββββ β
β SiLU β β
βββββ¬ββββ β
β β
βΌ β
ββββ x_proj ββββ β
β (E β 2Β·d_state+1) β
ββββββββ¬ββββββββ β
ββββββΌβββββ β
βΌ βΌ βΌ β
B C dt_raw (B, C: d_state each; dt_raw: 1)
β β β β
β β βΌ β
β β βββββββββββ
β β βdt_proj ββ (1 β E)
β β β+softplusββ
β β β+clip ββ
β β βββββ¬ββββββ
β β βΌ β
β β dt β (per-channel step size, dim E)
β β β β
βΌ βΌ βΌ β
ββββββββββββββββ β
β SSM Scan β β sequential recurrence over seq_len
β (see below) β β
ββββββββ¬ββββββββ β
βΌ β
y β
β β
βΌ βΌ
βββββββββββββββββββ
β y * SiLU(z) β gating: multiply SSM output by activated z branch
ββββββββββ¬βββββββββ
βΌ
ββββ out_proj ββββ
β (E β D=2560) β
ββββββββββ¬ββββββββ
βΌ
Output: (batch, seq_len, 2560)
Key ProjectionsΒΆ
| Layer | Shape | Purpose |
|---|---|---|
in_proj | \((D, 2E)\) | Split into \(x\) and \(z\) branches (gating) |
conv1d | Depthwise, kernel=4 | Local causal context before SSM |
x_proj | \((E, 2 \cdot d_\text{state} + 1)\) | Produces \(B\), \(C\), and raw \(\Delta t\) |
dt_proj | \((1, E)\) | Per-channel step size with bias |
out_proj | \((E, D)\) | Project back to hidden dimension |
SSM RecurrenceΒΆ
The core scan computes a discretized linear recurrence at each timestep \(t\):
Where:
| Symbol | Shape | Description |
|---|---|---|
| \(h_t\) | \((B_\text{batch}, E, d_\text{state})\) | Hidden state at time \(t\) |
| \(A\) | \((E, d_\text{state})\) | Diagonal state matrix (learnable, stored as \(\log\)) |
| \(B_t\) | \((B_\text{batch}, d_\text{state})\) | Input-selective matrix at time \(t\) |
| \(C_t\) | \((B_\text{batch}, d_\text{state})\) | Output-selective matrix at time \(t\) |
| \(\Delta t_t\) | \((B_\text{batch}, E)\) | Per-channel step size at time \(t\) |
| \(D\) | \((E,)\) | Skip connection (initialized to ones) |
DiscretizationΒΆ
The step size \(\Delta t\) is computed through a softplus projection with clamping:
This ensures \(\Delta t\) stays in a numerically stable range while remaining input-dependent (selective).
State InitializationΒΆ
The \(A\) matrix is initialized as a repeated diagonal:
At runtime: \(A = -\exp(A_{\log})\), producing diagonals from \(-1\) to \(-d_\text{state}\). The negative exponentials ensure stable decay of the hidden state over time.
Memory PropertiesΒΆ
Constant Memory Per TokenΒΆ
Unlike standard attention where the KV cache grows as \(O(n)\) with sequence length, the SSM maintains a fixed-size state:
| Component | Size |
|---|---|
| SSM state | \((B_\text{batch},\ E=7680,\ d_\text{state}=16)\) |
| Conv cache | \((B_\text{batch},\ K{-}1=3,\ E=7680)\) |
| Total per layer | ~1.5 MB (FP16, batch=1) |
| Total 16 SSM layers | ~24 MB |
Compare this with a full KV cache for 16 attention layers at 64K context, which would require several GB.
No KV CacheΒΆ
SSM layers return [conv_cache, ssm_state] as their cache β small, fixed-size tensors. The model's _create_caches() method returns None for all SSM layers and KVCache objects only for the 8 SWA layers (9β16).
JIT-Compiled KernelsΒΆ
Two leaf functions are decorated with @mx.compile for fused Metal kernel generation (following the Jamba pattern):
_ssm_fmaΒΆ
@mx.compile
def _ssm_fma(a: mx.array, b: mx.array, c: mx.array) -> mx.array:
return a * b + c # dA * h + dB * x_t (fused multiply-add)
This fuses the state update \(h_t = dA \cdot h_{t-1} + dB \cdot x_t\) into a single kernel, avoiding intermediate tensor allocation.
_compute_dtΒΆ
@mx.compile
def _compute_dt(dt: mx.array, dt_bias: mx.array, lo: float, hi: float) -> mx.array:
return mx.clip(nn.softplus(dt + dt_bias), lo, hi)
Fuses the bias addition, softplus activation, and clamping into one kernel.
Autoregressive DecodingΒΆ
During incremental (token-by-token) generation, the cache mechanism works as follows:
- First call (prefill,
cache=None): Process the full prompt, initializessm_stateto zeros, buildconv_cachefrom the last \(K-1\) positions. - Subsequent calls (decode,
cache=[conv_cache, ssm_state]): Concatenateconv_cachewith the new single token, run one step of the scan using the previousssm_state, return updated caches.
The scan loop iterates over seq_len positions β during prefill this is the full prompt length; during decode it's exactly 1.
ParametersΒΆ
Per SSM layer parameter count (with \(D=2560\), \(E=7680\), \(d_\text{state}=16\)):
| Parameter | Shape | Count |
|---|---|---|
in_proj.weight | \((2E, D)\) | 39.3M |
conv1d.weight | \((E, 1, 4)\) | 30.7K |
conv1d.bias | \((E,)\) | 7.7K |
x_proj.weight | \((33, E)\) | 253.4K |
dt_proj.weight | \((E, 1)\) | 7.7K |
dt_proj.bias | \((E,)\) | 7.7K |
A_log | \((E, 16)\) | 122.9K |
D | \((E,)\) | 7.7K |
out_proj.weight | \((D, E)\) | 19.7M |
| Total per SSM layer | ~60.2M |
With 16 SSM-bearing layers (8 pure + 8 with MoE), the SSM accounts for roughly 960M parameters of the 3.2B total.
See alsoΒΆ
- β Architecture Overview
- Axon-SSM Paper β Mathematical foundations and selective scan theory
- SWA + MoE β Sliding window attention and sparse experts used alongside SSM
- API β Layers β
AxonSSMandAxonSSMBlockPython API