가중치 포팅 가이드: Qwen2.5-3B → Bit-Axon¶
Bit-Axon의 아키텍처는 표준 트랜스포머와 상당히 다르지만, 초기 가중치는 Qwen2.5-3B에서 부트스트랩합니다. 이 가이드에서는 전체 포팅 파이프라인을 다룹니다. 각 파라미터 패밀리가 어떻게 매핑, 변환, 검증되는지 설명합니다.
왜 Qwen2.5-3B인가?¶
Qwen2.5-3B는 세 가지 이유로 검증된 기반이 됩니다.
- 사전 학습된 표현. 임베딩 행과 RMSNorm 가중치는 아키텍처 간 전이 가능한 분포 지식을 인코딩합니다. 이를 시작점으로 사용하는 것이 무작위 초기화보다 훨씬 낫습니다.
- 차원 정렬. Qwen2.5-3B의 은닉 차원(2048)은 Bit-Axon의 2560 바로 아래에 있습니다. 이로 인해 RMSNorm에 대한 패딩 전략이 거의 손실이 없으며, MLP→MoE 프로젝션이 소스 모델의 feedforward 용량 대부분을 유지할 수 있습니다.
- MLP 호환성. Qwen2.5-3B의 밀집 SwiGLU MLP는 Bit-Axon의 공유 expert MoE에 깔끔하게 매핑됩니다. 게이트 구조(gate/up/down 프로젝션)가 보존되며, expert 간에 잘리고 복제됩니다.
아키텍처 불일치 한눈에 보기¶
| 측면 | Qwen2.5-3B | Bit-Axon |
|---|---|---|
| 어휘 크기 | 151,936 토큰 | 32,000 토큰 |
| 은닉 차원 | 2,048 | 2,560 |
| MLP 중간 차원 | 11,008 | 4,096 (expert당) |
| FFN 구조 | 밀집 SwiGLU | 8-expert 공유 MoE |
| 정규화 레이어 | RMSNorm (2048) | RMSNorm (2560) |
| 어텐션 | 전체 36 레이어 | SWA만, 레이어 9-16 |
| SSM 레이어 | 없음 | 레이어 1-8, 17-24 |
어휘 매핑¶
첫 번째 단계는 Qwen의 151K tokenizer를 Bit-Axon의 32K로 축소하는 것입니다. 두 가지 전략이 있습니다.
First-N (기본값)¶
BPE 병합 순서대로 처음 32,000개 토큰을 가져옵니다. BPE 병합이 빈도 순서에 근사하므로 가장 흔한 토큰이 유지됩니다.
from bit_axon.porting.vocab_map import build_vocab_mapping
# 기본값: BPE 순서대로 처음 32K 토큰
vocab_mapping = build_vocab_mapping(
tokenizer_name="Qwen/Qwen2.5-3B",
target_size=32000,
)
# vocab_mapping: {0: 0, 1: 1, 2: 2, ..., 31999: 31999}
빈도 기반 선택¶
대표적인 코퍼스 텍스트를 전달하여 가장 빈도가 높은 32K 토큰을 선택합니다.
vocab_mapping = build_vocab_mapping(
tokenizer_name="Qwen/Qwen2.5-3B",
target_size=32000,
corpus_text=open("corpus.txt").read(),
)
# vocab_mapping의 키는 Qwen ID, 값은 새 Bit-Axon ID
매핑은 간단한 dict입니다: {old_qwen_id: new_bitaxon_id}. 다운스트림에서 임베딩 추출 단계가 이를 읽어 행을 재정렬합니다.
축소된 tokenizer 로드¶
매핑을 구축한 후 32K 선택된 토큰만 아는 tokenizer를 만들 수 있습니다.
from bit_axon.porting.vocab_map import load_truncated_tokenizer
tokenizer = load_truncated_tokenizer("Qwen/Qwen2.5-3B", vocab_mapping)
encoded = tokenizer.encode("Hello, world!")
print(encoded.ids) # 모든 ID가 [0, 32000) 범위
가중치 매핑¶
weight_map.py는 BitAxonModel의 모든 파라미터를 5가지 변환 카테고리 중 하나로 분류합니다. 기본 설정(24 레이어, 8 expert)으로 517개 파라미터 키를 생성합니다.
from bit_axon.config import BitAxonConfig
from bit_axon.porting.weight_map import build_key_mappings
config = BitAxonConfig()
mappings = build_key_mappings(config)
# 분류 결과 확인
from collections import Counter
counts = Counter(m.transform for m in mappings)
print(counts)
# Counter({'default': 401, 'pad_2048_2560': 72, 'moe_project': 24,
# 'copy_perturb': 18, 'vocab_extract': 2})
각 매핑은 세 개 필드를 가진 KeyMapping입니다.
target_key: BitAxonModel에서의 파라미터 이름source_key: Qwen2.5-3B에서의 해당 파라미터 이름 (또는 동등한 것이 없으면None)transform: 소스 가중치를 대상 가중치로 변환하는 방법
변환 유형¶
vocab_extract: 임베딩 행 재정렬¶
embed_tokens.weight와 lm_head.weight에 적용됩니다. Bit-Axon은 가중치 결합(weight tying)을 사용하므로 두 항목이 동일한 행렬을 가리킵니다. 이 변환은 Qwen의 (151936, 2048) 임베딩 테이블에서 어휘 매핑으로 지정된 32K 행을 선택하고 (32000, 2048) 행렬로 재정렬합니다.
from bit_axon.porting.mapper import extract_embeddings
embeddings = extract_embeddings(
qwen_weights,
vocab_mapping,
target_vocab_size=32000,
source_hidden_dim=2048,
)
# 형태: (32000, 2048)
pad_2048_2560: 제로 패딩 RMSNorm¶
모든 input_norm.weight, post_attention_norm.weight, post_ssm_norm.weight 파라미터에 적용됩니다. RMSNorm은 1.0으로 초기화되므로, 2048에서 2560으로 1.0으로 패딩하는 것은 거의 손실이 없습니다.
from bit_axon.porting.mapper import pad_rms_norm
padded = pad_rms_norm(qwen_weights["model.layers.5.input_layernorm.weight"], target_dim=2560)
# 형태: (2560,) — 처음 2048개 값은 Qwen에서, 나머지 512개는 1.0
moe_project: 구조화된 잘라내기 + 제로 패딩¶
MoE 레이어(레이어 8-23)의 공유 expert의 gate/up/down 프로젝션에 적용됩니다. Qwen의 밀집 MLP는 중간 차원이 11,008이고, Bit-Axon의 expert는 4,096을 사용합니다. 변환은 처음 4,096개의 행/열을 잘라내고 은닉 차원을 2,048에서 2,560으로 제로 패딩합니다.
Qwen gate_proj (11008, 2048) → 열 0-2047 자르기, (4096, 2560)으로 패딩
Qwen up_proj (11008, 2048) → 열 0-2047 자르기, (4096, 2560)으로 패딩
Qwen down_proj (2048, 11008) → 행 0-2047, 열 0-4095 자르기, (2560, 4096)으로 패딩
from bit_axon.porting.mapper import project_mlp_to_shared_expert
gate, up, down = project_mlp_to_shared_expert(
qwen_gate=qwen_weights["model.layers.10.mlp.gate_proj.weight"],
qwen_up=qwen_weights["model.layers.10.mlp.up_proj.weight"],
qwen_down=qwen_weights["model.layers.10.mlp.down_proj.weight"],
target_intermediate=4096,
target_hidden=2560,
source_hidden=2048,
)
copy_perturb: 라우팅 expert를 위해 공유 expert 복제¶
MoE 레이어의 switch_mlp 라우팅 expert에 적용됩니다. Expert 0은 공유 expert의 정확한 복사본입니다. Expert 1~7은 공유 expert의 가중치에 표준편차 0.02의 가우시안 노이즈를 더한 값을 갖습니다.
from bit_axon.porting.mapper import init_routed_experts
routed_gate, routed_up, routed_down = init_routed_experts(
shared_gate=gate,
shared_up=up,
shared_down=down,
num_experts=8,
perturbation_std=0.02,
)
# routed_gate 형태: (8, 4096, 2560)
# routed_up 형태: (8, 4096, 2560)
# routed_down 형태: (8, 2560, 4096)
작은 섭동(perturbation)은 각 expert가 파인튜닝 전 고유한 시작점을 갖도록 하면서, 공유 expert의 검증된 표현에 충분히 가까워 빠르게 학습할 수 있도록 합니다.
default: 무작위 초기화 유지¶
Qwen에 동등한 것이 없는 파라미터는 기본 초기화 상태를 유지합니다. 여기에는 다음이 포함됩니다.
- SSM 파라미터 (
ssm.*): A, B, C, D 행렬 및 합성곱 커널. Bit-Axon의 Mamba 스타일 SSM은 Qwen에 대응하는 것이 없습니다. - 어텐션 파라미터 (
attention.*): 슬라이딩 윈도우 어텐션 레이어의 Q, K, V, O 프로젝션. - 라우터 파라미터 (
moe.gate.weight): top-2 라우팅 게이트는 밀집 대응이 없습니다. - 차원 브릿지 (
input_proj.weight,output_proj.weight): 2048과 2560 차원 사이를 연결하는 선형 레이어.
전체 파이프라인¶
CLI¶
CLI는 엔드투엔드로 모든 것을 처리합니다. Qwen 다운로드, 어휘 매핑 구축, 모든 변환 실행, 결과 저장.
전체 모델 다운로드 없이 빠르게 테스트하려면 소형 설정으로 목업 가중치를 사용하세요.
소형 설정은 hidden_dim=256, num_layers=4, d_source_model=128, vocab_size=1024를 사용합니다. 목업 Qwen 가중치가 즉시 생성되므로 다운로드가 필요 없습니다. 전체 포팅에 앞서 파이프라인이 오류 없이 실행되는지 확인하는 데 유용합니다.
Python API¶
더 많은 제어가 필요하면 파이프라인 함수를 직접 사용하세요.
import mlx.core as mx
from bit_axon.config import BitAxonConfig
from bit_axon.porting.vocab_map import build_vocab_mapping
from bit_axon.porting.pipeline import initialize_from_qwen_weights, save_ported_model
# 1. Qwen 가중치 로드
weight_files = sorted(glob.glob("/path/to/qwen/*.safetensors"))
qwen_weights = {}
for f in weight_files:
qwen_weights.update(mx.load(f))
# 2. 어휘 매핑 구축
vocab_mapping = build_vocab_mapping(
tokenizer_name="Qwen/Qwen2.5-3B",
target_size=32000,
)
# 3. 파이프라인 실행
config = BitAxonConfig()
model, vocab_mapping = initialize_from_qwen_weights(
qwen_weights,
vocab_mapping=vocab_mapping,
config=config,
)
# 4. 저장
save_ported_model(model, "./output/model.safetensors", vocab_mapping)
어휘 매핑을 건너뛰고 기본 항등 매핑을 사용할 수도 있습니다.
검증¶
포팅 후 문제가 없는지 확인하기 위해 건전성 검사를 실행합니다.
가중치 통계¶
visualization.py 모듈은 분포 통계를 계산하고 이상을 감지합니다.
from mlx.utils import tree_flatten
from bit_axon.porting.visualization import compute_weight_stats, detect_anomalies, format_stats_table
params = dict(tree_flatten(model.parameters()))
stats = compute_weight_stats(params)
# 가장 이상한 가중치 테이블 출력
print(format_stats_table(stats))
# 문제 확인
warnings = detect_anomalies(stats)
for w in warnings:
print(w)
이상 감지기는 네 가지 조건을 감지합니다.
| 조건 | 임계값 | 가능한 원인 |
|---|---|---|
| 모두 0 | max == 0 and min == 0 | 변환 건너뛰거나 소스 키 누락 |
| NaN 값 | mean 또는 std가 NaN | 프로젝션 중 형태 불일치 |
| 높은 이상치 비율 | >10%의 값이 3σ 초과 | 잘못된 섭동 또는 패딩 |
| 극단적 희소성 | >99%가 0에 근접 | 차원 불일치, 빈 값으로 잘림 |
빠른 형태 확인¶
포팅 후 모든 파라미터가 예상 형태를 가지는지 확인합니다.
from bit_axon.porting.weight_map import build_key_mappings
mappings = build_key_mappings(config)
params = dict(tree_flatten(model.parameters()))
for m in mappings:
if m.target_key not in params:
print(f"MISSING: {m.target_key}")
elif m.transform == "vocab_extract":
assert params[m.target_key].shape == (config.vocab_size, config.d_source_model)
elif m.transform == "pad_2048_2560":
assert params[m.target_key].shape == (config.hidden_dim,)
elif m.transform == "moe_project":
assert params[m.target_key].shape[0] in (config.moe_intermediate_dim, config.hidden_dim)
elif m.transform == "copy_perturb":
assert params[m.target_key].shape[0] == config.moe_num_experts
print("All shapes validated.")
포팅 후 다음 단계¶
포팅된 모델은 완성된 모델이 아니라 시작점입니다. SSM과 어텐션 파라미터는 무작위 초기화에서 시작합니다. 라우팅 expert는 작은 노이즈가 추가된 공유 expert의 복사본입니다. 다음을 위해 파인튜닝(bit-axon train을 통한 QLoRA)이 필요합니다.
- SSM 레이어가 순차적 컨텍스트를 흡수하도록 학습
- SWA 레이어의 어텐션 헤드 보정
- 라우팅 expert를 차별화하여 전문화
- 출력 헤드를 축소된 어휘에 정렬