콘텐츠로 이동

Inference

bit_axon.inference

Inference utilities for Bit-Axon.

Classes

GenerateConfig dataclass
GenerateConfig(max_tokens: int = 512, temperature: float = 0.6, top_k: int = 50, top_p: float = 0.95, repetition_penalty: float = 1.0, seed: int | None = None)

Configuration for autoregressive text generation.

Attributes:

Name Type Description
max_tokens int

Maximum number of tokens to generate.

temperature float

Sampling temperature. Higher values increase randomness.

top_k int

Number of top logits to keep during sampling. 0 disables filtering.

top_p float

Nucleus sampling probability threshold. 1.0 disables filtering.

repetition_penalty float

Penalty for repeated tokens. 1.0 disables penalty.

seed int | None

Optional random seed for reproducible generation.

Attributes
max_tokens class-attribute instance-attribute
max_tokens: int = 512
temperature class-attribute instance-attribute
temperature: float = 0.6
top_k class-attribute instance-attribute
top_k: int = 50
top_p class-attribute instance-attribute
top_p: float = 0.95
repetition_penalty class-attribute instance-attribute
repetition_penalty: float = 1.0
seed class-attribute instance-attribute
seed: int | None = None
Functions
GenerateResult dataclass
GenerateResult(text: str, token_ids: list[int], prompt_tokens: int, completion_tokens: int, tokens_per_sec: float, time_to_first_token_ms: float | None = None)

Result of text generation.

Attributes:

Name Type Description
text str

Decoded output text.

token_ids list[int]

Generated token IDs (excluding prompt).

prompt_tokens int

Number of tokens in the input prompt.

completion_tokens int

Number of tokens generated.

tokens_per_sec float

Generation throughput in tokens per second.

time_to_first_token_ms float | None

Time from prefill start to first sampled token, in ms.

Attributes
text instance-attribute
text: str
token_ids instance-attribute
token_ids: list[int]
prompt_tokens instance-attribute
prompt_tokens: int
completion_tokens instance-attribute
completion_tokens: int
tokens_per_sec instance-attribute
tokens_per_sec: float
time_to_first_token_ms class-attribute instance-attribute
time_to_first_token_ms: float | None = None
Functions

Functions

load_model
load_model(weights_path: str | Path, config: BitAxonConfig | None = None, quantize: bool = False, bits: int = 4, group_size: int = 64) -> BitAxonModel

Load a BitAxonModel from disk with optional NF4 quantization.

Loads weights from safetensors files in weights_path. If no config is provided, attempts to read config.json from the same directory.

Parameters:

Name Type Description Default
weights_path str | Path

Directory containing safetensors weight files.

required
config BitAxonConfig | None

Model configuration. Falls back to config.json or defaults.

None
quantize bool

If True, replace Linear layers with QuantizedLinear.

False
bits int

Quantization bit width (default 4 for NF4).

4
group_size int

Quantization group size.

64

Returns:

Type Description
BitAxonModel

Loaded BitAxonModel with weights applied.

sample_logits
sample_logits(logits: array, temperature: float = 1.0, top_k: int = 0, top_p: float = 1.0, seed: int | None = None) -> array

Sample token IDs from logits with temperature, top-k, and top-p filtering.

Parameters:

Name Type Description Default
logits array

Shape (batch, vocab_size) or (vocab_size,).

required
temperature float

Sampling temperature. 0.0 = greedy (argmax).

1.0
top_k int

Keep only top-k logits. 0 = disabled.

0
top_p float

Nucleus sampling threshold. 1.0 = disabled.

1.0
seed int | None

Optional random seed for reproducibility.

None

Returns:

Type Description
array

Token IDs with shape (batch,) or scalar.