Skip to content

Layers

bit_axon.layers

Classes

AxonSSM
AxonSSM(config: BitAxonConfig)

Bases: Module

Mamba-style state space model layer with causal convolution.

Implements selective SSM with hardware-aware scan, causal conv1d prefix, and a gating branch (SiLU). The SSM expansion replaces the traditional FFN/MLP role in the block.

Attributes:

Name Type Description
in_proj

Projects input to 2x intermediate dim (x and z branches).

conv1d

Depthwise causal 1D convolution.

x_proj

Projects conv output to B, C, dt parameters.

dt_proj

Projects raw dt to per-channel step sizes.

out_proj

Projects SSM output back to hidden dim.

A_log

Log of the diagonal SSM state matrix (learnable).

D

Skip connection parameter per channel.

Initialize the AxonSSM layer.

Parameters:

Name Type Description Default
config BitAxonConfig

BitAxonConfig with ssm_intermediate_dim, ssm_d_state, ssm_d_conv, and hidden_dim settings.

required
Attributes
in_proj instance-attribute
in_proj = Linear(D, 2 * E, bias=False)
conv1d instance-attribute
conv1d = Conv1d(E, E, kernel_size=d_conv, groups=E, bias=True)
x_proj instance-attribute
x_proj = Linear(E, d_state * 2 + 1, bias=False)
dt_proj instance-attribute
dt_proj = Linear(1, E, bias=True)
out_proj instance-attribute
out_proj = Linear(E, D, bias=False)
d_conv instance-attribute
d_conv = d_conv
d_state instance-attribute
d_state = d_state
E instance-attribute
E = E
A_log instance-attribute
A_log = log(A)
D instance-attribute
D = ones((E,))
Functions
__call__
__call__(x, cache=None)

Run the SSM forward pass.

Parameters:

Name Type Description Default
x

Input tensor of shape (batch, seq_len, hidden_dim).

required
cache

Optional [conv_cache, ssm_state] from a previous step.

None

Returns:

Type Description

Tuple of (output, new_cache). Output has shape (batch, seq_len, hidden_dim).

new_cache is [updated_conv_cache, updated_ssm_state] for autoregressive decoding.

AxonSSMBlock
AxonSSMBlock(config: BitAxonConfig)

Bases: Module

Pure SSM block. No MLP — SSM's internal expansion serves the FFN role.

Initialize the SSM block.

Parameters:

Name Type Description Default
config BitAxonConfig

BitAxonConfig with hidden_dim and rms_norm_eps settings.

required
Attributes
input_norm instance-attribute
input_norm = RMSNorm(hidden_dim, rms_norm_eps)
ssm instance-attribute
ssm = AxonSSM(config)
Functions
__call__
__call__(x, cache=None)

Forward pass with residual connection.

Parameters:

Name Type Description Default
x

Input tensor of shape (batch, seq_len, hidden_dim).

required
cache

Optional SSM cache from a previous step.

None

Returns:

Type Description

Tuple of (output, cache). Output has shape (batch, seq_len, hidden_dim).

AxonSWAMoEBlock
AxonSWAMoEBlock(config: BitAxonConfig)

Bases: Module

SWA + MoE block with sliding window attention and sparse experts.

Initialize the SWA + MoE block.

Parameters:

Name Type Description Default
config BitAxonConfig

BitAxonConfig with attention and MoE hyperparameters.

required
Attributes
input_norm instance-attribute
input_norm = RMSNorm(hidden_dim, rms_norm_eps)
attention instance-attribute
attention = SlidingWindowAttention(hidden_dim, num_heads, swa_window_size)
post_attention_norm instance-attribute
post_attention_norm = RMSNorm(hidden_dim, rms_norm_eps)
moe instance-attribute
moe = SharedExpertMoE(hidden_dim, moe_intermediate_dim, moe_num_experts, moe_top_k)
Functions
__call__
__call__(x, cache=None)

Forward pass: attention with residual, then MoE with residual.

Parameters:

Name Type Description Default
x

Input tensor of shape (batch, seq_len, hidden_dim).

required
cache

Optional KVCache for autoregressive decoding.

None

Returns:

Type Description

Tuple of (output, cache). Output has shape (batch, seq_len, hidden_dim).

AxonSSMMoEBlock
AxonSSMMoEBlock(config: BitAxonConfig)

Bases: Module

SSM + MoE block with linear recurrence and sparse experts.

Initialize the SSM + MoE block.

Parameters:

Name Type Description Default
config BitAxonConfig

BitAxonConfig with SSM and MoE hyperparameters.

required
Attributes
input_norm instance-attribute
input_norm = RMSNorm(hidden_dim, rms_norm_eps)
ssm instance-attribute
ssm = AxonSSM(config)
post_ssm_norm instance-attribute
post_ssm_norm = RMSNorm(hidden_dim, rms_norm_eps)
moe instance-attribute
moe = SharedExpertMoE(hidden_dim, moe_intermediate_dim, moe_num_experts, moe_top_k)
Functions
__call__
__call__(x, cache=None)

Forward pass: SSM with residual, then MoE with residual.

Parameters:

Name Type Description Default
x

Input tensor of shape (batch, seq_len, hidden_dim).

required
cache

Optional SSM cache from a previous step.

None

Returns:

Type Description

Tuple of (output, ssm_cache). Output has shape (batch, seq_len, hidden_dim).

SharedExpertMoE
SharedExpertMoE(dim: int, intermediate_dim: int, num_experts: int = 8, top_k: int = 2, bias: bool = False)

Bases: Module

Mixture of experts with top-k routing and a gated shared expert.

Routes each token to the top-k experts out of num_experts via softmax gating, while a shared expert processes all tokens. The shared expert output is gated by a learned sigmoid to allow dynamic blending.

Attributes:

Name Type Description
gate

Router producing per-expert softmax scores.

switch_mlp

SwitchGLU implementing per-expert projections.

shared_expert

MLP applied to all tokens.

shared_expert_gate

Learned gate controlling shared expert contribution.

Initialize SharedExpertMoE.

Parameters:

Name Type Description Default
dim int

Input and output dimension.

required
intermediate_dim int

Hidden dimension for each expert's SwiGLU.

required
num_experts int

Total number of routed experts.

8
top_k int

Number of experts activated per token.

2
bias bool

Whether to include bias in projections.

False
Attributes
top_k instance-attribute
top_k = top_k
num_experts instance-attribute
num_experts = num_experts
gate instance-attribute
gate = Linear(dim, num_experts, bias=False)
switch_mlp instance-attribute
switch_mlp = SwitchGLU(dim, intermediate_dim, num_experts)
shared_expert_gate instance-attribute
shared_expert_gate = Linear(dim, 1, bias=False)
shared_expert instance-attribute
shared_expert = MLP(dim, intermediate_dim, bias=bias)
Functions
__call__
__call__(x: array) -> array

Forward pass: route tokens to top-k experts and blend with shared expert.

Parameters:

Name Type Description Default
x array

Input tensor of shape (batch, seq_len, dim).

required

Returns:

Type Description
array

Output tensor of shape (batch, seq_len, dim).

SlidingWindowAttention
SlidingWindowAttention(hidden_dim: int, num_heads: int, window_size: int)

Bases: Module

Multi-head attention with a sliding window mask.

Restricts each token to attend only to tokens within a fixed window size, reducing attention from O(n²) to O(n * window_size).

Attributes:

Name Type Description
q_proj

Query projection.

k_proj

Key projection.

v_proj

Value projection.

o_proj

Output projection.

Initialize sliding window attention.

Parameters:

Name Type Description Default
hidden_dim int

Model hidden dimension.

required
num_heads int

Number of attention heads.

required
window_size int

Maximum attention distance per side.

required
Attributes
num_heads instance-attribute
num_heads = num_heads
head_dim instance-attribute
head_dim = hidden_dim // num_heads
window_size instance-attribute
window_size = window_size
scale instance-attribute
scale = 1.0 / sqrt(head_dim)
q_proj instance-attribute
q_proj = Linear(hidden_dim, hidden_dim, bias=False)
k_proj instance-attribute
k_proj = Linear(hidden_dim, hidden_dim, bias=False)
v_proj instance-attribute
v_proj = Linear(hidden_dim, hidden_dim, bias=False)
o_proj instance-attribute
o_proj = Linear(hidden_dim, hidden_dim, bias=False)
Functions
__call__
__call__(x, mask=None, cache=None)

Forward pass for sliding window attention.

Parameters:

Name Type Description Default
x

Input tensor of shape (batch, seq_len, hidden_dim).

required
mask

Optional attention mask to add to scores. If None, a sliding window causal mask is generated automatically.

None
cache

Optional KVCache for autoregressive decoding.

None

Returns:

Type Description

Tuple of (output, cache). Output has shape (batch, seq_len, hidden_dim).

cache is the updated KVCache or None.

RMSNorm
RMSNorm(dims: int, eps: float = 1e-06)

Bases: Module

RMS normalization using MLX's fast kernel.

Unlike LayerNorm, does not center the input (no mean subtraction). Normalizes by the RMS of the input and applies a learned scale.

Attributes
weight instance-attribute
weight = ones((dims,))
eps instance-attribute
eps = eps
Functions
__call__
__call__(x: array) -> array