본문으로 건너뛰기

[sglang] SGLang의 FA3 디코드 최적화: get_scheduler_metadata 도입

PR 링크: sgl-project/sglang#21103 상태: Merged | 변경: +128 / -0

들어가며

LLM 추론 성능을 극대화하기 위해 FlashAttention-3(FA3)와 같은 최신 커널을 사용하는 것은 필수적입니다. 하지만 복잡한 커널 실행 과정에서 반복되는 연산은 전체 추론 지연 시간(Latency)의 병목이 될 수 있습니다. 특히 prepare_varlen_num_blocks와 같이 매 레이어마다 수행되는 커널은 모델의 레이어 수가 깊어질수록 누적되는 오버헤드가 상당합니다. 이번 SGLang의 PR은 FA3의 타일 스케줄링 메타데이터를 사전에 계산하여, 레이어별 반복 호출을 방지함으로써 디코드 성능을 개선합니다.

코드 분석

이번 변경은 C++ 수준에서 정의된 심볼을 PyTorch 연산으로 노출하고, 이를 Python 레이어에서 활용할 수 있도록 래핑하는 과정을 포함합니다.

1. C++ 커널 등록 (sgl-kernel/csrc/flash_extension.cc)

기존에 flash_ops.so에 존재했지만 PyTorch에서 직접 접근할 수 없었던 mha_fwd_get_scheduler_metadata 함수를 TORCH_LIBRARY_FRAGMENT를 통해 등록했습니다.

m.def("get_scheduler_metadata(...) -> Tensor");
m.impl("get_scheduler_metadata", torch::kCUDA, make_pytorch_shim(&mha_fwd_get_scheduler_metadata));

2. 헤더 선언 (sgl-kernel/include/sgl_flash_kernel_ops.h)

C++ 구현체와 PyTorch 바인딩 사이의 인터페이스를 정의합니다. at::Tensorstd::optional을 사용하여 다양한 입력 파라미터를 유연하게 처리하도록 설계되었습니다.

at::Tensor mha_fwd_get_scheduler_metadata(
    int64_t batch_size, ..., at::Tensor seqused_k, 
    std::optional<at::Tensor> cu_seqlens_q_, ...);

3. Python 래퍼 구현 (sgl-kernel/python/sgl_kernel/flash_attn.py)

사용자가 쉽게 호출할 수 있도록 Python 함수를 제공합니다. 핵심은 이 함수를 배치당 한 번만 호출하여 얻은 메타데이터를 이후 레이어의 연산에 재사용하는 것입니다.

def get_scheduler_metadata(...):
    return torch.ops.sgl_kernel.get_scheduler_metadata(...)

왜 이게 좋은가

이 최적화의 핵심 교훈은 **'반복되는 불변 연산의 제거(Hoisting)'**입니다.

  1. 레이어 오버헤드 감소: 디코드 단계에서 prepare_varlen_num_blocks는 입력 시퀀스 길이에 의존하지만, 레이어 간에는 동일한 메타데이터를 공유할 수 있습니다. 이를 사전에 계산함으로써 각 레이어마다 커널을 실행하는 GPU 커널 런치 오버헤드를 제거했습니다.
  2. 효율적인 리소스 활용: GPU 커널 런치(Launch)는 CPU에서 GPU로 명령을 전달하는 과정에서 비용이 발생합니다. 레이어 수가 많은 모델(예: Llama-3 70B 등)에서 이 최적화는 누적된 오버헤드를 획기적으로 줄여줍니다.

일반적으로 고성능 커널을 작성할 때, 데이터 의존성이 없는 메타데이터 계산은 메인 연산 루프 밖으로 분리하는 것이 추론 성능 최적화의 정석입니다. 이번 PR은 SGLang이 대규모 모델 추론에서 더 낮은 Latency를 달성하기 위한 중요한 발판을 마련했습니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글