본문으로 건너뛰기

[sglang] SGLang에서 FA4(FlashAttention 4)와 Speculative Decoding의 완벽한 결합

PR 링크: sgl-project/sglang#21080 상태: Merged | 변경: +None / -None

들어가며

최근 LLM 서빙 환경에서는 메모리 효율성과 처리량(throughput)을 극대화하기 위해 FP4 기반의 FA4(FlashAttention 4) 도입이 활발합니다. 하지만 기존 SGLang의 Speculative Decoding 파이프라인은 FA4와 호환되지 않아, 지연 시간(latency)에 민감한 시나리오에서 성능 최적화에 제약이 있었습니다. 이번 PR은 FA4를 Speculative Decoding 흐름에 통합하여, 저정밀도 어텐션과 추론 가속 기법을 동시에 활용할 수 있도록 개선했습니다.

코드 분석

1. python/sglang/srt/speculative/draft_utils.py: 백엔드 통합

기존에는 FA3와 FA4가 분리되어 관리되었으나, 코드 중복을 줄이고 유지보수성을 높이기 위해 범용적인 _create_fa_decode_backend_create_fa_prefill_backend를 도입했습니다.

# Before/After: 범용 백엔드 생성 로직 도입
def _create_fa_decode_backend(self, fa_impl_ver: int = 3):
    return FlashAttentionMultiStepBackend(
        self.draft_model_runner,
        self.topk,
        self.speculative_num_steps,
        fa_impl_ver=fa_impl_ver,
    )

# 래퍼를 통해 일관성 유지
def _create_fa4_decode_backend(self):
    return self._create_fa_decode_backend(fa_impl_ver=4)

2. python/sglang/srt/layers/attention/flashattention_backend.py: 파라미터 전달

FlashAttentionMultiStepBackend 클래스에 fa_impl_ver를 추가하여 추론 시점에 적절한 어텐션 버전을 선택할 수 있도록 했습니다.

# After: 생성자 파라미터 확장
def __init__(self, model_runner, topk, speculative_num_steps, fa_impl_ver=3):
    self.fa_impl_ver = fa_impl_ver
    # ... 이후 spec 단계별 초기화 시 fa_impl_ver 전달

왜 이게 좋은가

이번 최적화의 핵심은 **'유연한 백엔드 선택'**입니다. FA4를 Speculative Decoding에 통합함으로써, 사용자는 EAGLE3와 같은 고성능 추론 알고리즘을 사용하면서도 FA4의 메모리 절감 효과를 누릴 수 있습니다.

  • 성능 향상: 벤치마크 결과, 120B 모델 기준 FA4를 적용한 Speculative Decoding은 기존 대비 높은 처리량을 기록했습니다.
  • 유지보수성: 리뷰어 Qiaolin-Yu의 제안대로 fa_impl_ver 파라미터를 활용한 공통 함수를 작성하여, 향후 새로운 FA 버전이 나오더라도 백엔드 맵을 쉽게 확장할 수 있는 구조가 되었습니다.

이러한 구조적 개선은 복잡한 추론 파이프라인에서 하드웨어 가속 기술(FA4)과 알고리즘 가속 기술(Speculative Decoding)을 결합할 때 발생하는 '백엔드 불일치' 문제를 해결하는 표준적인 패턴을 제시합니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글