콘텐츠로 이동

Model

bit_axon.model.BitAxonModel

BitAxonModel(config: BitAxonConfig)

Bases: Module

3.2B hybrid language model with SSM, SWA, and MoE layers.

24-layer sandwich architecture
  • Layers 0-7: Pure SSM (linear recurrence, no KV cache)
  • Layers 8-15: SWA + MoE (sliding window attention + sparse experts)
  • Layers 16-23: SSM + MoE (linear recurrence + sparse experts)

Attributes:

Name Type Description
config

Model configuration.

embed_tokens

Token embedding table.

input_proj

Projects from source model dimension to hidden dim.

output_proj

Projects from hidden dim back to source model dimension.

lm_head

Output projection to vocabulary logits.

Initialize the BitAxon model.

Parameters:

Name Type Description Default
config BitAxonConfig

BitAxonConfig with architecture hyperparameters.

required

Functions

__call__
__call__(input_ids: array, cache=None)

Forward pass through all layers.

Parameters:

Name Type Description Default
input_ids array

Token indices of shape (batch, seq_len).

required
cache

Optional list of per-layer caches from a previous call.

None

Returns:

Type Description

Tuple of (logits, new_caches). Logits have shape (batch, seq_len, vocab_size).

new_caches is a list of updated per-layer caches.

_get_layer_type staticmethod
_get_layer_type(layer_idx: int, total_layers: int) -> str

Determine layer type based on position in the sandwich architecture.

Parameters:

Name Type Description Default
layer_idx int

Zero-based layer index.

required
total_layers int

Total number of layers.

required

Returns:

Type Description
str

One of "ssm", "swa_moe", or "ssm_moe".

_create_caches
_create_caches() -> list[object]

Create KV caches for SWA layers; None for SSM layers.

Returns:

Type Description
list[object]

List of KVCache objects for swa_moe layers and None for ssm/ssm_moe layers.