[sglang] SGLang의 Breakable CUDA Graph 최적화: 배치 사이즈 제한 극복하기
PR 링크: sgl-project/sglang#24662 상태: Merged | 변경: +0 / -0
들어가며
LLM 추론 엔진에서 CUDA Graph는 정적인 연산 그래프를 미리 기록하여 커널 실행 오버헤드를 줄이는 강력한 도구입니다. 하지만 기존 SGLang의 BreakableCudaGraphRunner는 배치 사이즈(bs)가 1인 경우에만 작동하도록 설계되어 있어, 멀티 배치 환경에서의 성능 이점을 충분히 활용하지 못하는 제약이 있었습니다. 이번 PR은 이 제한을 깨고, bs > 1인 상황에서도 CUDA Graph를 활용할 수 있도록 아키텍처를 재설계했습니다.
코드 분석
1. 모델 레이어 분리 및 캡처 범위 조정
기존에는 모델 전체의 forward를 캡처하려 했으나, 이는 logits_processor와 같이 배치 사이즈에 의존적인 연산까지 포함하게 되어 유연성을 떨어뜨렸습니다. 이제는 Transformer 레이어 스택만을 분리하여 캡처합니다.
# Before
# (기존에는 모델 전체를 캡처하여 bs=1에 고정됨)
# After
language_model = getattr(model_runner.model, "language_model", model_runner.model)
self.layer_model = (
language_model.model
if hasattr(language_model, "model") and hasattr(language_model.model, "layers")
else language_model
)
2. Monkey-patching을 통한 동적 Replay
replay 시점에 layer_model.forward를 캡처된 그래프를 실행하는 클로저로 임시 교체(Monkey-patching)합니다. 이를 통해 외부의 logits_processor는 실제 멀티 배치 데이터를 처리하고, 내부의 연산은 최적화된 그래프를 재사용하게 됩니다.
# After
def replay_layer_forward(*args, **layer_kwargs):
captured_graph.replay()
return captured_hidden
original_layer_forward = self.layer_model.forward
self.layer_model.forward = replay_layer_forward
3. @torch.no_grad 적용
layer_model.forward를 직접 호출하면서 기존 모델 클래스에 있던 @torch.no_grad 데코레이터가 누락되는 문제를 해결했습니다. 이는 torch.compile 기반의 MoE 커널 실행 시 그래디언트 추적 오류를 방지합니다.
왜 이게 좋은가
이 최적화의 핵심은 '그래프 캡처의 범위(Scope)를 최적화'한 것입니다.
- 범위의 분리: 배치 사이즈에 민감한 로직(logits, pooler)을 그래프 밖으로 빼냄으로써, 그래프는 배치 사이즈와 무관한 '토큰 단위 연산'에만 집중할 수 있게 되었습니다. 이는 다양한 배치 사이즈 요청이 들어오는 실제 서빙 환경에서 매우 중요합니다.
- 유연성:
can_run메서드에서forward_batch.batch_size > 1체크를 제거하여, 이제 더 넓은 범위의 추론 요청을 그래프 가속 경로로 처리할 수 있습니다. - 교훈: 복잡한 모델을 최적화할 때, 전체를 하나의 그래프로 묶으려 하기보다 '정적인 연산'과 '동적인 연산'을 분리하여 부분적으로 그래프를 적용하는 것이 훨씬 강력한 성능과 유연성을 제공합니다.
리뷰어 피드백 반영
리뷰어 merrymercy는 getattr을 통한 모듈 탐색이 문자열 기반의 매칭이라 다소 'hacky'하다고 지적했습니다. 이는 향후 모델 구조 변경 시 잠재적인 버그 포인트가 될 수 있으므로, 매칭 실패 시 명확한 경고(warning)를 출력하도록 보완하는 것이 권장됩니다.
References
참고 자료
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
- [sglang] SGLang P/D Disaggregation: Decode-Side Radix Cache 도입으로 LLM 추론 성능 극대화
- [sglang] SGLang 성능 최적화: torch.cuda.empty_cache() 호출 제어를 통한 가중치 업데이트 병목 해결
- [sglang] SGLang Whisper 모델의 CUDA Graph 도입 및 성능 최적화 분석
- [sglang] SGLang, 레이어별 오프로딩 기본값 설정을 통한 인코더/VAE 성능 최적화
- [sglang] SGLang의 MLA KV 캐시 쓰기 최적화: TMA Bulk-Store 도입
PR Analysis 의 다른글
- 이전글 [flashinfer] FlashInfer, 동적 토큰 페이지 커널 도입으로 TRTLLM-GEN GQA 성능 최적화
- 현재글 : [sglang] SGLang의 Breakable CUDA Graph 최적화: 배치 사이즈 제한 극복하기
- 다음글 [vllm] vLLM, DeepSeek-V4 K 캐시 커널 최적화: CuteDSL 도입으로 성능 향상
댓글