Layers¶
bit_axon.layers ¶
Classes¶
AxonSSM ¶
AxonSSM(config: BitAxonConfig)
Bases: Module
Mamba-style state space model layer with causal convolution.
Implements selective SSM with hardware-aware scan, causal conv1d prefix, and a gating branch (SiLU). The SSM expansion replaces the traditional FFN/MLP role in the block.
Attributes:
| Name | Type | Description |
|---|---|---|
in_proj | Projects input to 2x intermediate dim (x and z branches). | |
conv1d | Depthwise causal 1D convolution. | |
x_proj | Projects conv output to B, C, dt parameters. | |
dt_proj | Projects raw dt to per-channel step sizes. | |
out_proj | Projects SSM output back to hidden dim. | |
A_log | Log of the diagonal SSM state matrix (learnable). | |
D | Skip connection parameter per channel. |
Initialize the AxonSSM layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config | BitAxonConfig | BitAxonConfig with ssm_intermediate_dim, ssm_d_state, ssm_d_conv, and hidden_dim settings. | required |
Attributes¶
Functions¶
__call__ ¶Run the SSM forward pass.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x | Input tensor of shape (batch, seq_len, hidden_dim). | required | |
cache | Optional [conv_cache, ssm_state] from a previous step. | None |
Returns:
| Type | Description |
|---|---|
| Tuple of (output, new_cache). Output has shape (batch, seq_len, hidden_dim). | |
| new_cache is [updated_conv_cache, updated_ssm_state] for autoregressive decoding. |
AxonSSMBlock ¶
AxonSSMBlock(config: BitAxonConfig)
Bases: Module
Pure SSM block. No MLP — SSM's internal expansion serves the FFN role.
Initialize the SSM block.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config | BitAxonConfig | BitAxonConfig with hidden_dim and rms_norm_eps settings. | required |
Attributes¶
Functions¶
__call__ ¶Forward pass with residual connection.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x | Input tensor of shape (batch, seq_len, hidden_dim). | required | |
cache | Optional SSM cache from a previous step. | None |
Returns:
| Type | Description |
|---|---|
| Tuple of (output, cache). Output has shape (batch, seq_len, hidden_dim). |
AxonSWAMoEBlock ¶
AxonSWAMoEBlock(config: BitAxonConfig)
Bases: Module
SWA + MoE block with sliding window attention and sparse experts.
Initialize the SWA + MoE block.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config | BitAxonConfig | BitAxonConfig with attention and MoE hyperparameters. | required |
Attributes¶
attention instance-attribute ¶attention = SlidingWindowAttention(hidden_dim, num_heads, swa_window_size)
moe instance-attribute ¶moe = SharedExpertMoE(hidden_dim, moe_intermediate_dim, moe_num_experts, moe_top_k)
Functions¶
__call__ ¶Forward pass: attention with residual, then MoE with residual.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x | Input tensor of shape (batch, seq_len, hidden_dim). | required | |
cache | Optional KVCache for autoregressive decoding. | None |
Returns:
| Type | Description |
|---|---|
| Tuple of (output, cache). Output has shape (batch, seq_len, hidden_dim). |
AxonSSMMoEBlock ¶
AxonSSMMoEBlock(config: BitAxonConfig)
Bases: Module
SSM + MoE block with linear recurrence and sparse experts.
Initialize the SSM + MoE block.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config | BitAxonConfig | BitAxonConfig with SSM and MoE hyperparameters. | required |
Attributes¶
moe instance-attribute ¶moe = SharedExpertMoE(hidden_dim, moe_intermediate_dim, moe_num_experts, moe_top_k)
Functions¶
__call__ ¶Forward pass: SSM with residual, then MoE with residual.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x | Input tensor of shape (batch, seq_len, hidden_dim). | required | |
cache | Optional SSM cache from a previous step. | None |
Returns:
| Type | Description |
|---|---|
| Tuple of (output, ssm_cache). Output has shape (batch, seq_len, hidden_dim). |
SharedExpertMoE ¶
SharedExpertMoE(dim: int, intermediate_dim: int, num_experts: int = 8, top_k: int = 2, bias: bool = False)
Bases: Module
Mixture of experts with top-k routing and a gated shared expert.
Routes each token to the top-k experts out of num_experts via softmax gating, while a shared expert processes all tokens. The shared expert output is gated by a learned sigmoid to allow dynamic blending.
Attributes:
| Name | Type | Description |
|---|---|---|
gate | Router producing per-expert softmax scores. | |
switch_mlp | SwitchGLU implementing per-expert projections. | |
shared_expert | MLP applied to all tokens. | |
shared_expert_gate | Learned gate controlling shared expert contribution. |
Initialize SharedExpertMoE.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dim | int | Input and output dimension. | required |
intermediate_dim | int | Hidden dimension for each expert's SwiGLU. | required |
num_experts | int | Total number of routed experts. | 8 |
top_k | int | Number of experts activated per token. | 2 |
bias | bool | Whether to include bias in projections. | False |
SlidingWindowAttention ¶
Bases: Module
Multi-head attention with a sliding window mask.
Restricts each token to attend only to tokens within a fixed window size, reducing attention from O(n²) to O(n * window_size).
Attributes:
| Name | Type | Description |
|---|---|---|
q_proj | Query projection. | |
k_proj | Key projection. | |
v_proj | Value projection. | |
o_proj | Output projection. |
Initialize sliding window attention.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim | int | Model hidden dimension. | required |
num_heads | int | Number of attention heads. | required |
window_size | int | Maximum attention distance per side. | required |
Attributes¶
Functions¶
__call__ ¶Forward pass for sliding window attention.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x | Input tensor of shape (batch, seq_len, hidden_dim). | required | |
mask | Optional attention mask to add to scores. If None, a sliding window causal mask is generated automatically. | None | |
cache | Optional KVCache for autoregressive decoding. | None |
Returns:
| Type | Description |
|---|---|
| Tuple of (output, cache). Output has shape (batch, seq_len, hidden_dim). | |
| cache is the updated KVCache or None. |
RMSNorm ¶
Bases: Module
RMS normalization using MLX's fast kernel.
Unlike LayerNorm, does not center the input (no mean subtraction). Normalizes by the RMS of the input and applies a learned scale.