[sglang] SGLang MoE 라우팅 최적화: AMD GPU에서 aiter.biased_grouped_topk 활용
PR 링크: sgl-project/sglang#23611 상태: Merged | 변경: +None / -None
들어가며
대규모 언어 모델(LLM)의 효율적인 서빙은 항상 중요한 과제입니다. 특히 Mixture-of-Experts(MoE) 모델은 모델의 파라미터 수를 크게 늘리면서도 추론 시 활성화되는 파라미터는 제한하여 효율성을 높이는 구조를 가집니다. MoE 모델의 핵심은 입력 토큰을 어떤 전문가(Expert)에게 보낼지 결정하는 '라우팅(Routing)' 메커니즘입니다. 이 라우팅 과정은 모델의 전체 성능에 큰 영향을 미치며, 특히 GPU에서 이 과정이 얼마나 효율적으로 실행되느냐가 중요합니다.
이번 PR은 sgl-project/sglang 레포지토리에서 AMD GPU 환경의 MoE 라우팅 최적화를 다룹니다. 특히 MiniMax-M2.5와 같이 sigmoid 스코어링과 correction_bias를 사용하는 모델에서 기존 sgl_kernel.topk_sigmoid 대신 aiter.biased_grouped_topk 커널을 사용하여 MoE 라우팅 오버헤드를 약 35% 줄이는 것을 목표로 합니다. 이는 전반적인 출력 처리량(output throughput)을 향상시키는 중요한 개선입니다.
코드 분석: AMD GPU MoE 라우팅 커널 교체
핵심 변경사항은 python/sglang/srt/layers/moe/topk.py 파일에 있습니다. 이 파일은 MoE 라우팅 시 전문가를 선택하는 topk 연산을 담당합니다. 특히 sigmoid 스코어링 함수를 사용할 때의 로직이 변경되었습니다.
python/sglang/srt/layers/moe/topk.py
이 파일의 fused_topk 함수는 MoE 라우팅의 핵심 로직을 담고 있습니다. scoring_func가 sigmoid일 때, 기존에는 항상 topk_sigmoid 함수를 호출했습니다. 하지만 이번 변경으로 특정 조건(_use_aiter가 True이고 correction_bias가 None이 아닐 때)에서 aiter_biased_grouped_topk 함수를 사용하도록 조건부 로직이 추가되었습니다.
Before:
def fused_topk(
# ... (생략)
):
# ... (생략)
elif scoring_func == "sigmoid":
topk_sigmoid(
topk_weights,
topk_ids,
gating_output,
renormalize,
correction_bias,
)
else:
raise ValueError(f"Invalid scoring function: {scoring_func}")
After:
def fused_topk(
# ... (생략)
):
# ... (생략)
elif scoring_func == "sigmoid":
if _use_aiter and correction_bias is not None:
aiter_biased_grouped_topk(
gating_output,
correction_bias.to(dtype=gating_output.dtype),
topk_weights,
topk_ids,
num_expert_group=1,
topk_group=1,
need_renorm=renormalize,
)
else:
topk_sigmoid(
topk_weights,
topk_ids,
gating_output,
renormalize,
correction_bias,
)
else:
raise ValueError(f"Invalid scoring function: {scoring_func}")
이 변경의 핵심은 _use_aiter 플래그와 correction_bias의 존재 여부에 따라 다른 커널을 선택한다는 점입니다. _use_aiter는 AMD GPU 환경에서 aiter 라이브러리의 최적화된 커널을 사용할 수 있는지 여부를 나타내는 내부 플래그로 추정됩니다. correction_bias는 MiniMax-M2.5와 같은 특정 모델에서 sigmoid 스코어링과 함께 사용되는 추가적인 바이어스 값입니다.
새롭게 도입된 aiter_biased_grouped_topk는 aiter 라이브러리에서 제공하는 ASM(Assembly) 커널로, biased_grouped_topk라는 이름에서 알 수 있듯이 바이어스가 적용된 그룹화된 topk 연산에 특화되어 있습니다. 이 커널은 기존 topk_sigmoid 커널보다 AMD GPU에서 훨씬 더 효율적으로 동작하도록 설계되었습니다.
왜 이게 좋은 최적화인가?
이 최적화는 다음과 같은 이유로 매우 중요하고 좋은 개선입니다.
-
성능 향상: PR 설명에 따르면,
aiter.biased_grouped_topk커널은 기존sgl_kernel.topk_sigmoid커널 대비 호출당 약 35% 빠른 속도(6 us/call vs 9.3 us/call)를 보여줍니다. 이는 MoE 라우팅 오버헤드를 크게 줄여줍니다. -
실제 처리량 개선: 벤치마크 결과는 이러한 마이크로 최적화가 실제 시스템 성능으로 이어진다는 것을 보여줍니다. MiniMax-M2.5 FP8 모델(MI355X GPU, TP=4, ISL=8192, OSL=1024)에서
conc=64일 때 출력 처리량이 2.0% 증가했고,conc=32일 때 2.4% 증가했습니다. 모든 동시성 레벨(conc=4..128)에서 성능 저하가 없었다는 점도 중요합니다. -
특정 하드웨어 및 모델에 대한 맞춤 최적화: 이 최적화는 AMD GPU와
sigmoid스코어링 및correction_bias를 사용하는 모델(예: MiniMax-M2.5)에 특화되어 있습니다. 특정 하드웨어 아키텍처의 특성을 최대한 활용하는 ASM 커널을 사용함으로써, 일반적인 CUDA/HIP 커널로는 달성하기 어려운 수준의 성능을 끌어낼 수 있습니다. 이는 다양한 하드웨어 환경에서 최적의 성능을 제공하기 위한 SGLang의 노력을 보여줍니다. -
정확도 유지: 성능 향상과 더불어, GSM8K 벤치마크에서 정확도가 93.3%에서 93.4%로 미미하게 상승하거나 최소한 유지되었다는 점은 이 최적화가 모델의 추론 결과에 부정적인 영향을 미치지 않음을 의미합니다. 성능과 정확도라는 두 마리 토끼를 모두 잡은 좋은 사례입니다.
-
모듈성 및 조건부 적용:
if _use_aiter and correction_bias is not None:조건문을 통해 이 최적화가 필요한 특정 환경(AMD GPU + 특정 모델)에서만 적용되도록 설계되었습니다. 이는 다른 환경에서는 기존의 안정적인 코드를 유지하면서도, 특정 환경에서는 최적의 성능을 제공할 수 있도록 하는 유연한 접근 방식입니다.
결론
이번 PR은 SGLang에서 AMD GPU 환경의 MoE 라우팅 성능을 크게 향상시키는 중요한 최적화입니다. aiter.biased_grouped_topk와 같은 저수준(low-level) ASM 커널을 활용하여 sigmoid 스코어링 연산의 효율을 높임으로써, MiniMax-M2.5와 같은 MoE 모델의 추론 처리량을 실질적으로 개선했습니다. 이는 특정 하드웨어 아키텍처의 강점을 최대한 활용하고, 모델의 정확도를 유지하면서도 성능을 극대화하는 모범적인 최적화 사례라고 할 수 있습니다.
이러한 최적화는 LLM 서빙 시스템의 전반적인 효율성을 높이고, 더 많은 사용자가 더 빠르게 모델을 사용할 수 있도록 기여합니다. 앞으로도 SGLang이 다양한 하드웨어 환경에서 최적의 성능을 제공하기 위한 노력을 계속할 것으로 기대됩니다.
참고 자료
- https://pytorch.org/docs/stable/generated/torch.topk.html
- https://pytorch.org/docs/stable/generated/torch.sigmoid.html
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [sglang] AMD GPU에서 FP8 KV 캐시 쓰기 최적화: Triton 커널 융합으로 성능 향상
- 현재글 : [sglang] SGLang MoE 라우팅 최적화: AMD GPU에서 aiter.biased_grouped_topk 활용
- 다음글 [sglang] AMD ROCm 환경에서의 성능 최적화: Triton을 활용한 Fused QK GemmaRMSNorm 구현
댓글