본문으로 건너뛰기

[sglang] sglang diffusion 모델 성능 향상: Cache-DiT와 torch.compile의 최적화된 적용 순서

PR 링크: sgl-project/sglang#25328 상태: Merged | 변경: +0 / -0

들어가며

최근 AI 이미지 생성 모델의 발전 속도는 놀랍습니다. 특히 Stable Diffusion과 같은 확산 모델(Diffusion Model)은 고품질 이미지를 생성하는 데 핵심적인 역할을 하고 있습니다. 이러한 모델을 효율적으로 서빙하는 것은 실제 서비스 환경에서 매우 중요하며, 이를 위해 다양한 최적화 기법이 연구되고 적용됩니다.

이번 글에서는 sglang 레포지토리의 한 Pull Request(PR)에서 이루어진 최적화에 대해 자세히 살펴보겠습니다. 이 PR은 sglang의 diffusion 모델 서빙 시 Cache-DiTtorch.compile의 적용 순서를 변경하여, 첫 번째 실제 요청의 지연 시간을 평균 43.77% 단축하는 놀라운 성능 개선을 이루었습니다. 왜 이러한 변경이 가능했는지, 그리고 이 최적화가 우리에게 주는 교훈은 무엇인지 코드 변경사항과 함께 분석해보겠습니다.

문제 정의: 잘못된 컴파일 순서로 인한 성능 저하

이 PR이 해결하고자 하는 핵심 문제는 Cache-DiTtorch.compile의 적용 순서가 잘못되어 발생했던 성능 저하입니다. Diffusion 모델은 여러 단계의 노이즈 제거(denoising) 과정을 거치는데, Cache-DiT는 이 과정에서 중간 결과(캐시)를 활용하여 계산량을 줄이는 기법입니다. torch.compile은 PyTorch 코드를 최적화된 코드로 컴파일하여 실행 속도를 높이는 기능입니다.

기존 코드에서는 torch.compileCache-DiT가 적용되기 전에 먼저 실행될 수 있었습니다. --warmup 옵션을 사용할 경우, 이 과정에서 torch.compileCache-DiT가 적용되지 않은 순수한 트랜스포머 경로를 컴파일하게 됩니다. 하지만 실제 요청이 들어오면 Cache-DiT가 적용된 경로를 사용하게 되는데, 이 경로는 워밍업 단계에서 미리 컴파일되지 않았기 때문에 첫 번째 실제 요청에서 상당한 지연이 발생했습니다. 즉, 워밍업이 실제 사용될 경로를 제대로 준비하지 못했던 것입니다.

코드 분석: 변경된 DenoisingStage 로직

핵심 변경은 python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py 파일에서 이루어졌습니다. 특히 DenoisingStage 클래스의 __init__ 메서드와 _prepare_denoising_loop 메서드, 그리고 새로운 헬퍼 메서드 _maybe_enable_cache_dit_and_torch_compile이 중요합니다.

1. DenoisingStage.__init__ 메서드 변경

__init__ 메서드에서는 torch.compile이 이미 적용되었는지 추적하기 위한 _torch_compiled_module_ids 세트가 추가되었습니다. 이는 동일한 모듈에 대해 torch.compile이 중복으로 호출되는 것을 방지하기 위함입니다.

Before:

        # misc
        self.profiler = None
        # cache-dit state (for delayed mounting and idempotent control)
        self._cache_dit_enabled = False
        self._cached_num_steps = None
        self._is_warmed_up = False
        self._extra_func_kwarg_names_cache: dict[int, tuple[bool, frozenset[str]]] = {}

After:

        # cache-dit state (for delayed mounting and idempotent control)
        self._cache_dit_enabled = False
        self._cached_num_steps = None
        self._torch_compiled_module_ids: set[int] = set()

        # misc
        self.profiler = None
        self._is_warmed_up = False
        self._extra_func_kwarg_names_cache: dict[int, tuple[bool, frozenset[str]]] = {}

2. _maybe_enable_torch_compile 메서드 수정

_maybe_enable_torch_compile 메서드는 torch.compile을 적용하는 로직을 담당합니다. 이 메서드 시작 부분에 Cache-DiT가 활성화되지 않았거나, 이미 컴파일된 모듈인 경우 torch.compile 적용을 건너뛰는 조건이 추가되었습니다. 이는 Cache-DiT가 먼저 마운트된 후에만 torch.compile이 실행되도록 보장하는 중요한 변경입니다.

Before:

        if envs.SGLANG_CACHE_DIT_ENABLED and not self._cache_dit_enabled:
            logger.debug("Deferring torch.compile until cache-dit is enabled")
            return
        module_id = id(module)
        if module_id in self._torch_compiled_module_ids:
            return

After: (이전 코드 블록은 사실상 After 코드의 일부로, Before에는 해당 로직이 없었음을 의미합니다. 즉, 이 로직이 새로 추가된 것입니다.)

        # ... (기존 로직) ...
        if envs.SGLANG_CACHE_DIT_ENABLED and not self._cache_dit_enabled:
            logger.debug("Deferring torch.compile until cache-dit is enabled")
            return
        module_id = id(module)
        if module_id in self._torch_compiled_module_ids:
            return

        compile_kwargs: dict[str, Any] = {"fullgraph": False, "dynamic": None}
        # ... (이하 생략) ...
        module.compile(**compile_kwargs)
        self._torch_compiled_module_ids.add(module_id)

3. 새로운 헬퍼 메서드: _maybe_enable_cache_dit_and_torch_compile

리뷰어의 제안에 따라 (@mickqian의 코멘트 참조), _maybe_enable_cache_dit_maybe_enable_torch_compile 로직이 _maybe_enable_cache_dit_and_torch_compile라는 단일 메서드로 통합되었습니다. 이 메서드는 Cache-DiT 활성화와 torch.compile 적용을 정확한 순서로 수행합니다.

After (새로 추가):

def _maybe_enable_cache_dit_and_torch_compile(
    self,
    num_inference_steps: int | tuple[int, int],
    batch: Req
) -> None:
    """Apply request-dependent transformer acceleration in trace-safe order."""
    self._maybe_enable_cache_dit(num_inference_steps, batch)
    for transformer in filter(None, [self.transformer, self.transformer_2]):
        self._maybe_enable_torch_compile(transformer)

4. _prepare_denoising_loop 메서드 변경

가장 중요한 변경은 _prepare_denoising_loop 메서드에서 발생합니다. 트랜스포머가 처음 로드될 때, Cache-DiT를 먼저 활성화하고 그 후에 torch.compile을 호출하도록 로직이 수정되었습니다. 또한, 트랜스포머가 이미 로드된 경우에도 _maybe_enable_cache_dit_and_torch_compile 메서드를 호출하여 Cache-DiTtorch.compile이 올바른 순서로 적용되도록 보장합니다.

Before:

        if not server_args.model_loaded["transformer"]:
            # FIXME: reuse more code
            loader = TransformerLoader()
            self.transformer = loader.load(
                server_args.model_paths["transformer"], server_args, "transformer"
            )
            # enable cache-dit before torch.compile (delayed mounting)
            self._maybe_enable_cache_dit(cache_dit_num_inference_steps, batch)
            self._maybe_enable_torch_compile(self.transformer)
            if pipeline:
                pipeline.add_module("transformer", self.transformer)
            server_args.model_loaded["transformer"] = True
        else:
            self._maybe_enable_cache_dit(cache_dit_num_inference_steps, batch)

After:

        transformer_was_loaded = server_args.model_loaded["transformer"]
        if not transformer_was_loaded:
            # FIXME: reuse more code
            loader = TransformerLoader()
            self.transformer = loader.load(
                server_args.model_paths["transformer"], server_args, "transformer"
            )

        self._maybe_enable_cache_dit_and_torch_compile(
            cache_dit_num_inference_steps, batch
        )

        if not transformer_was_loaded:
            if pipeline:
                pipeline.add_module("transformer", self.transformer)
            server_args.model_loaded["transformer"] = True

이 변경을 통해 Cache-DiT가 먼저 적용되고, 그 후에 torch.compile이 실행되도록 순서가 명확하게 보장됩니다. 이는 워밍업 단계에서 Cache-DiT가 적용된 경로가 컴파일되도록 하여, 실제 요청 시의 지연 시간을 크게 줄이는 효과를 가져옵니다.

5. _maybe_enable_cache_dit 메서드 수정

_maybe_enable_cache_dit 메서드 내부의 조건문도 수정되었습니다. 기존에는 batch.is_warmup인 경우 Cache-DiT를 비활성화했지만, 이제는 server_args.enable_torch_compile이 활성화된 경우 워밍업 중에도 Cache-DiT를 활성화할 수 있도록 변경되었습니다. 이는 torch.compile 워밍업이 Cache-DiT를 포함한 전체 경로를 준비하도록 하기 위함입니다.

Before:

        # check if cache-dit is enabled in config
        if not envs.SGLANG_CACHE_DIT_ENABLED or batch.is_warmup:
            return

After:

        # Keep cache-dit disabled for ordinary warmup, but allow torch.compile
        # warmup to mount cache-dit before Dynamo traces the transformer.
        if not envs.SGLANG_CACHE_DIT_ENABLED:
            return
        if batch.is_warmup and not self.server_args.enable_torch_compile:
            return

왜 이게 좋은가: 성능 개선과 일반화된 교훈

성능 수치

이 PR의 가장 큰 성과는 첫 번째 실제 요청의 DenoisingStage 시간 단축입니다. 테스트 결과에 따르면, 평균적으로 5.2370초에서 2.9449초로 43.77%의 지연 시간 감소를 보였습니다. 이는 약 1.77배의 속도 향상(Speedup: 77.83%)에 해당합니다.

Run Baseline Patched Reduction
r1 5.2040s 2.9329s 43.64%
r2 5.2066s 2.9935s 42.51%
r3 5.2888s 2.9220s 44.75%
r4 5.2033s 2.9528s 43.25%
r5 5.2824s 2.9234s 44.66%
Mean 5.2370s 2.9449s 43.77%

클라이언트 전체 지연 시간에서도 평균 2.89%의 감소가 있었지만, DenoisingStage 자체의 성능 개선이 훨씬 두드러졌습니다. 이는 DenoisingStage가 전체 요청 처리 시간에서 차지하는 비중이 크다는 것을 의미하며, 해당 부분의 최적화가 전체 성능에 미치는 영향이 지대함을 보여줍니다.

일반화된 교훈

  1. 최적화 기법의 적용 순서의 중요성: torch.compile과 같은 성능 최적화 도구는 적용되는 코드의 상태에 따라 성능이 달라질 수 있습니다. 특히, 동적으로 변경되거나 캐싱 로직이 포함된 경우, 최적화 도구가 적용되기 전에 해당 로직이 완전히 준비되고 활성화되어야 최상의 성능을 기대할 수 있습니다. 이 PR은 Cache-DiT라는 캐싱 로직이 torch.compile보다 먼저 적용되어야 함을 명확히 보여줍니다.
  2. 워밍업(Warmup) 전략의 재검토: 워밍업은 실제 요청 경로를 준비하는 데 사용되지만, 잘못 설계되면 오히려 성능 저하의 원인이 될 수 있습니다. 워밍업 시 컴파일되는 경로와 실제 요청 시 사용될 경로가 일치하도록 주의해야 합니다. 이 PR은 --warmup 옵션이 Cache-DiT가 적용되지 않은 경로를 컴파일하는 문제를 해결했습니다.
  3. 코드 모듈화 및 통합: 리뷰어의 피드백(@mickqian)처럼, 관련 로직을 하나의 메서드로 통합하는 것은 코드의 가독성과 유지보수성을 높이는 좋은 방법입니다. _maybe_enable_cache_dit_and_torch_compile 메서드는 Cache-DiT 활성화와 torch.compile 적용이라는 두 가지 관련 작업을 원자적으로, 그리고 올바른 순서로 처리하도록 하여 코드의 안정성을 높였습니다.
  4. 점진적 컴파일(Lazy Compilation)의 이해: torch.compile은 기본적으로 동적 그래프(dynamic graph)를 생성하거나, torch.compile(..., dynamic=True) 옵션을 통해 동적 연산을 지원합니다. 하지만 최적화된 코드가 실제 실행될 때까지 컴파일이 지연되는 특성 때문에, 코드 실행 흐름과 컴파일 시점을 정확히 이해하는 것이 중요합니다. 이 PR은 torch.compileCache-DiT가 마운트된 후에 실행되도록 하여, Cache-DiT의 이점을 살린 컴파일을 유도했습니다.

결론

sglang 레포지토리의 이 PR은 Cache-DiTtorch.compile의 적용 순서를 최적화함으로써 diffusion 모델의 첫 번째 요청 지연 시간을 획기적으로 단축했습니다. 이는 단순히 코드 몇 줄을 변경한 것이 아니라, 최적화 기법들의 상호작용과 워밍업 전략의 중요성을 깊이 이해한 결과입니다. 이러한 분석은 복잡한 AI 모델 서빙 시스템에서 성능을 극대화하기 위한 중요한 통찰을 제공합니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글