Inference¶
bit_axon.inference ¶
Inference utilities for Bit-Axon.
Classes¶
GenerateConfig dataclass ¶
GenerateConfig(max_tokens: int = 512, temperature: float = 0.6, top_k: int = 50, top_p: float = 0.95, repetition_penalty: float = 1.0, seed: int | None = None)
Configuration for autoregressive text generation.
Attributes:
| Name | Type | Description |
|---|---|---|
max_tokens | int | Maximum number of tokens to generate. |
temperature | float | Sampling temperature. Higher values increase randomness. |
top_k | int | Number of top logits to keep during sampling. 0 disables filtering. |
top_p | float | Nucleus sampling probability threshold. 1.0 disables filtering. |
repetition_penalty | float | Penalty for repeated tokens. 1.0 disables penalty. |
seed | int | None | Optional random seed for reproducible generation. |
GenerateResult dataclass ¶
GenerateResult(text: str, token_ids: list[int], prompt_tokens: int, completion_tokens: int, tokens_per_sec: float, time_to_first_token_ms: float | None = None)
Result of text generation.
Attributes:
| Name | Type | Description |
|---|---|---|
text | str | Decoded output text. |
token_ids | list[int] | Generated token IDs (excluding prompt). |
prompt_tokens | int | Number of tokens in the input prompt. |
completion_tokens | int | Number of tokens generated. |
tokens_per_sec | float | Generation throughput in tokens per second. |
time_to_first_token_ms | float | None | Time from prefill start to first sampled token, in ms. |
Functions¶
load_model ¶
load_model(weights_path: str | Path, config: BitAxonConfig | None = None, quantize: bool = False, bits: int = 4, group_size: int = 64) -> BitAxonModel
Load a BitAxonModel from disk with optional NF4 quantization.
Loads weights from safetensors files in weights_path. If no config is provided, attempts to read config.json from the same directory.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
weights_path | str | Path | Directory containing safetensors weight files. | required |
config | BitAxonConfig | None | Model configuration. Falls back to config.json or defaults. | None |
quantize | bool | If True, replace Linear layers with QuantizedLinear. | False |
bits | int | Quantization bit width (default 4 for NF4). | 4 |
group_size | int | Quantization group size. | 64 |
Returns:
| Type | Description |
|---|---|
BitAxonModel | Loaded BitAxonModel with weights applied. |
sample_logits ¶
sample_logits(logits: array, temperature: float = 1.0, top_k: int = 0, top_p: float = 1.0, seed: int | None = None) -> array
Sample token IDs from logits with temperature, top-k, and top-p filtering.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
logits | array | Shape (batch, vocab_size) or (vocab_size,). | required |
temperature | float | Sampling temperature. 0.0 = greedy (argmax). | 1.0 |
top_k | int | Keep only top-k logits. 0 = disabled. | 0 |
top_p | float | Nucleus sampling threshold. 1.0 = disabled. | 1.0 |
seed | int | None | Optional random seed for reproducibility. | None |
Returns:
| Type | Description |
|---|---|
array | Token IDs with shape (batch,) or scalar. |