본문으로 건너뛰기

[SGLang] Diffusion Triton Rotary Embedding 다중 헤드 병렬 처리 최적화

PR 링크: sgl-project/sglang#21387 상태: Merged | 변경: +46 / -26

들어가며

Diffusion 모델에서 rotary position embedding(RoPE)은 모든 attention layer에서 호출되는 핵심 연산이다. 기존 Triton 커널은 (batch * num_tokens * num_heads) 크기의 1D grid로 launch되어, 각 program instance가 하나의 (토큰, 헤드) 쌍만 처리했다. 이는 헤드 수가 많을 때 grid 크기가 과도하게 커지고 스케줄링 오버헤드가 발생한다.

이 PR은 grid를 (batch * num_tokens, ceil(num_heads / BLOCK_HEADS)) 2D로 재구성하여, 각 instance가 BLOCK_HEADS개의 헤드를 동시에 처리한다.

핵심 코드 분석

Grid 차원 변경

Before:

x_reshaped = x.view(-1, head_size)  # [B*T*H, head_size]
grid = (bsz * num_tokens * num_heads,)  # 1D: 모든 (토큰,헤드) 조합

After:

x_reshaped = x.view(bsz * num_tokens, num_heads, head_size)  # [B*T, H, hs]
# 2D grid: (토큰, 헤드 블록)
_rotary_embedding_kernel[
    lambda META: (bsz * num_tokens, triton.cdiv(num_heads, META["BLOCK_HEADS"]))
](...)

커널 내부 다중 헤드 처리

Before:

row_idx = tl.program_id(0)
token_idx = (row_idx // num_heads) % num_tokens
x_row_ptr = x_ptr + row_idx * stride_x_row
# 1개 헤드의 x1, x2를 로드하여 RoPE 적용
x1_vals = tl.load(x_row_ptr + offsets_x1, mask=mask)

After:

bt_idx = tl.program_id(0)
head_block_idx = tl.program_id(1)
head_offsets = head_block_idx * BLOCK_HEADS + tl.arange(0, BLOCK_HEADS)
head_mask = head_offsets < num_heads

# cos/sin은 헤드 간 공유 (1D 로드)
cos_vals = tl.load(cos_row_ptr + offsets_half, mask=half_mask)

# x는 [BLOCK_HEADS, BLOCK_HS_HALF] 2D 로드
x1_vals = tl.load(x_row_ptrs + offsets_x1[None, :], mask=mask)
# Broadcasting: cos/sin [1, HS] * x [BH, HS]
cos_fp32 = cos_vals.to(tl.float32)[None, :]
o1_vals = tl.fma(-x2_fp32, sin_fp32, x1_fp32 * cos_fp32)

왜 이게 좋은가

  1. cos/sin 재사용: 같은 토큰 위치의 cos/sin 값을 BLOCK_HEADS개 헤드가 공유하므로 메모리 로드가 줄어든다.
  2. Autotune 확장: BLOCK_HEADS가 1, 2, 4, 8까지 자동 탐색되어 GPU별 최적 설정이 선택된다.
  3. Grid 크기 감소: 헤드 수 32 기준, BLOCK_HEADS=8이면 grid이 1/8로 줄어 launch 오버헤드가 감소한다.

정리

1D flatten grid를 2D grid로 전환하고 헤드 차원을 블록화하는 것은 Triton 커널 최적화의 기본 패턴이다. cos/sin broadcasting으로 메모리 효율도 동시에 개선한 깔끔한 변경이다.

참고 자료


이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 코드 분석과 해석에서 오류가 있을 수 있으니, 정확한 내용은 원본 PR을 참고해주세요.

댓글

관련 포스트

PR Analysis 의 다른글