Model¶
bit_axon.model.BitAxonModel ¶
BitAxonModel(config: BitAxonConfig)
Bases: Module
3.2B hybrid language model with SSM, SWA, and MoE layers.
24-layer sandwich architecture
- Layers 0-7: Pure SSM (linear recurrence, no KV cache)
- Layers 8-15: SWA + MoE (sliding window attention + sparse experts)
- Layers 16-23: SSM + MoE (linear recurrence + sparse experts)
Attributes:
| Name | Type | Description |
|---|---|---|
config | Model configuration. | |
embed_tokens | Token embedding table. | |
input_proj | Projects from source model dimension to hidden dim. | |
output_proj | Projects from hidden dim back to source model dimension. | |
lm_head | Output projection to vocabulary logits. |
Initialize the BitAxon model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config | BitAxonConfig | BitAxonConfig with architecture hyperparameters. | required |
Functions¶
__call__ ¶
Forward pass through all layers.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_ids | array | Token indices of shape (batch, seq_len). | required |
cache | Optional list of per-layer caches from a previous call. | None |
Returns:
| Type | Description |
|---|---|
| Tuple of (logits, new_caches). Logits have shape (batch, seq_len, vocab_size). | |
| new_caches is a list of updated per-layer caches. |