[sglang] SGLang 스케줄러 최적화: input_ids H2D 지연 처리 및 FutureMap 통합
PR 링크: sgl-project/sglang#25945 상태: Merged | 변경: +167 / -127
들어가며
LLM 추론 엔진에서 스케줄링 단계와 모델 실행(Forward) 단계 사이의 데이터 전송은 성능의 병목이 될 수 있습니다. 특히 input_ids를 스케줄러에서 GPU로 미리 복사(H2D)하는 방식은 스트림 간의 동기화 오버헤드를 유발합니다. 이번 SGLang의 PR은 input_ids의 H2D 작업을 schedule_stream에서 forward_stream으로 지연(Defer)시키고, 모든 입력 처리 경로를 FutureMap이라는 중앙 집중식 릴레이 시스템으로 통합하여 성능을 최적화했습니다.
코드 분석
1. python/sglang/srt/managers/schedule_batch.py: H2D 지연 처리
기존에는 prepare_for_extend 단계에서 즉시 GPU 텐서를 생성했으나, 이제는 CPU pinned memory에 데이터를 유지하고 Forward 단계에서 처리합니다.
Before:
input_ids_tensor = flatten_arrays_to_int64_tensor(input_ids, self.device, _pin)
self.input_ids = input_ids_tensor
After:
pinned_input_ids = flatten_arrays_to_pinned_cpu(input_ids, _pin)
self.input_ids = None
self.prefill_input_ids_cpu = pinned_input_ids
이를 통해 스케줄러는 GPU 메모리 할당 및 복사로부터 자유로워지며, Forward entry 시점에 resolve_forward_inputs를 통해 필요한 시점에만 H2D가 수행됩니다.
2. python/sglang/srt/managers/overlap_utils.py: FutureMap 통합
모든 입력(Prefill, Decode, Mixed)은 이제 resolve_forward_inputs를 거쳐 구체화됩니다. 특히 Mixed batch의 경우 torch.cat을 사용하여 CPU에서 가져온 Prefill 데이터와 FutureMap에 저장된 Decode 토큰을 결합합니다.
def resolve_forward_inputs(batch: ScheduleBatch, future_map: FutureMap) -> None:
if batch.prefill_input_ids_cpu is not None:
prefill_gpu = batch.prefill_input_ids_cpu.to(batch.device, non_blocking=True)
if batch.mix_running_indices is not None:
decode_gpu = future_map.output_tokens_buf[batch.mix_running_indices]
batch.input_ids = torch.cat([prefill_gpu, decode_gpu])
else:
batch.input_ids = prefill_gpu
batch.prefill_input_ids_cpu = None
왜 이게 좋은가
- 스트림 오버헤드 감소: 스케줄링 스트림에서 GPU 복사를 제거함으로써, 스케줄러가 다음 배치를 준비하는 동안 Forward 스트림이 이전 배치를 실행하는 파이프라인 병렬성이 극대화됩니다.
- 일관된 데이터 경로:
FutureMap을 통해 모든 입력 유형(prefill, decode, mixed)을 단일 경로로 처리함으로써 코드 복잡도를 줄이고,input_ids=None으로 인한 런타임 크래시를 방지했습니다. - 유연성:
spec_v1과 같은 기존 레거시 모드와의 호환성을 유지하면서도,spec_v2와 같은 최신 기능에 필요한 동적 데이터 릴레이를 표준화했습니다.
이 최적화는 비동기 데이터 전송(Asynchronous H2D)을 활용하여 GPU 연산 유닛의 유휴 시간을 줄이는 전형적인 고성능 시스템 설계 패턴을 따르고 있습니다. 특히 non_blocking=True를 활용한 H2D는 스트림 간의 간섭을 최소화하는 핵심 전략입니다.
리뷰어 피드백 반영
리뷰 과정에서 spec_v1과 spec_v2 간의 데이터 처리 차이를 명확히 구분하고, _DEBUG_ASSERT를 통해 CI 환경에서 데이터 무결성을 검증하는 로직이 추가되었습니다. 이는 복잡한 비동기 로직에서 발생할 수 있는 '데이터 누락' 버그를 방지하는 데 큰 역할을 합니다.
참고 자료
- https://pytorch.org/docs/stable/generated/torch.Tensor.to.html
- https://pytorch.org/docs/stable/generated/torch.cat.html
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
- [sglang] SGLang의 MLA KV 캐시 쓰기 최적화: TMA Bulk-Store 도입
- [sglang] SGLang 성능 최적화: torch.cuda.empty_cache() 호출 제어를 통한 가중치 업데이트 병목 해결
- [sglang] SGLang Triton 커널 최적화: libdevice.tanh 도입과 2D Strided Tensor 지원
- [sglang] SGLang의 디코드 성능 향상을 위한 Temperature 및 Softmax 커널 융합
- [sglang] SGLang의 FA3 디코드 최적화: get_scheduler_metadata 도입
PR Analysis 의 다른글
- 이전글 [vllm] [vLLM] MiniMax-M2 MoE Gate 최적화: Fused FP32 Kernel로 서빙 성능 32% 향상시키기
- 현재글 : [sglang] SGLang 스케줄러 최적화: input_ids H2D 지연 처리 및 FutureMap 통합
- 다음글 [cpython] tarfile 스트리밍 모드(r|*) 성능 개선: 파이썬 압축 파일 처리의 숨겨진 병목 제거
댓글