본문으로 건너뛰기

[flashinfer] NVIDIA Blackwell SM120을 위한 MoE Short-Decode 최적화 분석

PR 링크: flashinfer-ai/flashinfer#3193 상태: Merged | 변경: +None / -None

들어가며

LLM 추론 성능의 핵심 중 하나는 Mixture of Experts(MoE) 모델의 효율적인 처리입니다. 특히 토큰을 하나씩 생성하는 Decode 단계에서는 배치 사이즈가 작을 때(Short-decode) 발생하는 오버헤드를 줄이는 것이 관건입니다.

이번 flashinfer PR은 NVIDIA의 차세대 아키텍처인 Blackwell(SM120) 환경에서 B12x MoE 커널의 성능을 최적화하는 내용을 담고 있습니다. 핵심은 단일 토큰(Single-token) 상황에서 불필요한 전처리 과정을 생략하고, 하드웨어 자원을 더 정밀하게 제어하는 'Micro-kernel'의 고도화입니다.

코드 분석: 무엇이 바뀌었는가?

1. ReLU2 단일 토큰 디스패치 지름길 (Shortcut)

기존에는 MoE 연산을 위해 Triton 기반의 Compaction 커널을 먼저 실행하여 활성화된 전문가(Expert)들을 정리하는 과정이 필요했습니다. 하지만 단일 토큰의 경우 top-k 전문가가 이미 명확하므로 이 과정을 생략할 수 있습니다.

Before:

# 모든 경우에 Triton compaction 커널을 호출하여 전문가 ID를 정리함
from .triton_compact import compact_topk_ids as _triton_compact_topk_ids
_triton_compact_topk_ids(...)
launch_ids = compact_ids

After:

# ReLU2 활성화 함수를 사용하는 단일 토큰의 경우, 별도의 정리 없이 바로 flat_ids 사용
if num_tokens == 1 and activation == "relu2":
    launch_ids = flat_ids
elif num_tokens == 1:
    # Gated SiLU 등은 여전히 compact_ids를 사용하되 호스트에서 직접 매핑 생성
    compact_ids.copy_(torch.arange(flat_ids.numel(), ...))
    launch_ids = compact_ids
else:
    # 다중 토큰일 때만 Triton 커널 실행
    _triton_compact_topk_ids(...)
    launch_ids = compact_ids

이 변경을 통해 m=1인 상황에서 GPU 커널 런칭 오버헤드를 줄이고 지연 시간(Latency)을 단축했습니다.

2. 입력 및 전문가 스케일 공유 최적화

단일 토큰 디코딩 시 모든 전문가가 동일한 입력값을 보게 됩니다. 이를 활용해 양자화된 입력을 한 번만 계산하고 공유하는 share_input_across_experts 로직이 강화되었습니다.

moe_dispatch.py 변경점:

# 입력 스케일과 전문가별 스케일이 모두 단일 값(Shared)인지 확인
input_gs_is_shared = input_gs.numel() == 1
down_input_scale_is_shared = down_input_scale.numel() == 1

# ... 중략 ...

share_expert_scales = (
    activation == "relu2" and input_gs_is_shared and down_input_scale_is_shared
)

이 정보는 마이크로 커널로 전달되어, 커널 내부에서 불필요한 중복 계산이나 메모리 쓰기를 방지합니다.

3. 마이크로 커널 내부의 조건부 컴파일 최적화

moe_micro_kernel.py에서는 cutlass.const_expr을 사용하여 컴파일 타임에 경로를 최적화합니다. 단일 토큰 여부에 따라 런타임 분기를 제거하여 실행 효율을 높였습니다.

moe_micro_kernel.py (StorageRelu2 클래스):

# Before: 런타임 변수 all_rows_unique로 체크
if all_rows_unique == Int32(0):
    # ... 전문가 카운트 초기화 로직 ...

# After: 컴파일 타임 상수를 활용한 최적화
if cutlass.const_expr(not self.single_token):
    i = flat_tid
    while i < num_experts:
        row_counts[i] = Int32(0)

왜 이게 좋은가?

  1. 지연 시간(Latency) 최소화: 단일 토큰 생성 시 0.030ms라는 매우 낮은 Median 지연 시간을 달성했습니다. 이는 Triton 커널 호출을 생략하고 마이크로 커널 내에서 직접 전문가 ID를 처리함으로써 얻은 결과입니다.
  2. 하드웨어 활용도 극대화: Blackwell의 SM120 아키텍처 특성에 맞춰 MAC(Max Active Clusters) 값을 튜닝했습니다. 특히 routed_rows가 적을 때(40 미만) Static MAC을 64로 제한하여 자원 경합을 방지했습니다.
  3. 메모리 대역폭 절약: share_input_across_experts 옵션을 통해, 여러 전문가가 동일한 입력 데이터를 처리할 때 발생하는 중복 로드 및 양자화 연산을 제거했습니다.

마치며

이번 PR은 MoE 모델의 실시간 서비스 성능을 높이기 위해 "단일 토큰"이라는 특수한 상황을 얼마나 깊게 최적화할 수 있는지 보여주는 좋은 사례입니다. 특히 Blackwell과 같은 최신 GPU 아키텍처에서 마이크로 커널을 통해 하드웨어 제어권을 세밀하게 가져가는 전략은 시니어 엔지니어들이 눈여겨봐야 할 대목입니다.

이러한 최적화 기법은 대규모 언어 모델의 서빙 비용을 절감하고 사용자 경험을 개선하는 데 직접적인 기여를 합니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글