본문으로 건너뛰기

[sglang] SGLang의 Breakable CUDA Graph 최적화: 배치 사이즈 제한 극복하기

PR 링크: sgl-project/sglang#24662 상태: Merged | 변경: +0 / -0

들어가며

LLM 추론 엔진에서 CUDA Graph는 정적인 연산 그래프를 미리 기록하여 커널 실행 오버헤드를 줄이는 강력한 도구입니다. 하지만 기존 SGLang의 BreakableCudaGraphRunner는 배치 사이즈(bs)가 1인 경우에만 작동하도록 설계되어 있어, 멀티 배치 환경에서의 성능 이점을 충분히 활용하지 못하는 제약이 있었습니다. 이번 PR은 이 제한을 깨고, bs > 1인 상황에서도 CUDA Graph를 활용할 수 있도록 아키텍처를 재설계했습니다.

코드 분석

1. 모델 레이어 분리 및 캡처 범위 조정

기존에는 모델 전체의 forward를 캡처하려 했으나, 이는 logits_processor와 같이 배치 사이즈에 의존적인 연산까지 포함하게 되어 유연성을 떨어뜨렸습니다. 이제는 Transformer 레이어 스택만을 분리하여 캡처합니다.

# Before
# (기존에는 모델 전체를 캡처하여 bs=1에 고정됨)

# After
language_model = getattr(model_runner.model, "language_model", model_runner.model)
self.layer_model = (
    language_model.model
    if hasattr(language_model, "model") and hasattr(language_model.model, "layers")
    else language_model
)

2. Monkey-patching을 통한 동적 Replay

replay 시점에 layer_model.forward를 캡처된 그래프를 실행하는 클로저로 임시 교체(Monkey-patching)합니다. 이를 통해 외부의 logits_processor는 실제 멀티 배치 데이터를 처리하고, 내부의 연산은 최적화된 그래프를 재사용하게 됩니다.

# After
def replay_layer_forward(*args, **layer_kwargs):
    captured_graph.replay()
    return captured_hidden

original_layer_forward = self.layer_model.forward
self.layer_model.forward = replay_layer_forward

3. @torch.no_grad 적용

layer_model.forward를 직접 호출하면서 기존 모델 클래스에 있던 @torch.no_grad 데코레이터가 누락되는 문제를 해결했습니다. 이는 torch.compile 기반의 MoE 커널 실행 시 그래디언트 추적 오류를 방지합니다.

왜 이게 좋은가

이 최적화의 핵심은 '그래프 캡처의 범위(Scope)를 최적화'한 것입니다.

  1. 범위의 분리: 배치 사이즈에 민감한 로직(logits, pooler)을 그래프 밖으로 빼냄으로써, 그래프는 배치 사이즈와 무관한 '토큰 단위 연산'에만 집중할 수 있게 되었습니다. 이는 다양한 배치 사이즈 요청이 들어오는 실제 서빙 환경에서 매우 중요합니다.
  2. 유연성: can_run 메서드에서 forward_batch.batch_size > 1 체크를 제거하여, 이제 더 넓은 범위의 추론 요청을 그래프 가속 경로로 처리할 수 있습니다.
  3. 교훈: 복잡한 모델을 최적화할 때, 전체를 하나의 그래프로 묶으려 하기보다 '정적인 연산'과 '동적인 연산'을 분리하여 부분적으로 그래프를 적용하는 것이 훨씬 강력한 성능과 유연성을 제공합니다.

리뷰어 피드백 반영

리뷰어 merrymercygetattr을 통한 모듈 탐색이 문자열 기반의 매칭이라 다소 'hacky'하다고 지적했습니다. 이는 향후 모델 구조 변경 시 잠재적인 버그 포인트가 될 수 있으므로, 매칭 실패 시 명확한 경고(warning)를 출력하도록 보완하는 것이 권장됩니다.

References

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글