[Axolotl] Qwen 3.5 모델 Liger 커널 지원 및 fused RMSNorm+Gated 커널 추가
PR 링크: axolotl-ai-cloud/axolotl#3531 상태: Merged | 변경: +970 / -0
들어가며
Qwen 3.5는 기존 transformer와 다른 두 가지 특징이 있습니다:
- RMSNorm에 zero-init weight + offset 1.0 패턴 사용 (Gemma와 유사)
- Linear attention 레이어에 gated RMSNorm (
RMSNormGated) 사용
이 PR은 이 두 특성에 맞는 Liger 커널 통합과 Fused Linear Cross Entropy(FLCE) 지원을 추가합니다.
핵심 코드 분석
1. Qwen 3.5 RMSNorm 호환 wrapper
# integrations/liger/models/qwen3_5.py
if rms_norm:
class LigerRMSNormForQwen3_5(LigerRMSNorm):
def __init__(self, dim, eps=1e-6, **kwargs):
super().__init__(
dim, eps=eps,
offset=1.0, # output * (1.0 + weight) 패턴
casting_mode="gemma", # Gemma 스타일 캐스팅
init_fn="zeros", # zero-init weight
in_place=False,
)
modeling_qwen3_5.Qwen3_5RMSNorm = LigerRMSNormForQwen3_5
Qwen 3.5의 RMSNorm은 output * (1.0 + weight) 패턴을 사용합니다. weight가 0으로 초기화되어 처음에는 항등 변환이고, 학습 중 미세 조정됩니다. offset=1.0과 init_fn="zeros"로 이 동작을 정확히 재현합니다.
2. Fused RMSNorm+Gated 커널 통합
if rms_norm_gated:
from axolotl.kernels.rms_norm_gated import FusedRMSNormGated
modeling_qwen3_5.Qwen3_5RMSNormGated = FusedRMSNormGated
Linear attention 레이어의 RMSNormGated는 RMSNorm과 SiLU gate를 하나의 연산으로 fuse합니다. 별도 커널 없이 두 연산을 순차 실행하면 메모리 접근이 2배 발생하는데, fused 커널은 한 번의 메모리 접근으로 두 연산을 처리합니다.
3. Liger 설정 확장
# integrations/liger/args.py
liger_rms_norm_gated: bool | None = Field(
default=None,
json_schema_extra={
"description": (
"Enables fused RMSNorm+SiLU gate Triton kernel for models with "
"gated RMSNorm (e.g. Qwen3.5 / Qwen3.5 MoE linear attention layers)."
)
},
)
사용자가 YAML 설정에서 liger_rms_norm_gated: true로 간단하게 활성화할 수 있습니다.
왜 이게 좋은가
- 메모리 절약: FLCE는 logits를 메모리에 실체화하지 않아 vocabulary 크기에 비례한 메모리를 절약합니다.
- 커널 fusion: RMSNorm + SiLU gate를 하나의 Triton 커널로 처리하여 memory bandwidth 병목을 줄입니다.
- 올바른 초기화: Gemma 스타일의 offset RMSNorm을 정확하게 구현하여 학습 안정성을 보장합니다.
정리
Qwen 3.5 계열 모델의 고유한 아키텍처(offset RMSNorm, gated RMSNorm)에 맞춘 Liger 커널 통합입니다. 970줄 추가의 대부분은 Qwen 3.5와 Qwen 3.5 MoE 각각의 lce_forward와 apply_liger_kernel_to_* 함수입니다.
참고 자료
이 포스트는 AI가 작성하였으며, 사실과 다를 수 있습니다. 정확한 정보는 원본 PR을 참고해 주세요.
관련 포스트
PR Analysis 의 다른글
- 이전글 [Open WebUI] 메모리 항목 삭제 시 확인 대화상자 추가
- 현재글 : [Axolotl] Qwen 3.5 모델 Liger 커널 지원 및 fused RMSNorm+Gated 커널 추가
- 다음글 [Axolotl] LoRA 커널에 bias, dropout, DoRA, embedding 지원 추가
댓글