콘텐츠로 이동

Training

bit_axon.training

Bit-Axon training module: SFT, ORPO alignment, LoRA/DoRA adapters, and model merging.

Classes

TrainingConfig dataclass
TrainingConfig(learning_rate: float = 0.0001, weight_decay: float = 0.01, warmup_steps: int = 100, max_steps: int = 10000, max_grad_norm: float = 1.0, grad_accum_steps: int = 4, lora_rank: int = 8, lora_dropout: float = 0.0, lora_scale: float = 20.0, lora_targets: tuple[str, ...] = ('q_proj', 'k_proj', 'v_proj', 'o_proj', 'in_proj', 'out_proj', 'gate_proj', 'up_proj', 'down_proj', 'input_proj', 'output_proj'), use_dora: bool = True, beta: float = 0.1, training_mode: str = 'sft', quantize_bits: int = 4, quantize_group_size: int = 64, batch_size: int = 1, max_seq_len: int = 2048, save_every: int = 500, eval_every: int = 500, output_dir: str = 'checkpoints', temp_max_speed: float = 75.0, temp_pause: float = 85.0, temp_stop: float = 95.0, temp_poll_interval: float = 1.0, seed: int = 42)

Bit-Axon SFT training configuration.

Thermal-aware QLoRA fine-tuning for fanless MacBook Air M4.

Attributes
learning_rate class-attribute instance-attribute
learning_rate: float = 0.0001
weight_decay class-attribute instance-attribute
weight_decay: float = 0.01
warmup_steps class-attribute instance-attribute
warmup_steps: int = 100
max_steps class-attribute instance-attribute
max_steps: int = 10000
max_grad_norm class-attribute instance-attribute
max_grad_norm: float = 1.0
grad_accum_steps class-attribute instance-attribute
grad_accum_steps: int = 4
lora_rank class-attribute instance-attribute
lora_rank: int = 8
lora_dropout class-attribute instance-attribute
lora_dropout: float = 0.0
lora_scale class-attribute instance-attribute
lora_scale: float = 20.0
lora_targets class-attribute instance-attribute
lora_targets: tuple[str, ...] = ('q_proj', 'k_proj', 'v_proj', 'o_proj', 'in_proj', 'out_proj', 'gate_proj', 'up_proj', 'down_proj', 'input_proj', 'output_proj')
use_dora class-attribute instance-attribute
use_dora: bool = True
beta class-attribute instance-attribute
beta: float = 0.1
training_mode class-attribute instance-attribute
training_mode: str = 'sft'
quantize_bits class-attribute instance-attribute
quantize_bits: int = 4
quantize_group_size class-attribute instance-attribute
quantize_group_size: int = 64
batch_size class-attribute instance-attribute
batch_size: int = 1
max_seq_len class-attribute instance-attribute
max_seq_len: int = 2048
save_every class-attribute instance-attribute
save_every: int = 500
eval_every class-attribute instance-attribute
eval_every: int = 500
output_dir class-attribute instance-attribute
output_dir: str = 'checkpoints'
temp_max_speed class-attribute instance-attribute
temp_max_speed: float = 75.0
temp_pause class-attribute instance-attribute
temp_pause: float = 85.0
temp_stop class-attribute instance-attribute
temp_stop: float = 95.0
temp_poll_interval class-attribute instance-attribute
temp_poll_interval: float = 1.0
seed class-attribute instance-attribute
seed: int = 42
Functions
Trainer
Trainer(model: Module, config, dataset, val_dataset=None, cooling_scheduler=None)

Thermal-aware QLoRA SFT trainer for Bit-Axon.

Orchestrates the full training loop: data iteration, gradient accumulation, gradient clipping, checkpointing, evaluation, and thermal gating.

Attributes
model instance-attribute
model = model
config instance-attribute
config: TrainingConfig = config
dataset instance-attribute
dataset = dataset
val_dataset instance-attribute
val_dataset = val_dataset
cooling instance-attribute
cooling = cooling_scheduler
optimizer instance-attribute
optimizer = None
step_count instance-attribute
step_count = 0
Functions
setup
setup() -> None

Initialize optimizer, scheduler, and loss function. Resume from checkpoint if available.

train
train() -> dict

Run the main training loop.

Returns:

Type Description
dict

Dict with final training stats: {"step": int, "loss": float, "grad_norm": float}

evaluate
evaluate() -> dict

Run evaluation on val_dataset.

Returns:

Type Description
dict

{"loss": float, "perplexity": float}

ORPOTrainer
ORPOTrainer(model: Module, config, dataset, val_dataset=None, cooling_scheduler=None)

Thermal-aware ORPO preference trainer for Bit-Axon.

Performs simultaneous SFT and preference alignment using the ORPO objective (no reference model required). Uses two forward passes per batch (chosen + rejected) with gradient accumulation.

Attributes
model instance-attribute
model = model
config instance-attribute
config: TrainingConfig = config
dataset instance-attribute
dataset = dataset
val_dataset instance-attribute
val_dataset = val_dataset
cooling instance-attribute
cooling = cooling_scheduler
optimizer instance-attribute
optimizer = None
step_count instance-attribute
step_count = 0
Functions
setup
setup() -> None

Initialize optimizer, scheduler, and loss function. Resume from checkpoint if available.

train
train() -> dict

Run the main ORPO training loop.

Iterates over preference pairs, computing the ORPO loss (NLL + odds-ratio penalty), accumulating gradients, and updating the model. Tracks reward metrics (chosen/rejected log-probs and their margin) at each step.

Returns:

Type Description
dict

Dict with final training stats including reward metrics:

dict

{"step", "loss", "grad_norm", "chosen_reward", "rejected_reward", "reward_margin", "reward_accuracy"}.

evaluate
evaluate() -> dict

Run evaluation on val_dataset.

Iterates over the validation preference pairs (up to 10 batches), computing ORPO loss and reward margin metrics without gradients.

Returns:

Type Description
dict

{"loss": float, "perplexity": float, "reward_margin": float}

LoRALinear
LoRALinear(input_dims, output_dims, r=8, dropout=0.0, scale=20.0, bias=False)

Bases: Module

Low-Rank Adaptation wrapper around a base Linear layer.

Adds a trainable low-rank decomposition (lora_a @ lora_b) scaled by scale and added to the base layer output. Base weights are frozen.

Parameters:

Name Type Description Default
input_dims

Input dimension of the linear layer.

required
output_dims

Output dimension of the linear layer.

required
r

LoRA rank.

8
dropout

Dropout probability applied before the low-rank path.

0.0
scale

Scaling factor for the LoRA output.

20.0
bias

Whether to include a bias term in the base linear layer.

False

Attributes:

Name Type Description
linear

Base frozen linear layer.

lora_a

Low-rank matrix A of shape (input_dims, r).

lora_b

Low-rank matrix B of shape (r, output_dims), initialized to zeros.

scale

Output scaling factor.

Attributes
linear instance-attribute
linear = Linear(input_dims, output_dims, bias=bias)
dropout instance-attribute
dropout = Dropout(p=dropout)
scale instance-attribute
scale = scale
lora_a instance-attribute
lora_a = uniform(low=-init_scale, high=init_scale, shape=(input_dims, r))
lora_b instance-attribute
lora_b = zeros(shape=(r, output_dims))
Functions
__call__
__call__(x)
from_base staticmethod
from_base(linear, r=8, dropout=0.0, scale=20.0)

Create a LoRALinear wrapping an existing Linear or QuantizedLinear.

Parameters:

Name Type Description Default
linear

Base linear layer to wrap. Its weights are preserved.

required
r

LoRA rank.

8
dropout

Dropout probability.

0.0
scale

Output scaling factor.

20.0

Returns:

Type Description

LoRALinear with the base layer's weights and new LoRA matrices.

fuse
fuse(dequantize=False)

Fuse LoRA weights into the base layer, producing a plain nn.Linear.

Adds the scaled low-rank delta to the base weight. If the base is QuantizedLinear and dequantize=True, dequantizes before fusing.

Parameters:

Name Type Description Default
dequantize

If True, dequantize QuantizedLinear weights before fusing.

False

Returns:

Type Description

nn.Linear with fused weights.

DoRALinear
DoRALinear(input_dims, output_dims, r=8, dropout=0.0, scale=20.0, bias=False)

Bases: Module

Weight-Decomposed Low-Rank Adaptation (DoRA) wrapper.

Like LoRA but re-normalizes the adapted output to match the magnitude of the original weight matrix. Stores the per-output-dim norm of the base weight in m and divides by the norm of the adapted weight during forward.

Parameters:

Name Type Description Default
input_dims

Input dimension of the linear layer.

required
output_dims

Output dimension of the linear layer.

required
r

LoRA rank.

8
dropout

Dropout probability applied before the low-rank path.

0.0
scale

Scaling factor for the LoRA output.

20.0
bias

Whether to include a bias term in the base linear layer.

False

Attributes:

Name Type Description
linear

Base frozen linear layer.

lora_a

Low-rank matrix A of shape (input_dims, r).

lora_b

Low-rank matrix B of shape (r, output_dims), initialized to zeros.

m

Magnitude vector of shape (output_dims,), frozen norm of base weight rows.

scale

Output scaling factor.

Attributes
linear instance-attribute
linear = Linear(input_dims, output_dims, bias=bias)
dropout instance-attribute
dropout = Dropout(p=dropout)
scale instance-attribute
scale = scale
lora_a instance-attribute
lora_a = uniform(low=-init_scale, high=init_scale, shape=(input_dims, r))
lora_b instance-attribute
lora_b = zeros(shape=(r, output_dims))
m instance-attribute
m = norm(astype(float32), axis=1)
Functions
__call__
__call__(x)
from_base staticmethod
from_base(linear, r=8, dropout=0.0, scale=20.0)

Create a DoRALinear wrapping an existing Linear or QuantizedLinear.

Parameters:

Name Type Description Default
linear

Base linear layer to wrap. Its weights are preserved.

required
r

DoRA rank.

8
dropout

Dropout probability.

0.0
scale

Output scaling factor.

20.0

Returns:

Type Description

DoRALinear with the base layer's weights and magnitude vector.

fuse
fuse()

Fuse DoRA weights into the base layer, producing a plain nn.Linear.

Adds the scaled low-rank delta to the base weight, then re-normalizes each row to match the original magnitude stored in m.

Returns:

Type Description

nn.Linear with fused and magnitude-normalized weights.

SFTDataset
SFTDataset(data: list[dict] | str | Path, tokenizer: QwenTokenizerWrapper, max_seq_len: int = 2048, mask_prompt: bool = True)

Dataset for supervised fine-tuning with chat/messages format.

Expected JSONL format

{"messages": [{"role": "system", "content": "..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}

Yields (token_ids, loss_mask) tuples per example. If mask_prompt=True, loss is only computed on assistant response tokens.

Functions
__len__
__len__() -> int
__getitem__
__getitem__(idx: int) -> tuple[list[int], list[int]]
__iter__
__iter__() -> Iterator[tuple[list[int], list[int]]]
ORPODataset
ORPODataset(data: list[dict[str, object]] | str | Path, tokenizer: QwenTokenizerWrapper, max_seq_len: int = 2048)

Dataset for ORPO (Odds Ratio Preference Optimization).

Expected JSONL format

{"prompt": [{"role": "user", "content": "..."}], "chosen": [{"role": "assistant", "content": "..."}], "rejected": [{"role": "assistant", "content": "..."}]}

Also supports string format

{"prompt": "...", "chosen": "...", "rejected": "..."}

Yields (chosen_ids, chosen_mask, rejected_ids, rejected_mask) per example.

Functions
__len__
__len__() -> int
__getitem__
__getitem__(idx: int) -> tuple[list[int], list[int], list[int], list[int]]
__iter__
__iter__() -> Iterator[tuple[list[int], list[int], list[int], list[int]]]

Functions

apply_lora_to_model
apply_lora_to_model(model: Module, rank: int = 8, dropout: float = 0.0, scale: float = 20.0, targets: tuple[str, ...] = DEFAULT_LORA_TARGETS, use_dora: bool = False) -> list[str]

Walk model tree and replace target nn.Linear/nn.QuantizedLinear with LoRA/DoRA wrappers.

Layers matching names in LORA_EXCLUDED_NAMES or paths in LORA_EXCLUDED_PATHS are skipped regardless of target matching.

Parameters:

Name Type Description Default
model Module

BitAxonModel to apply adapters to.

required
rank int

LoRA rank.

8
dropout float

Dropout probability for the adapter.

0.0
scale float

Output scaling factor.

20.0
targets tuple[str, ...]

Tuple of linear layer name suffixes to wrap.

DEFAULT_LORA_TARGETS
use_dora bool

If True, use DoRALinear instead of LoRALinear.

False

Returns:

Type Description
list[str]

List of dot-separated paths to wrapped layers.