본문으로 건너뛰기

[sglang] LTX2 스플릿 로터리 커널 최적화: 헤드 배치 처리로 성능 2배 향상

PR 링크: sgl-project/sglang#24732 상태: Merged | 변경: +0 / -0

들어가며

최근 LLM 모델들은 점점 더 복잡하고 거대해지고 있으며, 이에 따라 추론 성능 최적화는 매우 중요한 과제가 되었습니다. 특히, LTX2와 같은 모델에서 사용되는 Rotary Positional Embedding (RoPE)은 어텐션 메커니즘의 핵심 요소로, 이 부분의 성능 개선은 전체 추론 속도에 큰 영향을 미칩니다.

이번 PR (#24732)은 SGLang 레포지토리에서 LTX2 모델의 스플릿 로터리 커널(split rotary kernel)을 최적화하여, 기존 대비 최대 1.94배의 성능 향상을 달성했습니다. 이 글에서는 해당 PR의 코드 변경 사항을 자세히 분석하고, 왜 이러한 변경이 성능 향상으로 이어졌는지, 그리고 이 최적화가 가지는 일반적인 교훈은 무엇인지 살펴보겠습니다.

코드 변경 분석

이번 PR의 핵심은 Triton 커널에서 개별 토큰/헤드 쌍마다 프로그램을 실행하던 방식을 개선하여, 하나의 프로그램에서 여러 개의 어텐션 헤드를 배치(batch) 처리하도록 변경한 것입니다.

python/sglang/jit_kernel/diffusion/triton/ltx2_rotary.py

가장 중요한 변경은 _ltx2_split_rotary_kernel 함수의 로직과 호출 방식입니다.

1. 커널 내부 로직 변경

Before:

def _ltx2_split_rotary_kernel(
    # ... (기존 인자들)
):
    pid_bt = tl.program_id(0)
    head = tl.program_id(1) # <-- 각 프로그램이 하나의 헤드를 처리
    batch = pid_bt // seq_len
    token = pid_bt - batch * seq_len
    offsets = tl.arange(0, BLOCK_HALF)
    mask = offsets < half_dim

    x_base = ((batch * seq_len + token) * num_heads + head) * head_dim
    # ... (기존 로딩 및 연산)

    tl.store(out_ptr + x_base + offsets, out_first, mask=mask)
    tl.store(out_ptr + x_base + half_dim + offsets, out_second, mask=mask)

After:

def _ltx2_split_rotary_kernel(
    # ... (기존 인자들)
    BLOCK_HEADS: tl.constexpr, # <-- 새로운 인자
    BLOCK_HALF: tl.constexpr,
):
    pid_bt = tl.program_id(0)
    head_block = tl.program_id(1) # <-- 이제 각 프로그램이 여러 헤드 블록을 처리
    batch = pid_bt // seq_len
    token = pid_bt - batch * seq_len
    heads = head_block * BLOCK_HEADS + tl.arange(0, BLOCK_HEADS) # <-- 처리할 헤드들을 계산
    offsets = tl.arange(0, BLOCK_HALF)
    mask = (heads[:, None] < num_heads) & (offsets[None, :] < half_dim) # <-- 헤드와 오프셋 모두에 대한 마스크

    x_base = ((batch * seq_len + token) * num_heads + heads[:, None]) * head_dim # <-- 헤드 차원 확장
    cos_base = (
        batch * stride_cos_b + heads[:, None] * stride_cos_h + token * stride_cos_t
    )
    sin_base = (
        batch * stride_sin_b + heads[:, None] * stride_sin_h + token * stride_sin_t
    )

    x_first = tl.load(x_ptr + x_base + offsets[None, :], mask=mask, other=0.0) # <-- 로딩 시 차원 확장
    x_second = tl.load(
        x_ptr + x_base + half_dim + offsets[None, :], mask=mask, other=0.0
    )
    cos = tl.load(cos_ptr + cos_base + offsets[None, :], mask=mask, other=0.0)
    sin = tl.load(sin_ptr + sin_base + offsets[None, :], mask=mask, other=0.0)

    # ... (기존 연산 로직은 동일하게 유지, 다만 브로드캐스팅 활용)

    tl.store(out_ptr + x_base + offsets[None, :], out_first, mask=mask) # <-- 저장 시 차원 확장
    tl.store(out_ptr + x_base + half_dim + offsets[None, :], out_second, mask=mask)

주요 변경점은 다음과 같습니다:

  • BLOCK_HEADS 상수 도입: 하나의 Triton 프로그램이 처리할 헤드의 수를 정의합니다. PR에서는 최대 16개의 헤드를 배치 처리하도록 설정했습니다 (block_heads = min(16, triton.next_power_of_2(num_heads))).
  • program_id(1)의 의미 변경: 이전에는 헤드 인덱스였지만, 이제는 헤드 블록 인덱스(head_block)가 됩니다.
  • heads 계산: head_blockBLOCK_HEADS를 사용하여 현재 프로그램이 처리해야 할 실제 헤드들의 범위를 계산합니다 (heads = head_block * BLOCK_HEADS + tl.arange(0, BLOCK_HEADS)).
  • 데이터 로딩 및 저장 시 차원 확장: x_base, cos_base, sin_base 계산 시 heads[:, None]와 같이 헤드 차원을 확장하여, 한 번의 로드/스토어 연산으로 여러 헤드의 데이터를 처리할 수 있도록 합니다. offsetsoffsets[None, :]로 확장하여 헤드와 오프셋 차원 모두에 대한 마스크를 적용합니다.
  • 마스크 로직 업데이트: 헤드 범위를 벗어나는 경우를 방지하기 위해 mask 계산 시 heads[:, None] < num_heads 조건을 추가했습니다.

2. 커널 호출 방식 변경

Before:

_ltx2_split_rotary_kernel[
    (batch * seq_len, num_heads) # <-- Grid: (batch*seq_len, num_heads)
](
    # ... (인자들)
    num_warps=1,
)

After:

block_heads = min(16, triton.next_power_of_2(num_heads))
num_warps = min(8, max(1, block_heads))
grid = (batch * seq_len, triton.cdiv(num_heads, block_heads)) # <-- Grid: (batch*seq_len, num_heads / block_heads)
_ltx2_split_rotary_kernel[grid](
    # ... (인자들)
    BLOCK_HEADS=block_heads, # <-- BLOCK_HEADS 전달
    num_warps=num_warps,
)
  • Grid 크기 변경: grid의 두 번째 차원이 num_heads에서 triton.cdiv(num_heads, block_heads)로 변경되었습니다. 이는 각 프로그램이 여러 헤드를 처리하므로, 필요한 프로그램 수를 줄이는 효과를 가져옵니다.
  • BLOCK_HEADS 전달: 커널 내부에서 사용할 BLOCK_HEADS 값을 커널 호출 시 전달합니다.
  • num_warps 조정: BLOCK_HEADS 값에 따라 num_warps를 동적으로 조정하여, GPU 하드웨어 활용률을 높이도록 했습니다.

왜 이게 좋은가?

성능 향상

이 PR은 Triton 커널의 실행 방식을 근본적으로 변경하여 상당한 성능 향상을 이루었습니다. H200 벤치마크 결과는 이를 명확히 보여줍니다:

Case Main This PR Speedup Delta
B1_S1024_H32_D128 37.952 us 25.232 us 1.50x -33.5%
B1_S4096_H16_D128 57.792 us 29.792 us 1.94x -48.5%
  • B1_S1024_H32_D128: 배치 크기 1, 시퀀스 길이 1024, 헤드 수 32, 헤드 차원 128 케이스에서 약 1.50배의 속도 향상을 보였습니다.
  • B1_S4096_H16_D128: 배치 크기 1, 시퀀스 길이 4096, 헤드 수 16, 헤드 차원 128 케이스에서는 무려 1.94배의 속도 향상을 기록했습니다.

이러한 성능 향상의 주된 이유는 다음과 같습니다:

  1. 커널 런치 오버헤드 감소: 이전에는 각 토큰-헤드 쌍마다 별도의 Triton 커널 프로그램이 실행되었습니다. 이는 특히 헤드 수가 많을 때 수많은 작은 커널 런치로 이어져 상당한 오버헤드를 발생시켰습니다. 헤드들을 배치 처리함으로써 커널 런치 횟수를 크게 줄여 이 오버헤드를 절감했습니다.
  2. GPU 활용률 증대: 하나의 프로그램에서 더 많은 작업을 처리하게 되면서, GPU의 병렬 처리 능력을 더 효과적으로 활용할 수 있게 되었습니다. 특히, num_warpsBLOCK_HEADS에 맞춰 조정함으로써 워프(warp) 수준의 병렬성을 최적화했습니다.
  3. 메모리 접근 패턴 개선: 여러 헤드의 데이터를 한 번에 로드하고 처리함으로써, 캐시 효율성을 높이고 메모리 접근 패턴을 개선했을 가능성이 있습니다. (정확한 캐시 효과는 더 깊은 분석이 필요하지만, 일반적으로 배치 처리는 이러한 이점을 가져옵니다.)

일반적인 교훈

이 PR은 다음과 같은 중요한 최적화 교훈을 제공합니다:

  • 작은 커널 런치 오버헤드의 중요성: GPU 프로그래밍에서는 개별 커널 런치 비용이 생각보다 클 수 있습니다. 가능한 한 많은 작업을 하나의 커널에 통합하여 런치 횟수를 줄이는 것이 성능 향상에 효과적입니다.
  • 작업 단위(Work Unit) 재정의: 기존의 작업 단위를 재검토하고, 더 큰 단위로 묶어 처리함으로써 병렬성 및 효율성을 높일 수 있습니다. 여기서는 '토큰-헤드 쌍'에서 '헤드 블록'으로 작업 단위를 확장했습니다.
  • 하드웨어 특성 고려: BLOCK_HEADSnum_warps의 동적 조정은 GPU 아키텍처의 특성(예: 워프 크기)을 고려하여 최적의 성능을 이끌어내는 좋은 예시입니다.
  • BF16 연산 순서 유지: PyTorch와의 호환성을 위해 BF16 연산 순서를 그대로 유지하면서 성능을 개선한 점은, 기존 코드와의 호환성을 유지하는 것이 얼마나 중요한지를 보여줍니다.

리뷰 댓글 분석

제공된 리뷰 댓글은 주로 CI 관련 재실행 요청이었습니다. 이는 코드 자체의 기술적 문제보다는, PR이 메인 브랜치와 통합되기 전에 CI 파이프라인을 통해 충분히 검증받아야 함을 시사합니다. 즉, 코드 변경 자체는 안정적이며, CI 통과 여부가 최종 통합의 관건이었음을 알 수 있습니다.

결론

이번 PR은 LTX2 모델의 스플릿 로터리 커널을 최적화하여, 헤드 배치 처리를 통해 커널 런치 오버헤드를 줄이고 GPU 활용률을 높임으로써 최대 1.94배의 성능 향상을 달성했습니다. 이는 LLM 추론 성능 최적화에 있어 Triton 커널 레벨에서의 세심한 최적화가 얼마나 중요한지를 보여주는 좋은 사례입니다. 이러한 최적화 기법은 다른 모델이나 연산에도 적용될 수 있으며, LLM 개발자들에게 유용한 통찰을 제공합니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글