본문으로 건너뛰기

[sglang] LTX-2 모델 성능 최적화: NPU 및 GPU에서의 지연 시간 단축 분석

PR 링크: sgl-project/sglang#22445 상태: Merged | 변경: +47 / -15

들어가며

최근 sglang 프로젝트의 LTX-2 모델에서 End-to-End (E2E) 지연 시간을 획기적으로 단축하는 Pull Request (PR)가 병합되었습니다. 이 PR은 특히 NPU 환경에서 약 27%의 성능 향상을, GPU 환경에서도 약 3%의 성능 향상을 가져왔습니다. 본 글에서는 이 PR이 어떤 코드 변경을 통해 이러한 성능 개선을 달성했는지, 그리고 그 원리가 무엇인지 심층적으로 분석하고 기술적인 인사이트를 공유하고자 합니다.

주요 개선 사항은 다음과 같습니다:

  1. torch.nn.RMSNorm을 커스텀 RMSNorm 구현으로 교체
  2. sgl_kernel_npufused_rmsnorm_without_weight 통합
  3. FlashAttention에서 마스크가 모두 1일 경우 None으로 설정하여 연산 최적화

이러한 변경들이 어떻게 실제 성능 향상으로 이어졌는지 코드 diff를 중심으로 자세히 살펴보겠습니다.

코드 분석

1. python/sglang/multimodal_gen/runtime/layers/layernorm.py 변경

이 파일에서는 RMSNorm 구현 방식에 대한 최적화가 이루어졌습니다. 특히 NPU 환경에서의 성능 향상을 위해 sgl_kernel_npu 라이브러리의 최적화된 함수를 활용하는 방안이 도입되었습니다.

Before:

기존에는 torch.nn.RMSNorm을 직접 사용하거나, CUDA 환경에서는 F.rms_norm을 사용하는 방식이었습니다. NPU 환경에 대한 별도의 최적화된 구현은 없었습니다.

# (기존 코드 - NPU 관련 최적화 없음)
# ...

class RMSNorm(nn.Module):
    def __init__(self, hidden_size, variance_epsilon=1e-6):
        super().__init__()
        self.hidden_size = hidden_size
        self.variance_epsilon = variance_epsilon
        self.weight = nn.Parameter(torch.ones(hidden_size))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # ... (RMSNorm 로직)

    def extra_repr(self) -> str:
        return f"hidden_size={self.hidden_size}, eps={self.variance_epsilon}"

# ...

After:

NPU 환경을 위한 fused_rmsnorm_without_weight 함수를 sgl_kernel_npu에서 임포트하고, 이를 활용하는 RMSNormNoWeight 커스텀 연산자(CustomOp)를 정의했습니다. 이 커스텀 연산자는 NPU 환경에서 fused_rmsnorm_without_weight를 호출하고, CUDA 및 CPU 환경에서는 기존의 F.rms_norm을 사용하도록 분기 처리합니다.

@@ -41,6 +41,9 @@
 
 if _is_npu:
     import torch_npu
+    from sgl_kernel_npu.norm.rmsnorm_without_weight import (
+        fused_rmsnorm_without_weight,
+    )
 
 if _is_musa:
     from sgl_kernel import fused_add_rmsnorm
@@ -292,6 +295,18 @@ def extra_repr(self) -> str:
         return f"hidden_size={self.hidden_size}, eps={self.variance_epsilon}"
 
+
+@CustomOp.register("rms_norm_no_weight")
+class RMSNormNoWeight(CustomOp):
+    def forward_native(self, x: torch.Tensor, eps: float) -> torch.Tensor:
+        return F.rms_norm(x, normalized_shape=(x.shape[-1],), eps=eps)
+
+    def forward_cuda(self, x: torch.Tensor, eps: float) -> torch.Tensor:
+        return self.forward_native(x, eps=eps)
+
+    def forward_npu(self, x: torch.Tensor, eps: float) -> torch.Tensor:
+        return fused_rmsnorm_without_weight(x, eps)
+
 
 # Copied and adapted from sglang
 @CustomOp.register("layer_norm")

2. python/sglang/multimodal_gen/runtime/models/dits/ltx_2.py 변경

LTX-2 모델의 구체적인 구현 파일에서도 RMSNorm 사용 방식이 변경되었습니다. 또한, NPU 환경에서 FlashAttention의 성능 저하를 방지하기 위한 로직이 추가되었습니다.

Before:

모델 내부에서는 torch.nn.RMSNorm 대신 F.rms_norm을 직접 호출하는 헬퍼 함수 rms_norm을 사용하고 있었습니다. 이 함수는 torch.nn.RMSNorm과 유사하게 동작합니다.

@@ -35,9 +36,14 @@
     LayerwiseOffloadableModuleMixin,
 )
 from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT
-from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum
+from sglang.multimodal_gen.runtime.platforms import (
+    AttentionBackendEnum,
+    current_platform,
+)
 from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
 
+_is_npu = current_platform.is_npu()
+
 logger = init_logger(__name__)
 
 ADALN_NUM_BASE_PARAMS = 6
@@ -393,10 +399,6 @@ def forward(
         return cos_freqs.to(dtype=out_dtype), sin_freqs.to(dtype=out_dtype)
 
 
-def rms_norm(x: torch.Tensor, eps: float) -> torch.Tensor:
-    return F.rms_norm(x, normalized_shape=(x.shape[-1],), eps=eps)
-
-
 class LTX2TextProjection(nn.Module):
     def __init__(
         self,
@@ -839,6 +841,7 @@ def __init__(
         super().__init__()
         self.idx = idx
         self.norm_eps = norm_eps
+        self.rms_norm = RMSNormNoWeight()
         # LTX2.3
         self.cross_attention_adaln = cross_attention_adaln
         self.use_local_av_cross_attention = use_local_av_cross_attention
@@ -1015,7 +1018,7 @@ def forward(
         self.scale_shift_table, batch_size, temb, slice(0, 3)
         )
         norm_hidden_states = (
-            rms_norm(hidden_states, self.norm_eps) * (1 + vscale_msa) + vshift_msa
+            self.rms_norm(hidden_states, self.norm_eps) * (1 + vscale_msa) + vshift_msa
         )
         attn_hidden_states = self.attn1(
             norm_hidden_states,
@@ -1031,7 +1034,8 @@ def forward(
             self.audio_scale_shift_table, batch_size, temb_audio, slice(0, 3)
         )
         norm_audio_hidden_states = (
-            rms_norm(audio_hidden_states, self.norm_eps) * (1 + ascale_msa) + ashift_msa
+            self.rms_norm(audio_hidden_states, self.norm_eps) * (1 + ascale_msa)
+            + ashift_msa
         )
         attn_audio_hidden_states = self.audio_attn1(
             norm_audio_hidden_states,
@@ -1056,7 +1060,7 @@ def forward(
                 self.prompt_scale_shift_table, batch_size, temb_prompt, slice(None)
             )
             norm_hidden_states = (
-                rms_norm(hidden_states, self.norm_eps) * (1 + vscale_q) + vshift_q
+                self.rms_norm(hidden_states, self.norm_eps) * (1 + vscale_q) + vshift_q
             )
             mod_encoder_hidden_states = (
                 encoder_hidden_states * (1 + v_prompt_scale) + v_prompt_shift
@@ -1078,7 +1082,8 @@ def forward(
                 slice(None),
             )
             norm_audio_hidden_states = (
-                rms_norm(audio_hidden_states, self.norm_eps) * (1 + ascale_q) + ashift_q
+                self.rms_norm(audio_hidden_states, self.norm_eps) * (1 + ascale_q)
+                + ashift_q
             )
             mod_audio_encoder_hidden_states = (
                 audio_encoder_hidden_states * (1 + a_prompt_scale) + a_prompt_shift
@@ -1092,7 +1097,7 @@ def forward(
                 audio_hidden_states + attn_audio_hidden_states * agate_q
             )
         else:
-            norm_hidden_states = rms_norm(hidden_states, self.norm_eps)
+            norm_hidden_states = self.rms_norm(hidden_states, self.norm_eps)
             attn_hidden_states = self.attn2(
                 norm_hidden_states,
                 context=encoder_hidden_states,
@@ -1100,7 +1105,7 @@ def forward(
             )
             hidden_states = hidden_states + attn_hidden_states
 
-            norm_audio_hidden_states = rms_norm(audio_hidden_states, self.norm_eps)
+            norm_audio_hidden_states = self.rms_norm(audio_hidden_states, self.norm_eps)
             attn_audio_hidden_states = self.audio_attn2(
                 norm_audio_hidden_states,
                 context=audio_encoder_hidden_states,
@@ -1108,8 +1113,8 @@ def forward(
             )
             audio_hidden_states = audio_hidden_states + attn_audio_hidden_states
         # 3. Audio-to-Video and Video-to-Audio Cross-Attention
-        norm_hidden_states = rms_norm(hidden_states, self.norm_eps)
-        norm_audio_hidden_states = rms_norm(audio_hidden_states, self.norm_eps)
+        norm_hidden_states = self.rms_norm(hidden_states, self.norm_eps)
+        norm_audio_hidden_states = self.rms_norm(audio_hidden_states, self.norm_eps)
 
         # Compute combined ada params
         video_per_layer_ca_scale_shift = self.video_a2v_cross_attn_scale_shift_table[
@@ -1221,7 +1226,7 @@ def forward(
             self.scale_shift_table, batch_size, temb, slice(3, 6)
         )
         norm_hidden_states = (
-            rms_norm(hidden_states, self.norm_eps) * (1 + vscale_mlp) + vshift_mlp
+            self.rms_norm(hidden_states, self.norm_eps) * (1 + vscale_mlp) + vshift_mlp
         )
         ff_output = self.ff(norm_hidden_states)
         hidden_states = hidden_states + ff_output * vgate_mlp
@@ -1230,7 +1235,8 @@ def forward(
             self.audio_scale_shift_table, batch_size, temb_audio, slice(3, 6)
         )
         norm_audio_hidden_states = (
-            rms_norm(audio_hidden_states, self.norm_eps) * (1 + ascale_mlp) + ashift_mlp
+            self.rms_norm(audio_hidden_states, self.norm_eps) * (1 + ascale_mlp)
+            + ashift_mlp
         )
         audio_ff_output = self.audio_ff(norm_audio_hidden_states)
         audio_hidden_states = audio_hidden_states + audio_ff_output * agate_mlp
@@ -1813,6 +1819,17 @@ def forward(
             audio_encoder_hidden_states = self.audio_caption_projection(
                 audio_encoder_hidden_states
             )
+
+        if _is_npu:
+            # If the 'encoder_attention_mask' is provided and it is all ones,
+            # it can be set to 'None' to avoid the degradation of performance on the NPU side,
+            # where the mask, even though it has no affect,
+            # can lead to the introduction of multiple small operators.
+            if encoder_attention_mask is not None and torch.all(
+                encoder_attention_mask == 1
+            ):
+                encoder_attention_mask = None
+
         # 5. Run blocks
         skip_video_self_attn_blocks = set(skip_video_self_attn_blocks or ()) 
         skip_audio_self_attn_blocks = set(skip_audio_self_attn_blocks or ()) 

After:

  1. RMSNorm 교체: 기존의 rms_norm 헬퍼 함수가 제거되고, 새로 정의된 RMSNormNoWeight 커스텀 연산자(self.rms_norm)가 인스턴스화되어 사용됩니다. 이를 통해 NPU 환경에서 최적화된 fused_rmsnorm_without_weight 함수를 호출할 수 있게 됩니다.
    @@ -839,6 +841,7 @@ def __init__(
         super().__init__()
         self.idx = idx
         self.norm_eps = norm_eps
    
  •    self.rms_norm = RMSNormNoWeight()
       # LTX2.3
       self.cross_attention_adaln = cross_attention_adaln
       self.use_local_av_cross_attention = use_local_av_cross_attention
    
    모든 `rms_norm(...)` 호출이 `self.rms_norm(...)`으로 변경되었습니다.
    
    
  1. FlashAttention 마스크 최적화: NPU 환경(_is_npu가 True일 때)에서 encoder_attention_maskNone이 아니고 모든 값이 1인 경우, 이를 None으로 설정합니다. 이는 NPU 하드웨어에서 모든 값이 1인 마스크가 오히려 불필요한 연산 오버헤드를 유발하고 작은 연산자들로 분해되는 성능 저하를 일으킬 수 있기 때문입니다. 마스크가 모든 토큰을 포함하도록 하는 경우, 이를 명시적으로 None으로 설정함으로써 FlashAttention이 더 효율적인 연산 경로를 타도록 유도합니다.
    @@ -1813,6 +1819,17 @@ def forward(
             audio_encoder_hidden_states = self.audio_caption_projection(
                 audio_encoder_hidden_states
             )
    
  •    if _is_npu:
    
  •        # If the 'encoder_attention_mask' is provided and it is all ones,
    
  •        # it can be set to 'None' to avoid the degradation of performance on the NPU side,
    
  •        # where the mask, even though it has no affect,
    
  •        # can lead to the introduction of multiple small operators.
    
  •        if encoder_attention_mask is not None and torch.all(
    
  •            encoder_attention_mask == 1
    
  •        ):
    
  •            encoder_attention_mask = None
    
  •    # 5. Run blocks
       skip_video_self_attn_blocks = set(skip_video_self_attn_blocks or ()) 
       skip_audio_self_attn_blocks = set(skip_audio_self_attn_blocks or ()) 
    

리뷰 피드백 반영

리뷰 과정에서 _is_npu와 같은 플랫폼 종속적인 코드를 제거하고, 이를 CustomOp으로 추상화하자는 제안이 있었습니다. ping1jing2님과 Makcum888e님의 피드백에 따라, RMSNormNoWeight라는 새로운 CustomOp을 생성하여 NPU 관련 로직을 캡슐화했습니다. 이로써 코드의 가독성과 유지보수성이 향상되었습니다.

왜 이게 좋은가?

1. NPU 성능 극대화

  • 커스텀 RMSNorm 및 퓨전 연산: torch.nn.RMSNorm 대신 sgl_kernel_npu에서 제공하는 fused_rmsnorm_without_weight를 사용함으로써, NPU 하드웨어에 최적화된 연산이 가능해졌습니다. 이는 연산의 융합(fusion)을 통해 메모리 접근을 줄이고 연산 효율성을 높여 NPU에서의 지연 시간을 크게 단축시킵니다. PR 설명에 따르면 NPU에서 E2E 지연 시간이 102.74s에서 75.69s로 약 27% 감소했습니다.
  • FlashAttention 마스크 최적화: NPU에서 모든 값이 1인 마스크가 오히려 성능 저하를 유발하는 특성을 파악하고, 이를 None으로 처리하여 FlashAttention 연산이 최적의 경로로 실행되도록 했습니다. 이는 특히 NPU 아키텍처의 세부 사항을 이해하고 활용한 결과입니다.

2. GPU 성능 향상

  • FlashAttention 마스크 최적화: 비록 NPU만큼 극적인 효과는 아니지만, GPU에서도 유사한 최적화가 적용되어 E2E 지연 시간이 136.62s에서 131.70s로 약 3% 감소했습니다. 이는 하드웨어별 특성을 고려한 최적화가 범용적인 성능 개선으로 이어질 수 있음을 보여줍니다.

3. 코드 품질 및 유지보수성 향상

  • CustomOp 도입: 리뷰어의 제안에 따라 RMSNormNoWeightCustomOp으로 구현함으로써, 플랫폼별 구현을 깔끔하게 분리했습니다. 이는 코드의 의존성을 낮추고, 향후 다른 하드웨어 지원이나 RMSNorm 구현 변경 시 유지보수를 용이하게 합니다.
  • 가독성 증가: _is_npu와 같은 조건부 로직이 CustomOp 내부로 캡슐화되면서, 모델의 핵심 로직(ltx_2.py)은 더 간결해지고 가독성이 높아졌습니다.

일반적인 교훈

  • 하드웨어별 최적화의 중요성: 딥러닝 모델의 성능은 하드웨어 아키텍처에 크게 좌우됩니다. 특히 NPU와 같은 특수 목적 하드웨어에서는 해당 하드웨어에 최적화된 라이브러리나 연산 방식을 활용하는 것이 필수적입니다.
  • 연산자 퓨전(Operator Fusion)의 효과: 여러 개의 작은 연산을 하나의 큰 연산으로 묶는 퓨전 기법은 메모리 대역폭 사용량을 줄이고 캐시 효율성을 높여 성능을 크게 향상시킬 수 있습니다. RMSNorm 퓨전이 대표적인 예입니다.
  • 라이브러리/프레임워크의 특성 이해: FlashAttention과 같은 고급 연산은 특정 조건(예: 마스크 형태)에서 성능 특성이 달라질 수 있습니다. 이러한 라이브러리의 세부 동작 방식을 이해하고 활용하는 것이 중요합니다.
  • 리뷰를 통한 코드 품질 향상: 코드 리뷰는 단순히 버그를 찾는 것을 넘어, 아키텍처 개선, 플랫폼 종속성 제거 등 코드 품질과 유지보수성을 높이는 데 중요한 역할을 합니다.

결론

이번 PR은 LTX-2 모델의 성능을 NPU와 GPU 환경 모두에서 개선하는 중요한 작업을 수행했습니다. 커스텀 RMSNorm 구현과 NPU 특화 최적화, 그리고 FlashAttention 마스크 처리 로직 개선을 통해 E2E 지연 시간을 효과적으로 단축했습니다. 또한, 코드 리뷰 과정을 통해 플랫폼 종속성을 제거하고 CustomOp을 도입하여 코드의 품질과 확장성을 높였습니다. 이러한 최적화 사례는 딥러닝 모델 개발 시 하드웨어 특성을 고려한 연산자 수준의 최적화가 얼마나 중요한지를 다시 한번 보여줍니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글