본문으로 건너뛰기

[sglang] Apple Silicon MLX 환경에서 SwitchGLU MoE 블록의 SwiGLU 활성화를 Gate Gather-QMV로 융합하여 성능 최적화

PR 링크: sgl-project/sglang#26188 상태: Merged | 변경: +1054 / -0

들어가며

최근 sglang-project/sglang 레포지토리의 PR #26188은 Apple Silicon 환경에서 MLX 백엔드를 사용할 때 Mixture-of-Experts (MoE) 모델의 성능을 최적화하는 중요한 개선 사항을 포함하고 있습니다. 특히, SwitchGLU MoE 블록에서 발생하는 SwiGLU 활성화 함수를 Gate의 Gather-QMV 연산과 융합하여 커널 디스패치 오버헤드를 줄이고자 합니다. 이 글에서는 해당 PR의 코드 변경 사항을 상세히 분석하고, 왜 이러한 최적화가 효과적인지, 그리고 어떤 기술적 교훈을 얻을 수 있는지 살펴보겠습니다.

기존 SwitchGLU MoE 블록은 각 레이어에서 여러 개의 커널 디스패치를 필요로 했습니다. Apple Silicon의 Metal 디스패치는 상당한 오버헤드(약 1-4ms의 GPU 유휴 시간)를 발생시키므로, 디스패치 횟수를 줄이는 것이 디코드 지연 시간을 단축하는 데 핵심적입니다. 이전 PR (#24712)에서는 up_projgate_proj를 하나의 MatMul 연산으로 융합하려 했으나, 이는 커널의 출력 차원을 두 배로 늘려 배치 크기 4 이상에서 타일링 및 점유율 문제를 야기하며 성능 저하를 초래했습니다. 이번 PR은 이러한 문제를 해결하기 위해, 프로젝션 커널을 분리된 상태로 유지하면서 silu(gate) x up 활성화 부분을 Gate MatMul 커널에 직접 융합하는 방식을 채택했습니다.

코드 변경 분석

이번 PR의 핵심 변경 사항은 python/sglang/srt/hardware_backend/mlx/moe/fused_swiglu.py 파일에 집중되어 있으며, 관련 환경 설정 및 모델 로딩 로직도 수정되었습니다.

1. 환경 변수 설정 (environ.py)

새로운 환경 변수 SGLANG_MLX_FUSE_SWIGLU가 추가되어 SwiGLU 활성화 융합 기능을 옵트인 방식으로 제어할 수 있게 되었습니다. 기본값은 False로 설정되어 있어, 기존 동작에 영향을 주지 않습니다.

--- a/python/sglang/srt/environ.py
+++ b/python/sglang/srt/environ.py
@@ -471,6 +471,7 @@ class Envs:
     # MPS (Apple Silicon)
     SGLANG_USE_MLX = EnvBool(False)
     SGLANG_MLX_USE_CUSTOM_ROPE = EnvBool(False)
+    SGLANG_MLX_FUSE_SWIGLU = EnvBool(False)
 
     # NPU
     SGLANG_NPU_DISABLE_ACL_FORMAT_WEIGHT = EnvBool(False)

2. 모델 로딩 시 패치 적용 (model_runner.py)

모델 로딩 과정에서 SGLANG_MLX_FUSE_SWIGLU 환경 변수가 활성화된 경우, fused_swiglu.py에 정의된 패치 함수를 호출하여 SwitchGLU 블록의 __call__ 메소드를 수정합니다. 이는 모델 로드 시점에 동적으로 이루어지며, 융합된 커널을 적용할 수 있는 블록 수를 로깅합니다.

--- a/python/sglang/srt/hardware_backend/mlx/model_runner.py
+++ b/python/sglang/srt/hardware_backend/mlx/model_runner.py
@@ -27,6 +27,7 @@
 from mlx_lm import load as mlx_lm_load
 from mlx_lm.utils import quantize_model as mlx_lm_quantize_model
 
+from sglang.srt.environ import envs
 from sglang.srt.hardware_backend.mlx.aot import (
     MLX_AOT_KERNEL_REGISTRY,
     MlxAOTKernelSet,
@@ -461,6 +462,21 @@ def _load_model(self):
         load_time = time.time() - start_time
         logger.info(f"MLX model loaded in {load_time:.2f}s")
 
+        # Optional: Path B fusion — keep up_proj/gate_proj weights separate
+        # (no matmul-kernel tile regression) but fuse the swiglu activation
+        # into the gate matmul via a custom Metal kernel. Activated by
+        # SGLANG_MLX_FUSE_SWIGLU=1. Mutually exclusive with FUSE_SWITCHGLU.
+        # See: python/sglang/srt/hardware_backend/mlx/moe/fused_swiglu.py
+        if envs.SGLANG_MLX_FUSE_SWIGLU.get():
+            from sglang.srt.hardware_backend.mlx.moe.fused_swiglu import (
+                patch_switch_glu_with_fused_swiglu,
+            )
+
+            n_patched = patch_switch_glu_with_fused_swiglu(self.model)
+            logger.info(
+                f"MLX SwiGLU activation fusion enabled: patched {n_patched} blocks"
+            )
+
     def _attention_module_for_layer(self, layer_idx: int) -> Any:
         attn = getattr(
             self._cache_layout.layers[layer_idx],

3. 핵심 융합 로직 (fused_swiglu.py)

이 파일은 PR의 핵심으로, 다음과 같은 주요 구성 요소를 포함합니다:

  • fused_gate_qmv_silu_mul: 새로운 Metal 커널로, Gate의 Quantized Matrix-Vector Multiplication (QMV) 연산과 SwiGLU 활성화(silu(gate) * x_up)를 하나의 커널에서 처리합니다. 이는 기존의 gate_proj MatMul, silu 활성화, up_proj MatMul, 그리고 x_up과의 곱셈을 대체합니다.
  • patch_switch_glu_with_fused_swiglu: SwitchGLU 모델의 __call__ 메소드를 동적으로 패치하는 함수입니다. 특정 조건(can_fuse)을 만족하는 경우에만 융합된 커널을 사용하도록 SwitchGLU 클래스를 서브클래싱하여 적용합니다.
  • can_fuse: 융합이 가능한지 여부를 결정하는 함수입니다. MLX의 4비트 양자화(bits=4, group_size=64) 및 특정 차원 제약 조건(K % 512 == 0, N % 8 == 0)을 만족해야 합니다. 또한, 스케일/바이어스의 dtype이 입력 dtype과 일치해야 하며, Gate 프로젝션에 학습 가능한 바이어스가 없어야 합니다. 이러한 조건을 만족하지 못하면 기존의 비융합 경로로 폴백합니다.
  • _KERNEL_SOURCE: 융합된 Metal 커널의 C++ 소스 코드입니다. 이 커널은 MLX의 qmv_fast_impl을 기반으로 하며, silu(result) * x_up 연산을 쓰기 에필로그(write epilogue)로 통합합니다.
# python/sglang/srt/hardware_backend/mlx/moe/fused_swiglu.py

# ... (중략) ...

_KERNEL_SOURCE = r"""
    // Mirrors qmv_fast_impl<T, group_size=64, bits=4> from MLX's quantized.h
    // with a silu(result) * x_up write epilogue.
    //
    // Inputs:
    //   x       [M_tok, K]                — pre-gather activations (T)
    //   w       [E, N, K * 4 / 32]        — packed 4-bit weights (uint32)
    //   s       [E, N, K / GROUP_SIZE]    — affine scales (T)
    //   b       [E, N, K / GROUP_SIZE]    — affine biases (T)
    //   idx     [M_tok * TOPK]            — expert per (token, topk) pair (uint32)
    //   x_up    [M_tok * TOPK, N]         — precomputed up output (T)
    // Output:
    //   y       [M_tok * TOPK, N]         — silu(gate_qmv(x)) * x_up

    constexpr int BITS = 4;
    constexpr int GROUP_SIZE = 64;
    // ... (커널 내부 로직) ...

    // Write epilogue: simd-sum across lanes, then silu(gate) * x_up.
    device T* y_p = y + uint64_t(mt) * uint64_t(N) + uint64_t(out_row);
    const device T* x_up_p = x_up + uint64_t(mt) * uint64_t(N) + uint64_t(out_row);

    for (int row = 0; row < RESULTS_PER_SIMDGROUP; row++) {
        float gate_v = simd_sum(result[row]);
        if (simd_lid == 0) {
            // silu(x) = x * sigmoid(x) = x / (1 + exp(-x))
            float silu_v = gate_v / (1.0f + metal::precise::exp(-gate_v));
            y_p[row] = T(silu_v * float(x_up_p[row]));
        }
    }
"""

# ... (중략) ...

def patch_switch_glu_with_fused_swiglu(model):
    """Patches SwitchGLU blocks in the model to use fused_gate_qmv_silu_mul.

    Args:
        model: The MLX model to patch.

    Returns:
        The number of blocks patched.
    """
    n_patched = 0
    for layer in model.model.layers:
        if hasattr(layer.mlp, "switch_mlp"):
            sw = layer.mlp.switch_mlp
            if can_fuse(sw):
                # Apply patch via one-off subclass
                sw.__class__ = _PathBSubclass(sw.__class__)
                n_patched += 1
    return n_patched

def can_fuse(sw: SwitchGLU) -> bool:
    # ... (융합 조건 검사 로직) ...
    return True # or False

# ... (중략) ...

4. 테스트 (test_fused_swiglu.py)

새로운 테스트 파일이 추가되어, 융합된 커널의 수치적 정확성과 전체 SwitchGLU 순방향 연산의 동등성을 검증합니다. 다양한 배치 크기와 정렬/비정렬 데이터에 대해 테스트를 수행하며, Qwen1.5-MoE-A2.7B-4bit 모델을 사용하여 정확도 테스트를 진행합니다.

왜 이게 좋은가?

1. 디스패치 횟수 감소 및 오버헤드 절감

기존 SwitchGLU MoE 레이어는 일반적으로 3개의 주요 커널(up_proj, gate_proj, silu * x_up)을 필요로 했습니다. 이 PR은 gate_proj MatMul과 SwiGLU 활성화(silu(gate) * x_up)를 하나의 Metal 커널(fused_gate_qmv_silu_mul)로 융합했습니다. 결과적으로 MoE 레이어당 필요한 커널 디스패치 횟수가 3개에서 2개로 줄어듭니다. 이는 각 디스패치마다 발생하는 약 1-4ms의 GPU 유휴 시간 오버헤드를 줄여, 특히 MoE 레이어가 많은 모델에서 전체 디코드 지연 시간을 감소시킬 잠재력을 가집니다.

PR 설명에 따르면, Qwen3-30B-A3B 모델의 48개 레이어에서 약 96회의 디스패치가 감소할 수 있습니다 (배치 크기 1 기준).

2. 성능 수치 및 분석

PR에서는 다양한 시나리오에서 성능을 측정했습니다. 특히 Qwen3-30B-A3B-4bit 모델을 사용하여 배치 크기 1에서 측정한 결과, 융합 활성화(SGLANG_MLX_FUSE_SWIGLU=1)를 사용했을 때와 사용하지 않았을 때(SGLANG_MLX_FUSE_SWIGLU=0)의 토큰당 생성 속도(Decode tok/s)는 다음과 같습니다:

Condition Decode tok/s (median) IQR
off (SGLANG_MLX_FUSE_SWIGLU=0) 73.99 73.03–74.21
on (SGLANG_MLX_FUSE_SWIGLU=1) 73.70 72.92–74.05

결과적으로 융합 활성화 사용 시 -0.40%의 변화를 보였으며, 이는 측정 노이즈 범위 내에 있는 값입니다. 즉, 이 PR은 측정 가능한 엔드-투-엔드 속도 향상을 주장하지 않습니다. 대신, 커널 디스패치 수를 줄이는 '올바른 융합 기반(correct fusion substrate)'을 제공하는 데 중점을 둡니다. 이는 이전 시도(PR #24712)가 배치 크기 증가 시 성능 저하를 보였던 것과 대조적입니다. 이번 접근 방식은 MatMul 커널의 출력 차원을 늘리지 않아 타일링 및 점유율 문제를 피하면서 디스패치 수를 줄이는 데 성공했습니다.

3. 수치적 정확성 보장

융합된 커널은 기존 연산과 수치적으로 동등함을 보장하기 위해 광범위한 테스트가 수행되었습니다. Qwen1.5-MoE-A2.7B-4bit 모델을 사용한 테스트 결과, 커널 레벨 및 전체 SwitchGLU 순방향 연산에서 최대 상대 오차는 0.1% 미만으로 매우 작았습니다. 이는 실제 모델 동작에 부정적인 영향을 미치지 않음을 시사합니다.

Check bs rel-max bound
kernel fused_gate_qmv_silu_mul 1 0.063% 2%
kernel fused_gate_qmv_silu_mul 4 0.091% 2%
full SwitchGLU (unsorted) 2 0.047% 5%
full SwitchGLU (sorted) 8 0.093% 5%

@alexnails가 Qwen3-30B-A3B-4bit 모델에서 검증한 결과 역시 커널 레벨 <1%, 전체 순방향 <0.6%의 오차를 보여, 높은 정확도를 유지함을 확인했습니다.

4. 일반적 교훈

  • 커널 융합의 장점: 연산의 일부를 기존 커널에 통합하면 디스패치 오버헤드를 줄여 성능을 향상시킬 수 있습니다. 특히 GPU 컴퓨팅에서는 커널 발사(launch) 비용이 상당할 수 있으므로, 융합은 중요한 최적화 기법입니다.
  • 성능 vs. 복잡성 트레이드오프: 이전 PR처럼 커널 자체를 더 크게 만드는 것은 특정 조건(예: 높은 배치 크기)에서 오히려 성능 저하를 유발할 수 있습니다. 활성화 함수나 작은 연산들을 기존 커널에 통합하는 방식은 커널의 기본 연산 특성을 유지하면서 오버헤드를 줄이는 더 안전한 접근 방식일 수 있습니다.
  • 정확성 검증의 중요성: 복잡한 융합 연산은 반드시 엄격한 수치적 정확성 테스트를 거쳐야 합니다. PR에서는 다양한 조건에서 기존 연산과의 동등성을 검증하여 신뢰도를 높였습니다.
  • 점진적 적용 및 옵트인: 새로운 최적화 기능은 기본적으로 비활성화하고 사용자가 명시적으로 활성화하도록 하는 것이 안전합니다. 이는 예상치 못한 부작용을 방지하고, 특정 하드웨어 또는 모델 구성에서만 최적화의 이점을 누릴 수 있도록 합니다.

리뷰 피드백 반영

리뷰 과정에서 몇 가지 중요한 개선이 이루어졌습니다:

  • _PathBSubclass 캐싱: cls.__dict__ 멤버십 검사를 사용하여 SwitchGLU 서브클래스가 부모 클래스의 캐시 항목으로 다운캐스팅되는 것을 방지했습니다. 이는 캐시 동작의 정확성을 높입니다.
  • can_fuse()의 dtype 게이트: 스케일/바이어스의 dtype 불일치 시 발생하는 오류를 방지하기 위해, 이러한 경우 융합 대신 기존 경로로 폴백하도록 수정되었습니다. 이는 더 넓은 범위의 모델 구성에서 안정적으로 작동하도록 합니다.
  • Gate 바이어스 처리: 기존의 QuantizedSwitchLinear가 추가하는 학습 가능한 Gate 바이어스는 새로운 융합 커널에서 처리할 슬롯이 없어 무시되었습니다. can_fuse() 함수는 이제 Gate에 학습 가능한 바이어스가 있는 경우 융합을 거부하도록 업데이트되었으며, 관련 회귀 테스트도 추가되었습니다.
  • 정확성 재확인: 리뷰어(yeahdongcn)가 제기한 sorted fallback 시 shape 불일치 버그(repro_pr26188_sorted_fallback.py 재현 코드)를 수정했습니다 (adaff5fd36). 또한, can_fuse 함수에 activation 체크를 추가하여 유사한 버그를 방지했습니다 (9fc1e09ecd).

결론

PR #26188은 Apple Silicon의 MLX 환경에서 SwitchGLU MoE 모델의 성능을 개선하기 위한 영리한 최적화 기법을 도입했습니다. 커널 디스패치 오버헤드를 줄이기 위해 SwiGLU 활성화를 Gate MatMul 커널에 융합하는 이 접근 방식은, 이전의 시도와 달리 MatMul 커널의 특성을 유지하면서도 효과적으로 디스패치 수를 줄입니다. 비록 측정 가능한 엔드-투-엔드 속도 향상은 미미했지만, 이는 더 많은 MoE 모델에서 잠재적인 성능 개선의 기반을 마련했으며, 엄격한 정확성 테스트를 통해 안정성이 검증되었습니다. 이 PR은 GPU 최적화에서 커널 융합의 중요성과 신중한 설계 및 검증의 필요성을 잘 보여주는 사례입니다.

References

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글