[sglang] SGLang에서 torch.compile을 활용한 Wan 모델 추론 가속화
PR 링크: sgl-project/sglang#25256 상태: Merged | 변경: +0 / -0
들어가며
최근 대규모 언어 모델 및 멀티모달 모델의 추론 성능 최적화는 엔지니어링의 핵심 과제입니다. 특히 SGLang 프로젝트에서 지원하는 Wan 모델의 경우, 복잡한 연산 그래프로 인해 추론 지연 시간이 발생할 수 있습니다. 본 PR은 torch.compile을 전략적으로 도입하여 MUSA(MT S5000) 및 CUDA(H200) 환경에서 MulAdd 및 LayerNorm 연산의 효율성을 극대화하고, 추론 속도를 개선하는 것을 목표로 합니다.
코드 분석
1. MUSA 디바이스를 위한 torch.compile 적용 (elementwise.py)
elementwise.py 파일에서는 forward_musa 메서드에 @torch.compile 데코레이터를 추가하여 MUSA 환경에서의 연산 그래프를 최적화했습니다.
@torch.compile
def forward_musa(
self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, k: int = 0
):
return self.forward_native(a, b, c, k=k)
2. LayerNorm 연산 최적화 (layernorm.py)
layernorm.py에서는 NPU 환경을 제외한 나머지 환경에서 forward_native 메서드에 torch.compile을 적용했습니다. 이는 PyTorch의 JIT 컴파일 기능을 활용하여 커널 융합(Kernel Fusion)을 유도합니다.
@torch.compile(disable=current_platform.is_npu())
def forward_native(
self, residual: torch.Tensor, ...
):
# ...
3. Wan 모델 데이터 레이아웃 정렬 (wanvideo.py)
wanvideo.py에서는 텐서 연산 후 메모리 레이아웃을 contiguous()하게 변경하여 후속 연산에서의 메모리 접근 효율을 높였습니다.
# Before
hidden_states = hidden_states.flatten(2).transpose(1, 2)
# After
hidden_states = hidden_states.flatten(2).transpose(1, 2).contiguous()
왜 이게 좋은가
성능 향상 수치
- MUSA (MT S5000): Wan2.2-T2V-A14B 모델 기준, 단일 스텝 추론 속도가 약 1.09배 (13.14ms -> 11.97ms) 향상되었습니다.
- CUDA (H200): Wan2.1-T2V-1.3B 모델 기준, 약 1.05배 (0.3381ms -> 0.3206ms) 향상되었습니다.
기술적 교훈
- Kernel Fusion의 힘:
torch.compile은 파이썬 수준의 연산을 그래프로 묶어 CUDA/MUSA 커널로 변환합니다. 특히 LayerNorm과 같은 빈번한 요소별 연산에서 오버헤드를 획기적으로 줄여줍니다. - 메모리 연속성(Contiguity):
transpose이후contiguous()를 호출하는 것은 메모리 단편화를 방지하고, 후속 레이어에서 커널이 텐서를 효율적으로 읽을 수 있게 합니다. 리뷰어들의 논의처럼 특정 상황에서는 자동 처리되기도 하지만, 명시적인 호출은 성능 안정성을 보장합니다. - 플랫폼별 분기 처리:
is_musa()나is_npu()와 같은 환경 체크를 통해 특정 하드웨어에서만 최적화가 동작하도록 제한함으로써, 범용성과 성능 사이의 균형을 맞추는 것이 중요합니다.
리뷰 피드백 반영
리뷰 과정에서 wanvideo.py의 contiguous() 호출에 대해 "자동으로 처리될 수 있다"는 의견이 있었으나, 코드의 명시성을 높이고 잠재적인 성능 저하를 방지하기 위해 유지되었습니다. 또한, MUSA 전용 최적화가 다른 플랫폼에 영향을 주지 않도록 is_musa 체크를 추가하는 방향으로 논의가 마무리되었습니다.
참고 자료
- https://pytorch.org/docs/stable/generated/torch.compile.html
- https://pytorch.org/docs/stable/generated/torch.Tensor.contiguous.html
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
- [sglang] SGLang의 MLA KV 캐시 쓰기 최적화: TMA Bulk-Store 도입
- [sglang] SGLang의 MHC 파이프라인 최적화: 커널 퓨전과 DeepGemm 도입
- [sglang] SGLang 성능 최적화: torch.cuda.empty_cache() 호출 제어를 통한 가중치 업데이트 병목 해결
- [sglang] SGLang Triton 커널 최적화: libdevice.tanh 도입과 2D Strided Tensor 지원
- [sglang] SGLang의 디코드 성능 향상을 위한 Temperature 및 Softmax 커널 융합
PR Analysis 의 다른글
- 이전글 [sglang] SGLang 멀티모달 파이프라인의 VAE 정밀도 최적화: bf16 도입을 통한 메모리 효율 개선
- 현재글 : [sglang] SGLang에서 torch.compile을 활용한 Wan 모델 추론 가속화
- 다음글 [sglang] SGLang, 레이어별 오프로딩 기본값 설정을 통한 인코더/VAE 성능 최적화
댓글