[sglang] sglang diffusion 모델 성능 향상: Cache-DiT와 torch.compile의 최적화된 적용 순서
PR 링크: sgl-project/sglang#25328 상태: Merged | 변경: +0 / -0
들어가며
최근 AI 이미지 생성 모델의 발전 속도는 놀랍습니다. 특히 Stable Diffusion과 같은 확산 모델(Diffusion Model)은 고품질 이미지를 생성하는 데 핵심적인 역할을 하고 있습니다. 이러한 모델을 효율적으로 서빙하는 것은 실제 서비스 환경에서 매우 중요하며, 이를 위해 다양한 최적화 기법이 연구되고 적용됩니다.
이번 글에서는 sglang 레포지토리의 한 Pull Request(PR)에서 이루어진 최적화에 대해 자세히 살펴보겠습니다. 이 PR은 sglang의 diffusion 모델 서빙 시 Cache-DiT와 torch.compile의 적용 순서를 변경하여, 첫 번째 실제 요청의 지연 시간을 평균 43.77% 단축하는 놀라운 성능 개선을 이루었습니다. 왜 이러한 변경이 가능했는지, 그리고 이 최적화가 우리에게 주는 교훈은 무엇인지 코드 변경사항과 함께 분석해보겠습니다.
문제 정의: 잘못된 컴파일 순서로 인한 성능 저하
이 PR이 해결하고자 하는 핵심 문제는 Cache-DiT와 torch.compile의 적용 순서가 잘못되어 발생했던 성능 저하입니다. Diffusion 모델은 여러 단계의 노이즈 제거(denoising) 과정을 거치는데, Cache-DiT는 이 과정에서 중간 결과(캐시)를 활용하여 계산량을 줄이는 기법입니다. torch.compile은 PyTorch 코드를 최적화된 코드로 컴파일하여 실행 속도를 높이는 기능입니다.
기존 코드에서는 torch.compile이 Cache-DiT가 적용되기 전에 먼저 실행될 수 있었습니다. --warmup 옵션을 사용할 경우, 이 과정에서 torch.compile은 Cache-DiT가 적용되지 않은 순수한 트랜스포머 경로를 컴파일하게 됩니다. 하지만 실제 요청이 들어오면 Cache-DiT가 적용된 경로를 사용하게 되는데, 이 경로는 워밍업 단계에서 미리 컴파일되지 않았기 때문에 첫 번째 실제 요청에서 상당한 지연이 발생했습니다. 즉, 워밍업이 실제 사용될 경로를 제대로 준비하지 못했던 것입니다.
코드 분석: 변경된 DenoisingStage 로직
핵심 변경은 python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py 파일에서 이루어졌습니다. 특히 DenoisingStage 클래스의 __init__ 메서드와 _prepare_denoising_loop 메서드, 그리고 새로운 헬퍼 메서드 _maybe_enable_cache_dit_and_torch_compile이 중요합니다.
1. DenoisingStage.__init__ 메서드 변경
__init__ 메서드에서는 torch.compile이 이미 적용되었는지 추적하기 위한 _torch_compiled_module_ids 세트가 추가되었습니다. 이는 동일한 모듈에 대해 torch.compile이 중복으로 호출되는 것을 방지하기 위함입니다.
Before:
# misc
self.profiler = None
# cache-dit state (for delayed mounting and idempotent control)
self._cache_dit_enabled = False
self._cached_num_steps = None
self._is_warmed_up = False
self._extra_func_kwarg_names_cache: dict[int, tuple[bool, frozenset[str]]] = {}
After:
# cache-dit state (for delayed mounting and idempotent control)
self._cache_dit_enabled = False
self._cached_num_steps = None
self._torch_compiled_module_ids: set[int] = set()
# misc
self.profiler = None
self._is_warmed_up = False
self._extra_func_kwarg_names_cache: dict[int, tuple[bool, frozenset[str]]] = {}
2. _maybe_enable_torch_compile 메서드 수정
_maybe_enable_torch_compile 메서드는 torch.compile을 적용하는 로직을 담당합니다. 이 메서드 시작 부분에 Cache-DiT가 활성화되지 않았거나, 이미 컴파일된 모듈인 경우 torch.compile 적용을 건너뛰는 조건이 추가되었습니다. 이는 Cache-DiT가 먼저 마운트된 후에만 torch.compile이 실행되도록 보장하는 중요한 변경입니다.
Before:
if envs.SGLANG_CACHE_DIT_ENABLED and not self._cache_dit_enabled:
logger.debug("Deferring torch.compile until cache-dit is enabled")
return
module_id = id(module)
if module_id in self._torch_compiled_module_ids:
return
After: (이전 코드 블록은 사실상 After 코드의 일부로, Before에는 해당 로직이 없었음을 의미합니다. 즉, 이 로직이 새로 추가된 것입니다.)
# ... (기존 로직) ...
if envs.SGLANG_CACHE_DIT_ENABLED and not self._cache_dit_enabled:
logger.debug("Deferring torch.compile until cache-dit is enabled")
return
module_id = id(module)
if module_id in self._torch_compiled_module_ids:
return
compile_kwargs: dict[str, Any] = {"fullgraph": False, "dynamic": None}
# ... (이하 생략) ...
module.compile(**compile_kwargs)
self._torch_compiled_module_ids.add(module_id)
3. 새로운 헬퍼 메서드: _maybe_enable_cache_dit_and_torch_compile
리뷰어의 제안에 따라 (@mickqian의 코멘트 참조), _maybe_enable_cache_dit와 _maybe_enable_torch_compile 로직이 _maybe_enable_cache_dit_and_torch_compile라는 단일 메서드로 통합되었습니다. 이 메서드는 Cache-DiT 활성화와 torch.compile 적용을 정확한 순서로 수행합니다.
After (새로 추가):
def _maybe_enable_cache_dit_and_torch_compile(
self,
num_inference_steps: int | tuple[int, int],
batch: Req
) -> None:
"""Apply request-dependent transformer acceleration in trace-safe order."""
self._maybe_enable_cache_dit(num_inference_steps, batch)
for transformer in filter(None, [self.transformer, self.transformer_2]):
self._maybe_enable_torch_compile(transformer)
4. _prepare_denoising_loop 메서드 변경
가장 중요한 변경은 _prepare_denoising_loop 메서드에서 발생합니다. 트랜스포머가 처음 로드될 때, Cache-DiT를 먼저 활성화하고 그 후에 torch.compile을 호출하도록 로직이 수정되었습니다. 또한, 트랜스포머가 이미 로드된 경우에도 _maybe_enable_cache_dit_and_torch_compile 메서드를 호출하여 Cache-DiT와 torch.compile이 올바른 순서로 적용되도록 보장합니다.
Before:
if not server_args.model_loaded["transformer"]:
# FIXME: reuse more code
loader = TransformerLoader()
self.transformer = loader.load(
server_args.model_paths["transformer"], server_args, "transformer"
)
# enable cache-dit before torch.compile (delayed mounting)
self._maybe_enable_cache_dit(cache_dit_num_inference_steps, batch)
self._maybe_enable_torch_compile(self.transformer)
if pipeline:
pipeline.add_module("transformer", self.transformer)
server_args.model_loaded["transformer"] = True
else:
self._maybe_enable_cache_dit(cache_dit_num_inference_steps, batch)
After:
transformer_was_loaded = server_args.model_loaded["transformer"]
if not transformer_was_loaded:
# FIXME: reuse more code
loader = TransformerLoader()
self.transformer = loader.load(
server_args.model_paths["transformer"], server_args, "transformer"
)
self._maybe_enable_cache_dit_and_torch_compile(
cache_dit_num_inference_steps, batch
)
if not transformer_was_loaded:
if pipeline:
pipeline.add_module("transformer", self.transformer)
server_args.model_loaded["transformer"] = True
이 변경을 통해 Cache-DiT가 먼저 적용되고, 그 후에 torch.compile이 실행되도록 순서가 명확하게 보장됩니다. 이는 워밍업 단계에서 Cache-DiT가 적용된 경로가 컴파일되도록 하여, 실제 요청 시의 지연 시간을 크게 줄이는 효과를 가져옵니다.
5. _maybe_enable_cache_dit 메서드 수정
_maybe_enable_cache_dit 메서드 내부의 조건문도 수정되었습니다. 기존에는 batch.is_warmup인 경우 Cache-DiT를 비활성화했지만, 이제는 server_args.enable_torch_compile이 활성화된 경우 워밍업 중에도 Cache-DiT를 활성화할 수 있도록 변경되었습니다. 이는 torch.compile 워밍업이 Cache-DiT를 포함한 전체 경로를 준비하도록 하기 위함입니다.
Before:
# check if cache-dit is enabled in config
if not envs.SGLANG_CACHE_DIT_ENABLED or batch.is_warmup:
return
After:
# Keep cache-dit disabled for ordinary warmup, but allow torch.compile
# warmup to mount cache-dit before Dynamo traces the transformer.
if not envs.SGLANG_CACHE_DIT_ENABLED:
return
if batch.is_warmup and not self.server_args.enable_torch_compile:
return
왜 이게 좋은가: 성능 개선과 일반화된 교훈
성능 수치
이 PR의 가장 큰 성과는 첫 번째 실제 요청의 DenoisingStage 시간 단축입니다. 테스트 결과에 따르면, 평균적으로 5.2370초에서 2.9449초로 43.77%의 지연 시간 감소를 보였습니다. 이는 약 1.77배의 속도 향상(Speedup: 77.83%)에 해당합니다.
| Run | Baseline | Patched | Reduction |
|---|---|---|---|
| r1 | 5.2040s | 2.9329s | 43.64% |
| r2 | 5.2066s | 2.9935s | 42.51% |
| r3 | 5.2888s | 2.9220s | 44.75% |
| r4 | 5.2033s | 2.9528s | 43.25% |
| r5 | 5.2824s | 2.9234s | 44.66% |
| Mean | 5.2370s | 2.9449s | 43.77% |
클라이언트 전체 지연 시간에서도 평균 2.89%의 감소가 있었지만, DenoisingStage 자체의 성능 개선이 훨씬 두드러졌습니다. 이는 DenoisingStage가 전체 요청 처리 시간에서 차지하는 비중이 크다는 것을 의미하며, 해당 부분의 최적화가 전체 성능에 미치는 영향이 지대함을 보여줍니다.
일반화된 교훈
- 최적화 기법의 적용 순서의 중요성:
torch.compile과 같은 성능 최적화 도구는 적용되는 코드의 상태에 따라 성능이 달라질 수 있습니다. 특히, 동적으로 변경되거나 캐싱 로직이 포함된 경우, 최적화 도구가 적용되기 전에 해당 로직이 완전히 준비되고 활성화되어야 최상의 성능을 기대할 수 있습니다. 이 PR은Cache-DiT라는 캐싱 로직이torch.compile보다 먼저 적용되어야 함을 명확히 보여줍니다. - 워밍업(Warmup) 전략의 재검토: 워밍업은 실제 요청 경로를 준비하는 데 사용되지만, 잘못 설계되면 오히려 성능 저하의 원인이 될 수 있습니다. 워밍업 시 컴파일되는 경로와 실제 요청 시 사용될 경로가 일치하도록 주의해야 합니다. 이 PR은
--warmup옵션이Cache-DiT가 적용되지 않은 경로를 컴파일하는 문제를 해결했습니다. - 코드 모듈화 및 통합: 리뷰어의 피드백(
@mickqian)처럼, 관련 로직을 하나의 메서드로 통합하는 것은 코드의 가독성과 유지보수성을 높이는 좋은 방법입니다._maybe_enable_cache_dit_and_torch_compile메서드는Cache-DiT활성화와torch.compile적용이라는 두 가지 관련 작업을 원자적으로, 그리고 올바른 순서로 처리하도록 하여 코드의 안정성을 높였습니다. - 점진적 컴파일(Lazy Compilation)의 이해:
torch.compile은 기본적으로 동적 그래프(dynamic graph)를 생성하거나,torch.compile(..., dynamic=True)옵션을 통해 동적 연산을 지원합니다. 하지만 최적화된 코드가 실제 실행될 때까지 컴파일이 지연되는 특성 때문에, 코드 실행 흐름과 컴파일 시점을 정확히 이해하는 것이 중요합니다. 이 PR은torch.compile이Cache-DiT가 마운트된 후에 실행되도록 하여,Cache-DiT의 이점을 살린 컴파일을 유도했습니다.
결론
sglang 레포지토리의 이 PR은 Cache-DiT와 torch.compile의 적용 순서를 최적화함으로써 diffusion 모델의 첫 번째 요청 지연 시간을 획기적으로 단축했습니다. 이는 단순히 코드 몇 줄을 변경한 것이 아니라, 최적화 기법들의 상호작용과 워밍업 전략의 중요성을 깊이 이해한 결과입니다. 이러한 분석은 복잡한 AI 모델 서빙 시스템에서 성능을 극대화하기 위한 중요한 통찰을 제공합니다.
참고 자료
- https://pytorch.org/docs/stable/generated/torch.compile.html
- https://github.com/sgl-project/sglang/blob/main/python/sglang/utils/env.py
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
- [sglang] sglang 성능 최적화: torch.compile 퓨전 복원을 통한 TopK 후처리 개선
- [sglang] LTX2.3 HQ Denoising 성능 최적화: Attention Skip을 활용한 효율적인 모델 호출
- [transformers] Hugging Face Transformers: MoE 및 FP8 커널 최적화를 통한 성능 향상
- [cpython] Python subprocess.communicate() 타임아웃 성능 개선: 느린 자식 프로세스 응답 방식 변경
- [cpython] Python `subprocess` 테스트 최적화: `communicate()` 타임아웃 테스트 속도 향상
PR Analysis 의 다른글
- 이전글 [triton] Triton 커널 최적화: Mask Sorting을 통한 Reduction 연산 가속화
- 현재글 : [sglang] sglang diffusion 모델 성능 향상: Cache-DiT와 torch.compile의 최적화된 적용 순서
- 다음글 [sglang] SGLang의 MLA KV 캐시 쓰기 최적화: TMA Bulk-Store 도입
댓글