[sglang] Apple Silicon MLX 환경에서 SwitchGLU MoE 블록의 SwiGLU 활성화를 Gate Gather-QMV로 융합하여 성능 최적화
PR 링크: sgl-project/sglang#26188 상태: Merged | 변경: +1054 / -0
들어가며
최근 sglang-project/sglang 레포지토리의 PR #26188은 Apple Silicon 환경에서 MLX 백엔드를 사용할 때 Mixture-of-Experts (MoE) 모델의 성능을 최적화하는 중요한 개선 사항을 포함하고 있습니다. 특히, SwitchGLU MoE 블록에서 발생하는 SwiGLU 활성화 함수를 Gate의 Gather-QMV 연산과 융합하여 커널 디스패치 오버헤드를 줄이고자 합니다. 이 글에서는 해당 PR의 코드 변경 사항을 상세히 분석하고, 왜 이러한 최적화가 효과적인지, 그리고 어떤 기술적 교훈을 얻을 수 있는지 살펴보겠습니다.
기존 SwitchGLU MoE 블록은 각 레이어에서 여러 개의 커널 디스패치를 필요로 했습니다. Apple Silicon의 Metal 디스패치는 상당한 오버헤드(약 1-4ms의 GPU 유휴 시간)를 발생시키므로, 디스패치 횟수를 줄이는 것이 디코드 지연 시간을 단축하는 데 핵심적입니다. 이전 PR (#24712)에서는 up_proj와 gate_proj를 하나의 MatMul 연산으로 융합하려 했으나, 이는 커널의 출력 차원을 두 배로 늘려 배치 크기 4 이상에서 타일링 및 점유율 문제를 야기하며 성능 저하를 초래했습니다. 이번 PR은 이러한 문제를 해결하기 위해, 프로젝션 커널을 분리된 상태로 유지하면서 silu(gate) x up 활성화 부분을 Gate MatMul 커널에 직접 융합하는 방식을 채택했습니다.
코드 변경 분석
이번 PR의 핵심 변경 사항은 python/sglang/srt/hardware_backend/mlx/moe/fused_swiglu.py 파일에 집중되어 있으며, 관련 환경 설정 및 모델 로딩 로직도 수정되었습니다.
1. 환경 변수 설정 (environ.py)
새로운 환경 변수 SGLANG_MLX_FUSE_SWIGLU가 추가되어 SwiGLU 활성화 융합 기능을 옵트인 방식으로 제어할 수 있게 되었습니다. 기본값은 False로 설정되어 있어, 기존 동작에 영향을 주지 않습니다.
--- a/python/sglang/srt/environ.py
+++ b/python/sglang/srt/environ.py
@@ -471,6 +471,7 @@ class Envs:
# MPS (Apple Silicon)
SGLANG_USE_MLX = EnvBool(False)
SGLANG_MLX_USE_CUSTOM_ROPE = EnvBool(False)
+ SGLANG_MLX_FUSE_SWIGLU = EnvBool(False)
# NPU
SGLANG_NPU_DISABLE_ACL_FORMAT_WEIGHT = EnvBool(False)
2. 모델 로딩 시 패치 적용 (model_runner.py)
모델 로딩 과정에서 SGLANG_MLX_FUSE_SWIGLU 환경 변수가 활성화된 경우, fused_swiglu.py에 정의된 패치 함수를 호출하여 SwitchGLU 블록의 __call__ 메소드를 수정합니다. 이는 모델 로드 시점에 동적으로 이루어지며, 융합된 커널을 적용할 수 있는 블록 수를 로깅합니다.
--- a/python/sglang/srt/hardware_backend/mlx/model_runner.py
+++ b/python/sglang/srt/hardware_backend/mlx/model_runner.py
@@ -27,6 +27,7 @@
from mlx_lm import load as mlx_lm_load
from mlx_lm.utils import quantize_model as mlx_lm_quantize_model
+from sglang.srt.environ import envs
from sglang.srt.hardware_backend.mlx.aot import (
MLX_AOT_KERNEL_REGISTRY,
MlxAOTKernelSet,
@@ -461,6 +462,21 @@ def _load_model(self):
load_time = time.time() - start_time
logger.info(f"MLX model loaded in {load_time:.2f}s")
+ # Optional: Path B fusion — keep up_proj/gate_proj weights separate
+ # (no matmul-kernel tile regression) but fuse the swiglu activation
+ # into the gate matmul via a custom Metal kernel. Activated by
+ # SGLANG_MLX_FUSE_SWIGLU=1. Mutually exclusive with FUSE_SWITCHGLU.
+ # See: python/sglang/srt/hardware_backend/mlx/moe/fused_swiglu.py
+ if envs.SGLANG_MLX_FUSE_SWIGLU.get():
+ from sglang.srt.hardware_backend.mlx.moe.fused_swiglu import (
+ patch_switch_glu_with_fused_swiglu,
+ )
+
+ n_patched = patch_switch_glu_with_fused_swiglu(self.model)
+ logger.info(
+ f"MLX SwiGLU activation fusion enabled: patched {n_patched} blocks"
+ )
+
def _attention_module_for_layer(self, layer_idx: int) -> Any:
attn = getattr(
self._cache_layout.layers[layer_idx],
3. 핵심 융합 로직 (fused_swiglu.py)
이 파일은 PR의 핵심으로, 다음과 같은 주요 구성 요소를 포함합니다:
fused_gate_qmv_silu_mul: 새로운 Metal 커널로, Gate의 Quantized Matrix-Vector Multiplication (QMV) 연산과 SwiGLU 활성화(silu(gate) * x_up)를 하나의 커널에서 처리합니다. 이는 기존의gate_projMatMul,silu활성화,up_projMatMul, 그리고x_up과의 곱셈을 대체합니다.patch_switch_glu_with_fused_swiglu:SwitchGLU모델의__call__메소드를 동적으로 패치하는 함수입니다. 특정 조건(can_fuse)을 만족하는 경우에만 융합된 커널을 사용하도록 SwitchGLU 클래스를 서브클래싱하여 적용합니다.can_fuse: 융합이 가능한지 여부를 결정하는 함수입니다. MLX의 4비트 양자화(bits=4,group_size=64) 및 특정 차원 제약 조건(K % 512 == 0,N % 8 == 0)을 만족해야 합니다. 또한, 스케일/바이어스의 dtype이 입력 dtype과 일치해야 하며, Gate 프로젝션에 학습 가능한 바이어스가 없어야 합니다. 이러한 조건을 만족하지 못하면 기존의 비융합 경로로 폴백합니다._KERNEL_SOURCE: 융합된 Metal 커널의 C++ 소스 코드입니다. 이 커널은 MLX의qmv_fast_impl을 기반으로 하며,silu(result) * x_up연산을 쓰기 에필로그(write epilogue)로 통합합니다.
# python/sglang/srt/hardware_backend/mlx/moe/fused_swiglu.py
# ... (중략) ...
_KERNEL_SOURCE = r"""
// Mirrors qmv_fast_impl<T, group_size=64, bits=4> from MLX's quantized.h
// with a silu(result) * x_up write epilogue.
//
// Inputs:
// x [M_tok, K] — pre-gather activations (T)
// w [E, N, K * 4 / 32] — packed 4-bit weights (uint32)
// s [E, N, K / GROUP_SIZE] — affine scales (T)
// b [E, N, K / GROUP_SIZE] — affine biases (T)
// idx [M_tok * TOPK] — expert per (token, topk) pair (uint32)
// x_up [M_tok * TOPK, N] — precomputed up output (T)
// Output:
// y [M_tok * TOPK, N] — silu(gate_qmv(x)) * x_up
constexpr int BITS = 4;
constexpr int GROUP_SIZE = 64;
// ... (커널 내부 로직) ...
// Write epilogue: simd-sum across lanes, then silu(gate) * x_up.
device T* y_p = y + uint64_t(mt) * uint64_t(N) + uint64_t(out_row);
const device T* x_up_p = x_up + uint64_t(mt) * uint64_t(N) + uint64_t(out_row);
for (int row = 0; row < RESULTS_PER_SIMDGROUP; row++) {
float gate_v = simd_sum(result[row]);
if (simd_lid == 0) {
// silu(x) = x * sigmoid(x) = x / (1 + exp(-x))
float silu_v = gate_v / (1.0f + metal::precise::exp(-gate_v));
y_p[row] = T(silu_v * float(x_up_p[row]));
}
}
"""
# ... (중략) ...
def patch_switch_glu_with_fused_swiglu(model):
"""Patches SwitchGLU blocks in the model to use fused_gate_qmv_silu_mul.
Args:
model: The MLX model to patch.
Returns:
The number of blocks patched.
"""
n_patched = 0
for layer in model.model.layers:
if hasattr(layer.mlp, "switch_mlp"):
sw = layer.mlp.switch_mlp
if can_fuse(sw):
# Apply patch via one-off subclass
sw.__class__ = _PathBSubclass(sw.__class__)
n_patched += 1
return n_patched
def can_fuse(sw: SwitchGLU) -> bool:
# ... (융합 조건 검사 로직) ...
return True # or False
# ... (중략) ...
4. 테스트 (test_fused_swiglu.py)
새로운 테스트 파일이 추가되어, 융합된 커널의 수치적 정확성과 전체 SwitchGLU 순방향 연산의 동등성을 검증합니다. 다양한 배치 크기와 정렬/비정렬 데이터에 대해 테스트를 수행하며, Qwen1.5-MoE-A2.7B-4bit 모델을 사용하여 정확도 테스트를 진행합니다.
왜 이게 좋은가?
1. 디스패치 횟수 감소 및 오버헤드 절감
기존 SwitchGLU MoE 레이어는 일반적으로 3개의 주요 커널(up_proj, gate_proj, silu * x_up)을 필요로 했습니다. 이 PR은 gate_proj MatMul과 SwiGLU 활성화(silu(gate) * x_up)를 하나의 Metal 커널(fused_gate_qmv_silu_mul)로 융합했습니다. 결과적으로 MoE 레이어당 필요한 커널 디스패치 횟수가 3개에서 2개로 줄어듭니다. 이는 각 디스패치마다 발생하는 약 1-4ms의 GPU 유휴 시간 오버헤드를 줄여, 특히 MoE 레이어가 많은 모델에서 전체 디코드 지연 시간을 감소시킬 잠재력을 가집니다.
PR 설명에 따르면, Qwen3-30B-A3B 모델의 48개 레이어에서 약 96회의 디스패치가 감소할 수 있습니다 (배치 크기 1 기준).
2. 성능 수치 및 분석
PR에서는 다양한 시나리오에서 성능을 측정했습니다. 특히 Qwen3-30B-A3B-4bit 모델을 사용하여 배치 크기 1에서 측정한 결과, 융합 활성화(SGLANG_MLX_FUSE_SWIGLU=1)를 사용했을 때와 사용하지 않았을 때(SGLANG_MLX_FUSE_SWIGLU=0)의 토큰당 생성 속도(Decode tok/s)는 다음과 같습니다:
| Condition | Decode tok/s (median) | IQR |
|---|---|---|
off (SGLANG_MLX_FUSE_SWIGLU=0) |
73.99 | 73.03–74.21 |
on (SGLANG_MLX_FUSE_SWIGLU=1) |
73.70 | 72.92–74.05 |
결과적으로 융합 활성화 사용 시 -0.40%의 변화를 보였으며, 이는 측정 노이즈 범위 내에 있는 값입니다. 즉, 이 PR은 측정 가능한 엔드-투-엔드 속도 향상을 주장하지 않습니다. 대신, 커널 디스패치 수를 줄이는 '올바른 융합 기반(correct fusion substrate)'을 제공하는 데 중점을 둡니다. 이는 이전 시도(PR #24712)가 배치 크기 증가 시 성능 저하를 보였던 것과 대조적입니다. 이번 접근 방식은 MatMul 커널의 출력 차원을 늘리지 않아 타일링 및 점유율 문제를 피하면서 디스패치 수를 줄이는 데 성공했습니다.
3. 수치적 정확성 보장
융합된 커널은 기존 연산과 수치적으로 동등함을 보장하기 위해 광범위한 테스트가 수행되었습니다. Qwen1.5-MoE-A2.7B-4bit 모델을 사용한 테스트 결과, 커널 레벨 및 전체 SwitchGLU 순방향 연산에서 최대 상대 오차는 0.1% 미만으로 매우 작았습니다. 이는 실제 모델 동작에 부정적인 영향을 미치지 않음을 시사합니다.
| Check | bs | rel-max | bound |
|---|---|---|---|
kernel fused_gate_qmv_silu_mul |
1 | 0.063% | 2% |
kernel fused_gate_qmv_silu_mul |
4 | 0.091% | 2% |
| full SwitchGLU (unsorted) | 2 | 0.047% | 5% |
| full SwitchGLU (sorted) | 8 | 0.093% | 5% |
@alexnails가 Qwen3-30B-A3B-4bit 모델에서 검증한 결과 역시 커널 레벨 <1%, 전체 순방향 <0.6%의 오차를 보여, 높은 정확도를 유지함을 확인했습니다.
4. 일반적 교훈
- 커널 융합의 장점: 연산의 일부를 기존 커널에 통합하면 디스패치 오버헤드를 줄여 성능을 향상시킬 수 있습니다. 특히 GPU 컴퓨팅에서는 커널 발사(launch) 비용이 상당할 수 있으므로, 융합은 중요한 최적화 기법입니다.
- 성능 vs. 복잡성 트레이드오프: 이전 PR처럼 커널 자체를 더 크게 만드는 것은 특정 조건(예: 높은 배치 크기)에서 오히려 성능 저하를 유발할 수 있습니다. 활성화 함수나 작은 연산들을 기존 커널에 통합하는 방식은 커널의 기본 연산 특성을 유지하면서 오버헤드를 줄이는 더 안전한 접근 방식일 수 있습니다.
- 정확성 검증의 중요성: 복잡한 융합 연산은 반드시 엄격한 수치적 정확성 테스트를 거쳐야 합니다. PR에서는 다양한 조건에서 기존 연산과의 동등성을 검증하여 신뢰도를 높였습니다.
- 점진적 적용 및 옵트인: 새로운 최적화 기능은 기본적으로 비활성화하고 사용자가 명시적으로 활성화하도록 하는 것이 안전합니다. 이는 예상치 못한 부작용을 방지하고, 특정 하드웨어 또는 모델 구성에서만 최적화의 이점을 누릴 수 있도록 합니다.
리뷰 피드백 반영
리뷰 과정에서 몇 가지 중요한 개선이 이루어졌습니다:
_PathBSubclass캐싱:cls.__dict__멤버십 검사를 사용하여SwitchGLU서브클래스가 부모 클래스의 캐시 항목으로 다운캐스팅되는 것을 방지했습니다. 이는 캐시 동작의 정확성을 높입니다.can_fuse()의 dtype 게이트: 스케일/바이어스의 dtype 불일치 시 발생하는 오류를 방지하기 위해, 이러한 경우 융합 대신 기존 경로로 폴백하도록 수정되었습니다. 이는 더 넓은 범위의 모델 구성에서 안정적으로 작동하도록 합니다.- Gate 바이어스 처리: 기존의
QuantizedSwitchLinear가 추가하는 학습 가능한 Gate 바이어스는 새로운 융합 커널에서 처리할 슬롯이 없어 무시되었습니다.can_fuse()함수는 이제 Gate에 학습 가능한 바이어스가 있는 경우 융합을 거부하도록 업데이트되었으며, 관련 회귀 테스트도 추가되었습니다. - 정확성 재확인: 리뷰어(
yeahdongcn)가 제기한 sorted fallback 시 shape 불일치 버그(repro_pr26188_sorted_fallback.py재현 코드)를 수정했습니다 (adaff5fd36). 또한,can_fuse함수에activation체크를 추가하여 유사한 버그를 방지했습니다 (9fc1e09ecd).
결론
PR #26188은 Apple Silicon의 MLX 환경에서 SwitchGLU MoE 모델의 성능을 개선하기 위한 영리한 최적화 기법을 도입했습니다. 커널 디스패치 오버헤드를 줄이기 위해 SwiGLU 활성화를 Gate MatMul 커널에 융합하는 이 접근 방식은, 이전의 시도와 달리 MatMul 커널의 특성을 유지하면서도 효과적으로 디스패치 수를 줄입니다. 비록 측정 가능한 엔드-투-엔드 속도 향상은 미미했지만, 이는 더 많은 MoE 모델에서 잠재적인 성능 개선의 기반을 마련했으며, 엄격한 정확성 테스트를 통해 안정성이 검증되었습니다. 이 PR은 GPU 최적화에서 커널 융합의 중요성과 신중한 설계 및 검증의 필요성을 잘 보여주는 사례입니다.
References
- MLX Metal Kernels Documentation
- torch.compile (융합 및 컴파일 관련 개념 참고)
- SwitchGLU Layer (Hugging Face Transformers 라이브러리에서의 구현 참고)
- MLX Quantized Kernels (커널 구현 참고)
참고 자료
- https://github.com/sgl-project/sglang/pull/26188
- https://github.com/ml-explore/mlx/blob/main/docs/metal_kernels.md
- https://pytorch.org/docs/stable/generated/torch.compile.html
- https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1011
- https://github.com/ml-explore/mlx/blob/main/mlx/include/mlx/backend/metal/kernels/quantized.h
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [triton] Triton AMD StreamK GEMM 커널의 Race Condition 해결: 동기화 로직 최적화 분석
- 현재글 : [sglang] Apple Silicon MLX 환경에서 SwitchGLU MoE 블록의 SwiGLU 활성화를 Gate Gather-QMV로 융합하여 성능 최적화
- 다음글 [sglang] SGLang PD-Disaggregation 최적화: Mori 백엔드에서의 증분 KV 전송 구현
댓글