[sglang] AMD GPU 최적화: Triton 커널 퓨전을 통한 Qwen2 MoE 공유 전문가 게이팅 성능 향상
PR 링크: sgl-project/sglang#27636 상태: Merged | 변경: +189 / -1
들어가며
최근 sglang-project/sglang 레포지토리의 Pull Request(PR) #27636은 AMD GPU 환경에서 Qwen2 MoE 모델의 추론 성능을 최적화하는 중요한 변경사항을 포함하고 있습니다. 특히, Qwen2MoeSparseMoeBlock._forward_shared_experts 내에서 발생하는 공유 전문가 게이팅(shared expert gating) 연산의 비효율성을 개선하는 데 초점을 맞추고 있습니다. 기존에는 두 개의 별도 커널로 처리되던 F.sigmoid와 요소별 곱셈(elementwise multiplication) 연산을 하나의 Triton 커널로 융합함으로써, 커널 실행 오버헤드를 줄이고 전반적인 추론 속도를 향상시키는 것을 목표로 합니다.
이 글에서는 해당 PR의 코드 변경사항을 상세히 분석하고, 왜 이러한 최적화가 성능 향상에 기여하는지, 그리고 이를 통해 얻을 수 있는 일반적인 교훈은 무엇인지 기술 블로그 형식으로 풀어보고자 합니다.
코드 변경 분석
이번 PR의 핵심은 Triton 커널을 사용하여 기존의 두 개 연산을 하나로 융합하는 것입니다. 변경 사항은 크게 세 부분으로 나눌 수 있습니다.
- 새로운 Triton 커널 구현 (
python/sglang/jit_kernel/triton/sigmoid_gate_mul.py) - 모델 코드 수정 (
python/sglang/srt/models/qwen2_moe.py) - 단위 테스트 추가 (
test/registered/kernels/test_sigmoid_gate_mul.py)
1. 새로운 Triton 커널 구현: sigmoid_gate_mul.py
이 파일은 두 가지 종류의 융합된 Triton 커널을 정의합니다.
sigmoid_gate_mul(x, gate): 입력 텐서x와gate의 형태가 동일할 때,x * torch.sigmoid(gate)연산을 단일 커널로 처리합니다. 이는 이전 PR(#27630)에서 구현된 동일 형태(same-shape) 연산에 대한 커널입니다.sigmoid_gate_mul_broadcast(x, gate): 공유 전문가 게이팅에서 발생하는 일반적인 경우로,gate의 형태가(N, 1)이고x의 형태가(N, D)일 때,x * torch.sigmoid(gate)연산을 브로드캐스팅(broadcasting)을 지원하며 단일 커널로 처리합니다. AMD GPU(ROCm/HIP) 환경에서의 성능 최적화를 위해 도입되었습니다.
주요 코드 변경:
# python/sglang/jit_kernel/triton/sigmoid_gate_mul.py
# ... (import statements)
@triton.jit
def _sigmoid_gate_mul_kernel(
out_ptr,
gate_ptr,
x_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offs < n_elements
# Load gate and x, promote to float32 for sigmoid
g = tl.load(gate_ptr + offs, mask=mask).to(tl.float32)
x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)
# Compute sigmoid and multiply
out = x * tl.sigmoid(g)
# Store result in original dtype
tl.store(out_ptr + offs, out.to(gate_ptr.dtype.element_ty), mask=mask)
def sigmoid_gate_mul(x: torch.Tensor, gate: torch.Tensor) -> torch.Tensor:
"""Compute x * sigmoid(gate) in a single fused kernel."""
# ... (grid setup and kernel launch)
@triton.jit
def _sigmoid_gate_mul_broadcast_kernel(
out_ptr,
gate_ptr,
x_ptr,
hidden_dim: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
row = tl.program_id(0)
# Load gate for the current row, promote to float32
g = tl.load(gate_ptr + row).to(tl.float32)
g = tl.sigmoid(g)
offs = tl.arange(0, BLOCK_SIZE)
mask = offs < hidden_dim
# Load x for the current row, promote to float32
x = tl.load(x_ptr + row * hidden_dim + offs, mask=mask).to(tl.float32)
out = x * g
# Store result in original dtype
tl.store(
out_ptr + row * hidden_dim + offs,
out.to(x_ptr.dtype.element_ty),
mask=mask,
)
def sigmoid_gate_mul_broadcast(x: torch.Tensor, gate: torch.Tensor) -> torch.Tensor:
"""Compute x * sigmoid(gate) where gate is (N, 1) and x is (N, D)."""
# ... (grid setup, num_warps calculation, and kernel launch)
_sigmoid_gate_mul_broadcast_kernel에서는 각 행(row)마다 gate 값을 로드하여 sigmoid를 계산하고, 이를 해당 행의 모든 hidden_dim 요소에 대해 브로드캐스팅하여 x와 곱합니다. 연산 과정에서 수치적 안정성을 위해 float32로 업캐스팅하는 점도 주목할 만합니다.
2. 모델 코드 수정: qwen2_moe.py
Qwen2MoeSparseMoeBlock 클래스의 _forward_shared_experts 메소드 내에서 AMD GPU(_is_hip 조건)일 경우, 기존의 F.sigmoid와 곱셈 연산을 새로운 sigmoid_gate_mul_broadcast 함수 호출로 대체합니다.
주요 코드 변경:
--- a/python/sglang/srt/models/qwen2_moe.py
+++ b/python/sglang/srt/models/qwen2_moe.py
@@ -407,6 +407,13 @@
True,
shared_output,
)
+ elif _is_hip:
+ from sglang.jit_kernel.triton.sigmoid_gate_mul import \
+ sigmoid_gate_mul_broadcast
+
+ gate = self.shared_expert_gate(hidden_states)
+ shared_output = sigmoid_gate_mul_broadcast(shared_output, gate)
+
else:
shared_output = (
F.sigmoid(self.shared_expert_gate(hidden_states))
이 변경은 _use_aiter 플래그와 함께 사용될 때 (_is_hip 조건이 더 넓은 범위일 수 있음) AMD GPU에서 새로운 융합 커널을 사용하도록 경로를 설정합니다. CUDA 경로는 변경되지 않고 유지됩니다.
3. 단위 테스트 추가: test_sigmoid_gate_mul.py
새로운 커널의 정확성과 안정성을 보장하기 위해 40개의 단위 테스트가 추가되었습니다. 이 테스트들은 bf16, fp16, fp32 데이터 타입과 다양한 텐서 형태에 대해 element-wise 및 broadcast 변형 모두를 검증합니다. 또한, 입력 불변성, 출력 데이터 타입, 연속성(contiguity) 등을 확인하여 커널이 예상대로 동작하는지 보장합니다.
주요 코드 변경:
# test/registered/kernels/test_sigmoid_gate_mul.py
# ... (imports and setup)
# --- element-wise variant ---
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
@pytest.mark.parametrize("shape", [(1, 4096), (4, 4096), ...])
def test_sigmoid_gate_mul_correctness(shape, dtype):
from sglang.jit_kernel.triton.sigmoid_gate_mul import sigmoid_gate_mul
# ... (test logic comparing with reference)
# --- broadcast variant ---
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
@pytest.mark.parametrize("shape", [(1, 4096), (4, 4096), ...])
def test_sigmoid_gate_mul_broadcast_correctness(shape, dtype):
from sglang.jit_kernel.triton.sigmoid_gate_mul import sigmoid_gate_mul_broadcast
# ... (test logic comparing with reference)
# ... (tests for input immutability, output dtype, contiguity)
왜 이게 좋은가?
이 PR의 핵심적인 개선점은 다음과 같습니다.
1. 커널 실행 오버헤드 감소
AMD GPU(MI355X) 환경에서는 작은 연산의 경우 커널 실행 자체에 상당한 오버헤드가 발생합니다. PR 설명에 따르면, 기존에는 sigmoid_kernel (~3.8 us)과 elementwise_mul (~4.6 us) 두 개의 커널이 실행되었으며, 각 커널 론칭에 약 4 us의 오버헤드가 소요되었습니다. 이는 총 8.4 us의 연산 시간 외에 약 8 us의 추가 오버헤드를 발생시킵니다.
| Stage | Kernels | Time (us) | Kernel Names |
|---------|---------|-----------|--------------|
| Before | 2 | 8.4 | sigmoid_kernel_cuda + elementwise_kernel BinaryFunctor (mul) |
| After | 1 | 4.0 | _sigmoid_gate_mul_broadcast_kernel |
| **Savings** | | **4.4** | x60 A-type layers = **~264 us/iteration** |
하나의 Triton 커널로 융합함으로써, _sigmoid_gate_mul_broadcast_kernel은 약 4.0 us의 시간으로 연산을 완료합니다. 이는 커널 론칭 오버헤드를 절반으로 줄여, A-type 디코드 레이어당 약 4.4 us의 시간 절약을 가져옵니다. 모델 전체적으로는 약 264 us/iteration의 절감 효과를 기대할 수 있습니다.
2. 성능 향상 및 회귀 없음 확인
E2E 벤치마크 결과, Qwen3.5-397B 모델에서 총 처리량(throughput)이나 지연 시간(TTFT, ITL, TPOT)에서 유의미한 성능 저하(regression)가 관찰되지 않았습니다. 오히려 미미한 성능 향상이 있었으며, 이는 융합된 커널이 기존 연산보다 효율적임을 시사합니다.
| Metric | Baseline | After | Delta |
|--------|----------|-------|-------|
| Total throughput (tok/s) | 2986.62 | 2982.67 | -0.1% |
| Median TTFT (ms) | 391.24 | 390.09 | -0.3% |
| Median ITL (ms) | 10.31 | 10.33 | +0.2% |
| Median TPOT (ms) | 11.42 | 11.40 | -0.2% |
특히, 중간 레이턴시(ITL) 개선 기대치(약 2.6%)는 실제 벤치마크 결과에서 약간의 노이즈 범위 내에 있었지만, 성능 저하가 없다는 점이 중요합니다.
3. 수치적 안정성 및 정확도 보장
새로운 Triton 커널은 sigmoid 연산을 수행하기 전에 입력 값을 float32로 업캐스팅하여 수치적 안정성을 높였습니다. 또한, 광범위한 단위 테스트와 실제 모델(Qwen3.5-397B-A17B-MXFP4)에 대한 정확도 테스트(GSM8K Accuracy 0.909, Invalid rate 0.3%)를 통과하여, 성능 최적화가 모델의 정확성에 영향을 미치지 않음을 검증했습니다.
일반적인 교훈
- 작은 연산의 융합: GPU에서 작은 연산들이 개별적으로 실행될 때 발생하는 커널 론칭 오버헤드는 무시할 수 없습니다. 이러한 연산들을 Triton과 같은 커스텀 커널로 융합하는 것은 상당한 성능 향상을 가져올 수 있습니다.
- 하드웨어 특성 고려: AMD GPU의 경우 커널 론칭 오버헤드가 특히 두드러질 수 있으므로, Triton 커널을 활용한 최적화가 더욱 효과적일 수 있습니다.
- 철저한 테스트: 새로운 커널을 도입할 때는 다양한 데이터 타입, 형태, 그리고 실제 모델 통합 시나리오에 대한 철저한 단위 테스트와 정확도/성능 벤치마킹이 필수적입니다.
- 점진적 도입: CUDA 경로를 유지하면서 AMD GPU에만 새로운 최적화 경로를 적용하는 방식(
_is_hip조건)은 위험을 줄이고 점진적인 도입을 가능하게 합니다.
리뷰 과정 및 CI 상태
리뷰 과정에서 CI는 새로운 Triton 커널 자체에 대해서는 AMD 및 NVIDIA 환경에서 모두 통과했음을 보여주었습니다. 특히 test_sigmoid_gate_mul.py의 40개 테스트 케이스가 성공적으로 실행되었습니다.
하지만, qwen2_moe.py 파일의 실제 모델 코드 통합 부분(_is_hip 분기)은 PR CI 환경에서 직접적으로 실행되지 않는다는 점이 지적되었습니다. 이는 해당 경로가 특정 환경 변수(SGLANG_USE_AITER=1)와 모델 아키텍처(shared_expert_gate)에 의해 트리거되기 때문입니다. CI는 커널의 수학적 정확성은 검증했지만, 모델 레벨에서의 통합 및 동작은 검증하지 못한 상태였습니다.
이에 따라, 병합 전에 AMD 환경에서 해당 통합 경로를 실제로 실행해보거나, Qwen MoE 모델에 대한 정확도 테스트를 수동으로 수행하는 것이 권장되었습니다. 최종적으로, PR 설명에 포함된 Qwen3.5 모델의 GSM8K 정확도 결과(0.909)가 이러한 검증의 근거로 제시되었습니다.
결론
PR #27636은 AMD GPU 환경에서 Qwen2 MoE 모델의 성능을 최적화하기 위한 중요한 발걸음입니다. Triton 커널을 활용하여 두 개의 연산을 하나로 융합함으로써 커널 론칭 오버헤드를 줄이고, 실제 벤치마크에서 성능 향상 및 회귀 없음을 입증했습니다. 또한, 철저한 단위 테스트와 정확도 검증을 통해 최적화가 모델의 정확성에 영향을 미치지 않음을 확인했습니다. 이 PR은 GPU 컴퓨팅에서 커널 융합의 중요성과 하드웨어 특성을 고려한 최적화 전략의 효과를 잘 보여주는 사례입니다.
참고 자료
- https://pytorch.org/docs/stable/generated/torch.compile.html
- https://github.com/sgl-project/sglang/blob/main/python/sglang/jit_kernel/triton/sigmoid_gate_mul.py
- https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/qwen2_moe.py
- https://github.com/sgl-project/sglang/blob/main/test/registered/kernels/test_sigmoid_gate_mul.py
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
- [sglang] [SGLang] Blackwell(B200)에서 Diffusion Attention 성능을 7배 끌어올리는 Triton 커널 최적화 분석
- [sglang] LTX2 스플릿 로터리 커널 최적화: 헤드 배치 처리로 성능 2배 향상
- [sglang] SGLang 성능 최적화: PDL 도입과 안전한 CUDA 동기화로 DSV3.2/GLM-5 가속하기
- [sglang] AMD ROCm 환경에서의 DeepSeek-V4 성능 최적화: Aiter MHC 커널 통합 분석
- [sglang] AMD ROCm 환경에서의 성능 최적화: Triton을 활용한 Fused QK GemmaRMSNorm 구현
PR Analysis 의 다른글
- 이전글 [cpython] CPython unicodedata.normalize() 최적화: Py_UCS4 버퍼 직접 조작으로 성능 향상
- 현재글 : [sglang] AMD GPU 최적화: Triton 커널 퓨전을 통한 Qwen2 MoE 공유 전문가 게이팅 성능 향상
- 다음글 [sglang] [성능 최적화] Wan2.2 모델을 위한 최적의 torch.compile 모드 찾기: 왜 'default'가 더 빠를까?
댓글