본문으로 건너뛰기

[sglang] SGLang에서 torch.compile을 활용한 Wan 모델 추론 가속화

PR 링크: sgl-project/sglang#25256 상태: Merged | 변경: +0 / -0

들어가며

최근 대규모 언어 모델 및 멀티모달 모델의 추론 성능 최적화는 엔지니어링의 핵심 과제입니다. 특히 SGLang 프로젝트에서 지원하는 Wan 모델의 경우, 복잡한 연산 그래프로 인해 추론 지연 시간이 발생할 수 있습니다. 본 PR은 torch.compile을 전략적으로 도입하여 MUSA(MT S5000) 및 CUDA(H200) 환경에서 MulAdd 및 LayerNorm 연산의 효율성을 극대화하고, 추론 속도를 개선하는 것을 목표로 합니다.

코드 분석

1. MUSA 디바이스를 위한 torch.compile 적용 (elementwise.py)

elementwise.py 파일에서는 forward_musa 메서드에 @torch.compile 데코레이터를 추가하여 MUSA 환경에서의 연산 그래프를 최적화했습니다.

@torch.compile
def forward_musa(
    self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, k: int = 0
):
    return self.forward_native(a, b, c, k=k)

2. LayerNorm 연산 최적화 (layernorm.py)

layernorm.py에서는 NPU 환경을 제외한 나머지 환경에서 forward_native 메서드에 torch.compile을 적용했습니다. 이는 PyTorch의 JIT 컴파일 기능을 활용하여 커널 융합(Kernel Fusion)을 유도합니다.

@torch.compile(disable=current_platform.is_npu())
def forward_native(
    self, residual: torch.Tensor, ...
):
    # ...

3. Wan 모델 데이터 레이아웃 정렬 (wanvideo.py)

wanvideo.py에서는 텐서 연산 후 메모리 레이아웃을 contiguous()하게 변경하여 후속 연산에서의 메모리 접근 효율을 높였습니다.

# Before
hidden_states = hidden_states.flatten(2).transpose(1, 2)

# After
hidden_states = hidden_states.flatten(2).transpose(1, 2).contiguous()

왜 이게 좋은가

성능 향상 수치

  • MUSA (MT S5000): Wan2.2-T2V-A14B 모델 기준, 단일 스텝 추론 속도가 약 1.09배 (13.14ms -> 11.97ms) 향상되었습니다.
  • CUDA (H200): Wan2.1-T2V-1.3B 모델 기준, 약 1.05배 (0.3381ms -> 0.3206ms) 향상되었습니다.

기술적 교훈

  1. Kernel Fusion의 힘: torch.compile은 파이썬 수준의 연산을 그래프로 묶어 CUDA/MUSA 커널로 변환합니다. 특히 LayerNorm과 같은 빈번한 요소별 연산에서 오버헤드를 획기적으로 줄여줍니다.
  2. 메모리 연속성(Contiguity): transpose 이후 contiguous()를 호출하는 것은 메모리 단편화를 방지하고, 후속 레이어에서 커널이 텐서를 효율적으로 읽을 수 있게 합니다. 리뷰어들의 논의처럼 특정 상황에서는 자동 처리되기도 하지만, 명시적인 호출은 성능 안정성을 보장합니다.
  3. 플랫폼별 분기 처리: is_musa()is_npu()와 같은 환경 체크를 통해 특정 하드웨어에서만 최적화가 동작하도록 제한함으로써, 범용성과 성능 사이의 균형을 맞추는 것이 중요합니다.

리뷰 피드백 반영

리뷰 과정에서 wanvideo.pycontiguous() 호출에 대해 "자동으로 처리될 수 있다"는 의견이 있었으나, 코드의 명시성을 높이고 잠재적인 성능 저하를 방지하기 위해 유지되었습니다. 또한, MUSA 전용 최적화가 다른 플랫폼에 영향을 주지 않도록 is_musa 체크를 추가하는 방향으로 논의가 마무리되었습니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글