[SGLang] Prefill Delayer: 전략적 프리필 지연으로 디코드 처리량 극대화
들어가며
LLM 서빙에서 프리필(prefill)과 디코드(decode)는 근본적으로 다른 연산 특성을 가진다. 프리필은 긴 입력을 한 번에 처리하는 compute-intensive 작업이고, 디코드는 토큰 하나씩 생성하는 memory-bandwidth-intensive 작업이다. 프리필 요청이 디코드 배치에 끼어들면 디코드 처리량이 급격히 하락한다. SGLang의 PrefillDelayer는 프리필 요청을 의도적으로 몇 패스 지연시켜, 디코드 배치가 최대 배치 크기에 도달할 때까지 기다리는 전략을 구현한다. 소스는 python/sglang/srt/managers/prefill_delayer.py에 있다.
구조도
DP Worker 0 DP Worker 1
┌──────────┐ ┌──────────┐
│prefillable│ │ decode │
│ = True │ │ only │
└─────┬─────┘ └─────┬─────┘
│ │
all_gather (4 ints per worker)
│
┌────────▼────────┐
│ prefillable_status │
│ "all" / "mixed" / "none"
└────────┬────────┘
│
┌──────────────┼──────────────┐
▼ ▼ ▼
"all" "mixed" "none"
┌─────────┐ ┌──────────┐ ┌─────────┐
│ 디코드BS │ │delay_count│ │ 무조건 │
│ < max? │ │< max-1? │ │ allow │
│ → delay │ │ → delay │ └─────────┘
│ else │ │ else │
│ allow │ │ allow │
└─────────┘ │(timeout) │
└──────────┘
핵심 코드 분석
1. 초기화: 지연 패스 수와 워터마크
PrefillDelayer는 두 가지 핵심 파라미터로 초기화된다: 최대 지연 패스 수(max_delay_passes)와 토큰 사용률 저수위 워터마크(token_usage_low_watermark).
class PrefillDelayer:
def __init__(self, dp_size, attn_tp_size, cpu_group, server_args,
max_delay_passes, token_usage_low_watermark, ...):
self._max_delay_passes = max_delay_passes
self._token_usage_low_watermark = token_usage_low_watermark
self.dp_size = dp_size
self.enable_dp_attention = server_args.enable_dp_attention
dp_size_dim = dp_size if self.enable_dp_attention else 1
self._global_info_buffer = torch.empty(
(dp_size_dim, attn_tp_size, 4),
dtype=torch.int64, device=device,
)
disaggregation_mode가 null이고 overlap_schedule이 활성화되어야만 사용 가능하다. 이는 PrefillDelayer가 오버랩 스케줄링 환경에서만 의미가 있기 때문이다.
2. 상태 머신: _State
지연 상태는 불변(frozen) 데이터클래스 _State로 관리된다. delayed_count가 지연 횟수를 추적하고, start_time이 지연 시작 시각을 기록한다.
@dataclass(frozen=True)
class _State:
delayed_count: int = 0
start_time: float = field(default_factory=time.perf_counter)
def bump_delayed_count(self) -> "_State":
return dataclasses.replace(self, delayed_count=self.delayed_count + 1)
frozen=True이므로 새 상태는 항상 dataclasses.replace로 생성한다. 이 불변 패턴은 상태 전이를 명시적으로 만들어 디버깅을 용이하게 한다.
3. 글로벌 정보 수집: _gather_info
각 워커는 4개 정수(prefillable, token_watermark_force_allow, running_batch, max_prefill_bs)를 all_gather로 교환한다.
def _gather_info(self, local_prefillable, local_token_watermark_force_allow, **kwargs):
local_info = torch.tensor(
[
int(local_prefillable),
int(local_token_watermark_force_allow),
kwargs.get("running_batch", 0),
kwargs.get("max_prefill_bs", 0),
],
device="cpu", dtype=torch.int64,
)
torch.distributed.all_gather_into_tensor(
self._global_info_buffer.flatten(),
local_info, group=self._cpu_group,
)
tp0_info = self._global_info_buffer[:, 0, :]
return tp0_info
CPU 그룹에서 all_gather를 수행하므로 GPU 연산에 간섭하지 않는다.
4. 핵심 의사결정: _negotiate_should_allow_prefill_pure
이 함수가 PrefillDelayer의 두뇌다. 글로벌 prefillable 상태에 따라 세 가지 경로로 분기한다.
def _negotiate_should_allow_prefill_pure(self, prev_state, local_prefillable,
token_usage, **kwargs):
# 로컬 워터마크 판단
local_token_watermark_force_allow = (
local_prefillable
and ((x := self._token_usage_low_watermark) is not None)
and (token_usage < x)
)
# 글로벌 상태 수집
tp0_info = self._gather_info(...)
global_prefillable = tp0_info[:, 0]
if global_prefillable.min().item() > 0:
prefillable_status = "all"
elif global_prefillable.max().item() == 0:
prefillable_status = "none"
else:
prefillable_status = "mixed"
"all" (모든 워커가 프리필 가능): 디코드 배치의 여유 슬롯이 max_prefill_bs보다 작으면 지연한다. 즉, 디코드가 아직 최대 배치에 도달하지 않았으면 프리필을 미룬다.
if prefillable_status == "all":
max_running_requests = kwargs.get("max_running_requests", 0)
if (max_running_requests - global_running_batch.max().item()
< global_max_prefill_bs.max().item()):
if self.skip_first_delayer:
self.skip_first_delayer = False
pass # 첫 번째는 건너뜀
else:
next_state = prev_state or _State()
next_state = next_state.bump_delayed_count()
return _NegotiateOutput(
next_state=next_state, output_allow=False,
output_reason="delay", **debug_info,
)
"mixed" (일부만 프리필 가능): 토큰 워터마크를 초과하면 강제 허용하고, 그렇지 않으면 max_delay_passes - 1번까지 지연 후 timeout으로 허용한다.
elif prefillable_status == "mixed":
if global_exists_token_watermark_force_allow:
return _NegotiateOutput(next_state=None, output_allow=True,
output_reason="token_watermark", ...)
prev_delayed_count = prev_state.delayed_count if prev_state else 0
if prev_delayed_count < self._max_delay_passes - 1:
next_state = (prev_state or _State()).bump_delayed_count()
return _NegotiateOutput(next_state=next_state, output_allow=False,
output_reason="delay", ...)
else:
return _NegotiateOutput(next_state=None, output_allow=True,
output_reason="wait_timeout", ...)
"none" (아무도 프리필 불가): 판단 자체가 불필요하므로 무조건 허용한다.
5. SinglePassExecutor: 패스당 한 번만 호출
PrefillDelayerSinglePassExecutor는 한 스케줄링 패스에서 negotiate가 한 번만 호출되도록 보장한다.
class PrefillDelayerSinglePassExecutor:
def negotiate_should_allow_prefill(self, local_prefillable, **kwargs) -> bool:
if not self._called:
self._result = self._prefill_delayer._negotiate_should_allow_prefill(
local_prefillable=local_prefillable,
token_usage=self._token_usage, **kwargs,
)
return self._result.output_allow
def finalize(self, *, actual_prefill: bool):
if not self._called:
self.negotiate_should_allow_prefill(local_prefillable=False)
_record_single_pass_result(
actual_execution=actual_prefill, output=self._result,
metrics_collector=self._prefill_delayer._metrics_collector,
)
finalize는 실제 프리필 실행 여부와 함께 메트릭을 기록한다. 호출되지 않은 경우 local_prefillable=False로 자동 보고하여 all_gather 동기화가 깨지지 않도록 한다.
6. 메트릭 관측
디버그 로그와 메트릭 컬렉터를 통해 지연 패스 수, 대기 시간, 최종 결정을 상세하게 추적할 수 있다.
def _record_single_pass_result(actual_execution, output, metrics_collector):
if metrics_collector is not None:
if (s := output.next_state) is not None:
wait_seconds = time.perf_counter() - s.start_time
forward_passes = s.delayed_count
else:
wait_seconds = forward_passes = 0
metrics_collector.observe_prefill_delayer_outcome(
forward_passes=forward_passes, wait_seconds=wait_seconds,
input_estimation=output.input_estimation,
output_allow=output.output_allow,
output_reason=output.output_reason,
actual_execution=actual_execution,
)
왜 이 설계인가
프리필-디코드 간섭 문제: 프리필 요청은 대량의 토큰을 한 번에 처리해야 하므로, 디코드 배치와 합쳐지면 디코드 레이턴시가 급증한다. PrefillDelayer는 디코드 배치가 충분히 커질 때까지 프리필을 지연시켜, 디코드 처리량(throughput)을 극대화한다.
"mixed" 상태에서의 타임아웃: 일부 워커만 프리필 가능한 상황에서 무한 대기하면 starvation이 발생한다. max_delay_passes로 상한을 두어, 일정 패스 후에는 타임아웃으로 프리필을 허용한다.
토큰 워터마크 강제 허용: KV cache 사용률이 low watermark 아래로 떨어지면, 디코드 처리량보다 GPU 활용률이 더 중요해진다. 이 경우 즉시 프리필을 허용하여 파이프라인이 비는 것을 방지한다.
불변 상태와 순수 함수: _State의 frozen 패턴과 _negotiate_should_allow_prefill_pure의 (거의) 순수 함수 설계는 상태 전이를 예측 가능하게 만들고, 테스트와 디버깅을 용이하게 한다.
첫 번째 지연 건너뛰기: skip_first_delayer 플래그는 서버 시작 직후 첫 merge_batch에서 디코드가 아직 최대 배치에 도달하지 못한 상황을 예외 처리한다.
관련 포스트
- [SGLang] Pipeline Parallelism 스케줄러: PP 믹스인 설계 (별도 포스트)
- [SGLang] DP Attention 믹스인 분석 (별도 포스트)
참고
- 소스 코드:
python/sglang/srt/managers/prefill_delayer.py - SGLang GitHub: https://github.com/sgl-project/sglang
- Sarathi-Serve 논문: Agrawal et al., "Taming Throughput-Latency Tradeoff in LLM Inference with Sarathi-Serve", OSDI 2024 (chunked prefill 관련)
관련 포스트
SGLang 의 다른글
- 이전글 [SGLang] Data Parallel Attention 스케줄러: DP Attention 믹스인
- 현재글 : [SGLang] Prefill Delayer: 전략적 프리필 지연으로 디코드 처리량 극대화
- 다음글 [SGLang] RadixAttention: Radix Tree 기반 프리픽스 캐싱의 핵심
댓글