[sglang] [SGLang] VAE 병렬 디코딩 최적화: CFG 병렬화와의 시너지 분석
PR 링크: sgl-project/sglang#27875 상태: Merged | 변경: +168 / -25
들어가며: VAE 디코딩의 병목과 CFG 병렬화의 한계
Diffusion 모델, 특히 최근의 Wan2.1이나 Qwen2-VL과 같은 고해상도 비디오/이미지 생성 모델에서 VAE(Variational Autoencoder) 디코딩은 전체 파이프라인에서 상당한 연산 비중을 차지합니다. Latent 공간에서 픽셀 공간으로 복원하는 과정은 고해상도일수록 메모리 점유율이 높고 시간이 오래 걸립니다.
기존 SGLang의 구조에서는 CFG(Classifier-Free Guidance) 병렬화를 사용할 때, VAE 디코딩 단계가 종종 MAIN_RANK_ONLY로 설정되어 특정 Rank에서만 실행되거나, 병렬화 로직이 SP(Sequence Parallelism) 그룹에만 강하게 결합되어 있었습니다. 이는 멀티 GPU 환경에서도 VAE 단계에서 자원을 충분히 활용하지 못하는 병목을 야기했습니다.
이번 PR([diffusion] optimize: enable vae parallel decode with cfg-parallel)은 이러한 구조적 제약을 해결하여, CFG 병렬화 환경에서도 모든 Rank가 VAE 디코딩에 참여할 수 있도록 최적화한 사례입니다.
핵심 변경 사항 분석
1. 추상화된 병렬 코디네이터 도입 (parallel_state.py)
가장 중요한 변화는 VAE 병렬화를 위한 그룹 결정 로직을 추상화한 것입니다. 기존에는 VAE 병렬화가 SP 그룹과 동일시되었으나, 이제는 CFG 그룹도 활용할 수 있도록 get_decode_parallel_group_coordinator가 도입되었습니다.
Before:
기존에는 VAE 관련 함수들이 직접 get_sp_group()을 호출하여 SP 환경에서만 동작하도록 설계되었습니다.
After:
def get_decode_parallel_group_coordinator() -> GroupCoordinator:
sp_group = get_sp_group()
cfg_group = get_cfg_group()
# SP가 활성화되지 않았더라도 CFG 병렬화가 활성화되어 있다면 CFG 그룹을 사용
if sp_group.world_size == 1 and cfg_group.world_size > 1:
return cfg_group
return sp_group
이 변경을 통해 SP(시퀀스 병렬화)를 쓰지 않는 상황에서도 CFG 병렬화가 켜져 있다면, 두 Rank가 협력하여 VAE 디코딩을 수행할 수 있는 기반이 마련되었습니다.
2. VAE 유틸리티의 범용화 (wan_dist_utils.py)
Wan VAE와 같이 분산 처리가 필요한 모델의 내부 유틸리티들이 특정 병렬화 그룹(SP)에 의존하던 것을 새로운 코디네이터를 사용하도록 수정했습니다.
Before:
from sglang.multimodal_gen.runtime.distributed.parallel_state import (
get_sp_group,
get_sp_parallel_rank,
get_sp_world_size,
)
# ...
rank = get_sp_parallel_rank()
world_size = get_sp_world_size()
After:
from sglang.multimodal_gen.runtime.distributed.parallel_state import (
get_decode_parallel_group_coordinator,
get_decode_parallel_rank,
get_decode_parallel_world_size,
)
# ...
decode_group = get_decode_parallel_group_coordinator()
rank = get_decode_parallel_rank()
world_size = get_decode_parallel_world_size()
이로써 halo_exchange(경계 픽셀 교환)나 all_gather 연산이 SP 그룹뿐만 아니라 CFG 그룹 내에서도 올바르게 동작하게 되었습니다. 이는 병렬 디코딩 시 발생하는 아티팩트를 방지하는 핵심 로직입니다.
3. 디코딩 스테이지의 병렬 실행 전략 수정 (decoding.py)
실제 런타임에서 디코딩 스테이지가 어떤 방식으로 실행될지 결정하는 parallelism_type 로직이 개선되었습니다.
Before:
@property
def parallelism_type(self) -> StageParallelismType:
if get_global_server_args().enable_cfg_parallel:
return StageParallelismType.MAIN_RANK_ONLY
return StageParallelismType.REPLICATED
기존에는 CFG 병렬화가 켜지면 무조건 메인 Rank에서만 디코딩을 수행했습니다.
After:
@property
def parallelism_type(self) -> StageParallelismType:
server_args = get_global_server_args()
if server_args.enable_cfg_parallel:
if self._can_use_parallel_decode():
return StageParallelismType.REPLICATED
return StageParallelismType.MAIN_RANK_ONLY
return StageParallelismType.REPLICATED
def _can_use_parallel_decode(self) -> bool:
return (
model_parallel_is_initialized()
and get_decode_parallel_world_size() > 1
and self.vae.use_parallel_decode
)
이제 병렬 디코딩이 가능한 조건(_can_use_parallel_decode)을 만족하면 REPLICATED 모드로 동작합니다. 여기서 REPLICATED는 각 Rank가 동일한 작업을 하는 것이 아니라, 분산된 데이터를 각자 처리하는 병렬 모드로 동작함을 의미합니다.
왜 이게 좋은 최적화인가?
1. 압도적인 성능 향상
H200 2GPU 환경에서 Wan2.1-T2V 모델(832x480x81f)로 테스트한 결과는 놀랍습니다.
- Total Time: 6313.4 ms → 5384.2 ms (약 15% 감소)
- Decode Time: 2631.4 ms → 1814.8 ms (약 31% 감소)
- Peak Reserved Memory: 37.5 GB → 31.9 GB (약 15% 절감)
디코딩 시간만 놓고 보면 30% 이상의 성능 향상이 있었으며, 메모리 사용량까지 줄어들었습니다. 이는 한 GPU가 모든 픽셀을 디코딩하던 부담을 두 GPU가 나누어 가졌기 때문입니다.
2. 유연한 병렬화 전략
이 최적화의 핵심은 "병렬화 그룹의 재정의"에 있습니다. 기존에는 'CFG 병렬화'와 'VAE 병렬화'를 별개의 차원으로 보았으나, 이 PR은 "현재 가용한 GPU 자원(Rank)이 있다면, 그것이 어떤 목적으로 묶였든 디코딩에 동원한다"는 실용적인 접근을 취했습니다.
3. 오버헤드 고려 (Speed Mode)
PR 설명에 따르면, 9프레임과 같은 작은 출력물에서는 병렬 디코딩의 통신 오버헤드가 이득보다 클 수 있습니다. 개발자는 이를 인지하고 performance_mode=speed일 때만 이 최적화가 적극적으로 활성화되도록 제한하여, 일반적인 케이스에서의 회귀(Regression)를 방지했습니다.
마치며
이번 최적화는 분산 환경에서 자원을 어떻게 하면 '놀리지 않고' 끝까지 활용할 수 있는지를 잘 보여줍니다. 특히 VAE처럼 연산 집약적인 컴포넌트를 가진 모델에서는 이러한 그룹 코디네이션의 유연성이 전체 처리량(Throughput)과 지연 시간(Latency) 개선에 결정적인 역할을 합니다. SGLang의 이번 업데이트는 고해상도 비디오 생성 서비스를 운영하는 엔지니어들에게 매우 반가운 소식이 될 것입니다.
참고 자료
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
- [sglang] SGLang Diffusion 모델의 FP8 GEMM 최적화: 41.5% 성능 향상 달성
- [sglang] SGLang LTX-2 VAE 디코딩 성능 최적화: channels_last_3d 도입으로 4.5배 속도 향상
- [sglang] SGLang의 Spectral Progressive Diffusion 도입: 추론 속도 최대 2.78배 향상
- [sglang] SGLang의 Ideogram4 추론 성능 최적화: Denoising 루프 내 오버헤드 제거
- [sglang] [SGLang] LingBot 실시간 서빙 최적화: 카메라 컨디셔닝 캐싱과 전송 프로토콜 개선
PR Analysis 의 다른글
- 이전글 [sglang] SGLang PD-Disaggregation 최적화: Mori 백엔드에서의 증분 KV 전송 구현
- 현재글 : [sglang] [SGLang] VAE 병렬 디코딩 최적화: CFG 병렬화와의 시너지 분석
- 다음글 [sglang] SGLang, GPU 간 VAE 디코딩 최적화를 통한 이미지 생성 속도 향상
댓글