본문으로 건너뛰기

[sglang] SGLang DeepSeekV3 Router GEMM 최적화: FlashInfer 커널 도입 및 벤치마킹

PR 링크: sgl-project/sglang#17707 상태: Merged | 변경: +None / -None

들어가며

대규모 언어 모델(LLM)의 추론 성능은 모델의 상업적 활용에 있어 매우 중요합니다. 특히 Mixture-of-Experts (MoE) 아키텍처를 사용하는 DeepSeekV3와 같은 모델에서는 라우터(router)의 GEMM(General Matrix Multiply) 연산이 전체 성능에 큰 영향을 미칩니다. 이 PR은 sgl-project/sglang 레포지토리에서 DeepSeekV3 라우터 GEMM의 성능을 최적화하기 위해 flashinfer 라이브러리의 최적화된 커널을 도입하고, 기존 sglang 커널과의 성능 비교 벤치마킹을 수행하는 내용을 담고 있습니다. 목표는 flashinfer 커널이 특정 조건에서 더 나은 성능을 제공하는지 확인하고, 이를 통해 전체 추론 속도를 향상시키는 것입니다.

코드 분석

이 PR은 주로 두 가지 핵심 변경사항을 포함합니다. 첫째, flashinfer 라이브러리의 dsv3_router_gemm 커널을 벤치마킹하기 위한 스크립트를 추가합니다. 둘째, sglang의 DeepSeekV2 모델 구현에서 특정 조건(Blackwell GPU, m=1~16, n=256, k=7168, tp_size=1~8)에서 flashinfer 커널을 사용하도록 로직을 변경합니다.

benchmark/kernels/deepseek/benchmark_deepgemm_dsv3_router_gemm_blackwell.py 파일 추가

이 파일은 sglang의 기존 dsv3_router_gemm 커널과 flashinfermm_M1_16_K7168_N256 커널의 정확성 및 성능을 비교하기 위한 벤치마킹 스크립트입니다. 주요 변경사항은 다음과 같습니다.

Before:

# /dev/null (새로 추가된 파일)

After:

import argparse
import os
from typing import List

import torch
import triton
from flashinfer.gemm import mm_M1_16_K7168_N256
from sgl_kernel import dsv3_router_gemm

N = 256
K = 7168


def create_benchmark_configs(tp_sizes: List[int]):
    configs = []
    for tp_size in tp_sizes:
        for m in range(1, 17):
            configs.append((m, N, K, tp_size))
    return configs


def dsv3_router_gemm_flashinfer(
    hidden_states: torch.Tensor,
    router_weights: torch.Tensor,
):
    """Flashinfer implementation of dsv3 router gemm"""
    output = torch.empty(
        hidden_states.shape[0],
        router_weights.shape[0],
        device="cuda",
        dtype=torch.float32,
    )
    mm_M1_16_K7168_N256(
        hidden_states, router_weights.t(), output, launch_with_pdl=args.use_pdl
    )
    return output


def dsv3_router_gemm_sgl(
    hidden_states: torch.Tensor,
    router_weights: torch.Tensor,
):
    """SGLang implementation of dsv3 router gemm"""
    output = dsv3_router_gemm(
        hidden_states,
        router_weights,
        out_dtype=torch.float32,
    )
    return output

# ... (정확성 및 벤치마킹 로직 생략)

if __name__ == "__main__":
    # ... (인자 파싱 및 초기화 로직 생략)

    if args.use_pdl:
        os.environ["TRTLLM_ENABLE_PDL"] = "1"

    # ... (정확성 테스트 및 벤치마킹 실행 로직 생략)

이 스크립트는 flashinfer.gemm.mm_M1_16_K7168_N256sgl_kernel.dsv3_router_gemm 두 함수를 직접 호출하여 벤치마킹합니다. 특히 flashinfer 커널의 경우 launch_with_pdl 인자를 통해 PDL(Pre-computed Dynamic Loading) 사용 여부를 제어할 수 있도록 했습니다. 초기 PR에서는 torch.randn을 사용하여 출력 텐서를 초기화했으나, 리뷰어 nv-yunzheq의 피드백에 따라 torch.empty로 변경하여 불필요한 오버헤드를 제거했습니다. 또한, contiguous() 호출도 제거하여 최적화된 벤치마킹 환경을 구축했습니다.

python/sglang/srt/models/deepseek_v2.py 파일 변경

이 파일은 DeepSeekV2 모델의 forward 메서드에서 라우터 GEMM 연산을 처리하는 부분을 수정합니다. 특정 조건에서 flashinfer 커널을 사용하도록 분기 로직을 추가했습니다.

Before:

# ...
if _is_cuda:
    from sgl_kernel import dsv3_fused_a_gemm, dsv3_router_gemm
# ...
def forward(
    self,
    hidden_states: torch.Tensor,
    router_logits: torch.Tensor,
    router_weights: torch.Tensor,
    top_k: int,
    renormalize: bool,
):
    # ...
            if (
                _device_sm >= 90
                and (self.weight.shape[0] == 256 or self.weight.shape[0] == 384)
            ):
                # router gemm output float32
                logits = dsv3_router_gemm(
                    hidden_states, self.weight, out_dtype=torch.float32
                )
            elif _use_aiter:
                logits = aiter_dsv3_router_gemm(hidden_states, self.weight)
            else:
                logits = F.linear(hidden_states, self.weight, None)
    # ...

After:

# ...
if _is_cuda:
    from flashinfer.gemm import mm_M1_16_K7168_N256 as _raw_dsv3_router_gemm
    from sgl_kernel import dsv3_fused_a_gemm, dsv3_router_gemm

    from sglang.srt.utils.custom_op import register_custom_op
# ...
def forward(
    self,
    hidden_states: torch.Tensor,
    router_logits: torch.Tensor,
    router_weights: torch.Tensor,
    top_k: int,
    renormalize: bool,
):
    # ...
            if (
                _device_sm >= 90
                and (self.weight.shape[0] == 256 or self.weight.shape[0] == 384)
            ):
                if _device_sm >= 100 and self.weight.shape[0] == 256:
                    # router gemm output float32
                    logits = torch.empty(
                        hidden_states.shape[0],
                        self.weight.shape[0],
                        device=hidden_states.device,
                        dtype=torch.float32,
                    )
                    flashinfer_dsv3_router_gemm(logits, hidden_states, self.weight)
                else:
                    logits = dsv3_router_gemm(
                        hidden_states, self.weight, out_dtype=torch.float32
                    )
            elif _use_aiter:
                logits = aiter_dsv3_router_gemm(hidden_states, self.weight)
            else:
                logits = F.linear(hidden_states, self.weight, None)
    # ...

@register_custom_op(
    op_name="flashinfer_dsv3_router_gemm",
    mutates_args=[],
    fake_impl=lambda logits, hidden_states, weight: None,
)
def flashinfer_dsv3_router_gemm(
    logits: torch.Tensor,
    hidden_states: torch.Tensor,
    weight: torch.Tensor,
) -> None:
    _raw_dsv3_router_gemm(
        hidden_states,
        weight.t(),
        logits,
        launch_with_pdl=True,
    )

EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM, DeepseekV32ForCausalLM]

이 변경은 _device_sm >= 100 (Blackwell 이상 GPU)이고 self.weight.shape[0] == 256인 경우 flashinfer_dsv3_router_gemm 함수를 호출하도록 합니다. 이 함수는 flashinfer.gemm.mm_M1_16_K7168_N256 커널을 래핑하며, launch_with_pdl=True로 설정하여 PDL을 기본적으로 활성화합니다. 이는 리뷰 과정에서 PDL 활성화가 성능에 긍정적인 영향을 미친다는 논의를 반영한 것입니다. 또한, self.weight.shape[0] == 384 조건은 Kimi K2 모델을 위해 유지되었습니다.

왜 이게 좋은가

이 PR의 핵심은 DeepSeekV3 라우터 GEMM 연산의 잠재적 성능 향상을 탐색하고, flashinfer 라이브러리의 최적화된 커널을 통합하는 데 있습니다.

성능 수치 분석

초기 벤치마킹 결과는 flashinfer 커널이 특정 m 값(예: m=6, m=14)과 tp_size에서 기존 sglang 커널보다 약간 더 나은 성능을 보였으나, 전반적으로는 큰 차이가 없었습니다. 특히 launch_with_pdl=False일 때는 sglang 커널이 더 나은 경우가 많았습니다. 그러나 nv-yunzheq의 피드백에 따라 launch_with_pdl=True로 설정하고 TRTLLM_ENABLE_PDL 환경 변수를 사용하여 PDL을 활성화한 후 재벤치마킹한 결과, 두 커널의 성능은 거의 동등한 수준으로 나타났습니다.

PDL 활성화 시 벤치마크 결과 (일부 발췌):

m n k tp_size SGLang (us) Flashinfer (us)
1 256 7168 1 13.344000 13.344000
6 256 7168 1 15.360000 15.360000
14 256 7168 1 19.455999 19.487999

이 결과는 두 커널이 동일한 TRTLLM 기반에서 파생되었기 때문에 예상된 결과라는 리뷰어 leejnau의 설명과 일치합니다. 즉, flashinfer 커널이 현재 sglang 커널보다 압도적으로 우수하지는 않지만, 성능 패리티를 유지하면서 외부 최적화 라이브러리를 통합하는 좋은 사례를 보여줍니다.

일반적 교훈

  1. 외부 라이브러리 활용의 중요성: flashinfer와 같은 전문적인 고성능 컴퓨팅 라이브러리를 통합하는 것은 자체 커널 개발 및 유지보수 부담을 줄이면서 최신 하드웨어(Blackwell GPU)에 최적화된 성능을 활용할 수 있는 전략입니다.
  2. 정확한 벤치마킹 환경 구축: torch.randn 대신 torch.empty를 사용하고 contiguous() 호출을 제거하는 등, 벤치마킹 시 불필요한 오버헤드를 최소화하는 것이 중요합니다. 이는 실제 성능을 정확하게 측정하고 의미 있는 비교를 가능하게 합니다.
  3. PDL(Pre-computed Dynamic Loading)의 역할: PDL은 GPU 커널 실행 시 동적 로딩 오버헤드를 줄여 성능을 향상시키는 기술입니다. DeepSeekV3 라우터 GEMM과 같이 반복적으로 호출되는 연산에서 PDL을 활성화하는 것은 중요한 최적화 포인트입니다. 리뷰 과정에서 PDL 활성화가 기본 동작이 되어야 한다는 논의는 이러한 중요성을 강조합니다.
  4. 조건부 최적화: 특정 GPU 아키텍처(_device_sm >= 100) 및 모델 구성(self.weight.shape[0] == 256)에 따라 최적화된 커널을 선택적으로 적용하는 것은 유연하고 효율적인 성능 관리를 가능하게 합니다. 이는 모든 시나리오에 하나의 커널을 강제하는 대신, 가장 적합한 구현을 동적으로 선택함으로써 전반적인 시스템 성능을 극대화합니다.

이 PR은 당장 드라마틱한 성능 향상을 가져오지는 않지만, sglang이 외부 고성능 라이브러리를 통합하고 벤치마킹을 통해 성능을 검증하는 모범적인 개발 프로세스를 보여줍니다. 이는 향후 flashinfer 커널이 더욱 최적화되거나 다른 고성능 커널이 등장했을 때, sglang이 이를 신속하게 도입하고 활용할 수 있는 기반을 마련합니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글