Architecture OverviewΒΆ
Bit-Axon 3.2B: A 24-layer hybrid language model combining Mamba-style state space models, sliding window attention, and shared-expert mixture-of-experts β built entirely for Apple Silicon.
Design Philosophy: "No GPU, No Cloud"ΒΆ
Bit-Axon is designed from the ground up to run the full training-inference-deployment cycle on a fanless MacBook Air M4 with 16 GB unified memory. Three architectural pillars make this possible:
| Pillar | Technique | Effect |
|---|---|---|
| Linear | Axon-SSM (Mamba-style SSM) | \(O(1)\) memory per token, no KV cache |
| Sparse | Shared-Expert MoE | Only ~1.4B of 3.2B params active per token |
| Quantized | 4-bit weights + TurboQuant KV cache | ~1.76 GB weight footprint |
24-Layer Sandwich StructureΒΆ
The 24 layers are divided into three functional zones. SSM layers maintain constant memory regardless of context length; only the middle SWA zone produces a KV cache.
Bit-Axon 3.2B β 24-Layer Sandwich
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β Input: embed_tokens (vocab 32K) β input_proj (2048β2560) β
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€
β β
β Zone 1: Foundation Layers 1β8 (Pure Axon-SSM) β
β ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β AxonSSMBlock: RMSNorm β AxonSSM (no MLP) β β
β β β’ Context absorption via linear recurrence β β
β β β’ No KV cache β O(1) memory per token β β
β ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β Γ 8 layers β
β β
β Zone 2: Deep Reasoning Layers 9β16 (SWA + MoE) β
β ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β AxonSWAMoEBlock: β β
β β RMSNorm β SWA (window=4096) + residual β β
β β RMSNorm β SharedExpertMoE (8 experts, top-2) + res. β β
β β β’ Sliding window attention for local reasoning β β
β β β’ KV cache required (layers 9β16 only) β β
β ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β Γ 8 layers β
β β
β Zone 3: Output Synthesis Layers 17β24 (SSM + MoE) β
β ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β AxonSSMMoEBlock: β β
β β RMSNorm β AxonSSM + residual β β
β β RMSNorm β SharedExpertMoE (8 experts, top-2) + res. β β
β β β’ Linear recurrence + sparse experts β β
β β β’ No KV cache β thermally efficient output generation β β
β ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β Γ 8 layers β
β β
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€
β Output: output_proj (2560β2048) β lm_head (2048β32000) β
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Why This Layout?ΒΆ
| Question | Answer |
|---|---|
| Why SSM for layers 1β8? | Raw context needs absorption, not reasoning. SSM gives \(O(1)\) memory and eliminates KV cache for 1/3 of the model. |
| Why SWA only for layers 9β16? | Deep reasoning benefits from local attention. Restricting SWA to 8 layers caps KV cache memory to a manageable range. |
| Why drop attention in layers 17β24? | Output generation is autoregressive. SSM + MoE produces tokens with minimal thermal load β critical on a fanless device. |
Model ConfigurationΒΆ
| Parameter | Value | Notes |
|---|---|---|
vocab_size | 32,000 | Truncated from Qwen's 151K |
hidden_dim | 2,560 | Model width (\(d_\text{model}\)) |
num_layers | 24 | Sandwich: 8 + 8 + 8 |
num_heads | 32 | SWA attention heads |
head_dim | 80 | \(2560 / 32\) |
d_source_model | 2,048 | Qwen2.5-3B bridge dimension |
ssm_d_state | 16 | SSM state vector size |
ssm_d_conv | 4 | Causal conv1d kernel |
ssm_expand | 3 | SSM intermediate = \(2560 \times 3 = 7680\) |
swa_window_size | 4,096 | Sliding window span |
moe_num_experts | 8 | Total routed experts |
moe_top_k | 2 | Active experts per token |
moe_intermediate_dim | 4,096 | Expert FFN dimension |
moe_shared_expert | true | Shared expert always active |
max_seq_len | 65,536 | Maximum context (64K) |
weight_tying | true | embed_tokens.weight = lm_head.weight |
rms_norm_eps | 1e-6 | RMSNorm epsilon |
Key Design DecisionsΒΆ
Dimension Bridge (\(d_\text{source} = 2048\))ΒΆ
Weights are ported from Qwen2.5-3B, which uses hidden_size=2048. Bit-Axon operates at hidden_dim=2560 for wider representations. Two projection layers handle the dimension mismatch:
embed_tokens(vocab=32000, dim=2048)
β
input_proj(2048 β 2560) β projects into Bit-Axon's wider space
β
[24 sandwich layers at dim=2560]
β
output_proj(2560 β 2048) β projects back to source dimension
β
lm_head(2048 β 32000)
The first 24 of Qwen's 36 layers are mapped 1:1; layers 24β35 are discarded. SSM parameters are randomly initialized. MoE expert weights: shared expert from Qwen's MLP, expert 0 is a copy, experts 1β7 are perturbed copies.
Weight TyingΒΆ
embed_tokens.weight and lm_head.weight share the same parameter tensor. This eliminates ~64 MB of duplicate storage (\(2048 \times 32000 \times 2\) bytes in FP16).
MLX IntegrationΒΆ
The entire model is built on Apple's MLX framework, not PyTorch:
- JIT compilation: Leaf functions (
_ssm_fma,_compute_dt,swiglu) are decorated with@mx.compilefor fused Metal kernels - Unified memory zero-copy: Quantized weights are loaded once and accessed by both CPU and GPU without copies
- Metal-optimized quantization:
nn.QuantizedLinearuses fused Metal kernels for 4-bit matmul
MLX Compilation Constraints
- Model-level
mx.compileis used (layer-level doesn't work due to module reference issues) shapeless=Trueis broken for matmul in MLX β€ 0.31 β use shape-dependent compilation- NumPy interop in MoE routing breaks
mx.compiletracing β pure MLX dispatch required
Sub-PagesΒΆ
| Page | Content |
|---|---|
| Axon-SSM | Mamba-style selective state space model: algorithm, math, and implementation |
| SWA + MoE | Sliding window attention and shared-expert mixture-of-experts |
| Memory Budget | Memory breakdown, quantization strategy, and thermal management |
See alsoΒΆ
- Papers β Theoretical foundations and mathematical formulations behind each component
- Training Guide β Thermal-aware QLoRA fine-tuning workflow
- Quantization Guide β NF4 quantization details and memory impact
- Weight Porting Guide β How Qwen2.5-3B weights are mapped to Bit-Axon
- API Reference β Layers β Python API for all layer implementations