본문으로 건너뛰기

[flashinfer] FlashInfer 오토튜너 최적화: 하이브리드 토큰 버킷 도입

PR 링크: flashinfer-ai/flashinfer#3115 상태: Merged | 변경: +None / -None

들어가며

고성능 LLM 추론 엔진인 FlashInfer에서 커널 오토튜닝(Autotuning)은 최적의 성능을 끌어내기 위한 핵심 단계입니다. 기존에는 토큰 수를 2의 거듭제곱(power-of-2) 단위로 버킷팅하여 튜닝을 수행했습니다. 하지만 이 방식은 큰 값으로 갈수록 버킷 간의 간격이 너무 넓어져(예: 1024에서 2048로 급격히 점프), 실제 워크로드와 튜닝된 커널 간의 미스매치가 발생하는 문제가 있었습니다. 본 PR은 이를 해결하기 위해 '하이브리드 토큰 버킷' 방식을 도입했습니다.

코드 분석

1. flashinfer/fused_moe/utils.py: 하이브리드 버킷 로직 구현

기존의 단순한 2의 거듭제곱 방식 대신, 4단계의 하이브리드 페이즈를 도입하여 버킷 간격을 세밀하게 조정했습니다.

# Before
def get_last_power_of_2_num_tokens_buckets(max_num_tokens, min_num_tokens=1):
    max_num_tokens = last_positive_power_of_2(max_num_tokens)
    # ... (단순 2의 거듭제곱 루프)

# After
def get_hybrid_num_tokens_buckets(max_num_tokens: int, min_num_tokens: int = 1):
    # Phase 1: [min .. 256] — power-of-2
    # Phase 2: (256 .. 2048] — linear step 256
    # Phase 3: (2048 .. 4096] — linear step 512
    # Phase 4: (4096 .. max] — power-of-2
    # ... (4단계 로직 구현)

2. flashinfer/fused_moe/core.py: 호출부 업데이트

MoERunner 및 관련 클래스에서 사용하는 DynamicTensorSpec의 버킷 생성 및 매핑 함수를 새로운 하이브리드 API로 교체했습니다.

# Before
get_last_power_of_2_num_tokens_buckets(8192),
lambda x: min(last_positive_power_of_2(x), 8192),

# After
get_hybrid_num_tokens_buckets(8192),
lambda x: map_to_hybrid_bucket(x, 8192),

3. flashinfer/fused_moe/core.py: 버그 수정

trtllm_fp8_per_tensor_scale_moe 호출 시 누락되었던 routing_replay_out 인자를 추가하여 데이터 흐름을 정상화했습니다.

왜 이게 좋은가

기존의 2의 거듭제곱 방식은 1024와 2048 사이의 워크로드 변화를 제대로 반영하지 못했습니다. MoE 모델의 경우 avg_tokens_per_expert가 이 간격 사이에서 자주 변동하는데, 기존 방식은 튜닝된 커널이 실제 워크로드와 너무 동떨어진 설정을 사용하게 만들었습니다.

이번 개선을 통해 얻은 교훈은 다음과 같습니다:

  1. 워크로드 특성 반영: 범용적인 2의 거듭제곱보다, 실제 하드웨어 커널이 민감하게 반응하는 구간(256~4096)에 대해 선형적인 간격을 두는 것이 튜닝 정확도에 훨씬 유리합니다.
  2. 유연한 매핑: map_to_hybrid_bucket을 통해 입력 토큰 수를 최적의 버킷으로 매핑함으로써, 오토튜너가 더 적절한 커널을 선택하도록 유도했습니다.

이 변경은 단순한 리팩토링을 넘어, 다양한 토큰 수 환경에서 커널 실행 효율성을 극대화하는 실질적인 성능 최적화 사례입니다.

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글