[SGLang] Diffusion Triton Rotary Embedding 다중 헤드 병렬 처리 최적화
PR 링크: sgl-project/sglang#21387 상태: Merged | 변경: +46 / -26
들어가며
Diffusion 모델에서 rotary position embedding(RoPE)은 모든 attention layer에서 호출되는 핵심 연산이다. 기존 Triton 커널은 (batch * num_tokens * num_heads) 크기의 1D grid로 launch되어, 각 program instance가 하나의 (토큰, 헤드) 쌍만 처리했다. 이는 헤드 수가 많을 때 grid 크기가 과도하게 커지고 스케줄링 오버헤드가 발생한다.
이 PR은 grid를 (batch * num_tokens, ceil(num_heads / BLOCK_HEADS)) 2D로 재구성하여, 각 instance가 BLOCK_HEADS개의 헤드를 동시에 처리한다.
핵심 코드 분석
Grid 차원 변경
Before:
x_reshaped = x.view(-1, head_size) # [B*T*H, head_size]
grid = (bsz * num_tokens * num_heads,) # 1D: 모든 (토큰,헤드) 조합
After:
x_reshaped = x.view(bsz * num_tokens, num_heads, head_size) # [B*T, H, hs]
# 2D grid: (토큰, 헤드 블록)
_rotary_embedding_kernel[
lambda META: (bsz * num_tokens, triton.cdiv(num_heads, META["BLOCK_HEADS"]))
](...)
커널 내부 다중 헤드 처리
Before:
row_idx = tl.program_id(0)
token_idx = (row_idx // num_heads) % num_tokens
x_row_ptr = x_ptr + row_idx * stride_x_row
# 1개 헤드의 x1, x2를 로드하여 RoPE 적용
x1_vals = tl.load(x_row_ptr + offsets_x1, mask=mask)
After:
bt_idx = tl.program_id(0)
head_block_idx = tl.program_id(1)
head_offsets = head_block_idx * BLOCK_HEADS + tl.arange(0, BLOCK_HEADS)
head_mask = head_offsets < num_heads
# cos/sin은 헤드 간 공유 (1D 로드)
cos_vals = tl.load(cos_row_ptr + offsets_half, mask=half_mask)
# x는 [BLOCK_HEADS, BLOCK_HS_HALF] 2D 로드
x1_vals = tl.load(x_row_ptrs + offsets_x1[None, :], mask=mask)
# Broadcasting: cos/sin [1, HS] * x [BH, HS]
cos_fp32 = cos_vals.to(tl.float32)[None, :]
o1_vals = tl.fma(-x2_fp32, sin_fp32, x1_fp32 * cos_fp32)
왜 이게 좋은가
- cos/sin 재사용: 같은 토큰 위치의 cos/sin 값을
BLOCK_HEADS개 헤드가 공유하므로 메모리 로드가 줄어든다. - Autotune 확장:
BLOCK_HEADS가 1, 2, 4, 8까지 자동 탐색되어 GPU별 최적 설정이 선택된다. - Grid 크기 감소: 헤드 수 32 기준,
BLOCK_HEADS=8이면 grid이 1/8로 줄어 launch 오버헤드가 감소한다.
정리
1D flatten grid를 2D grid로 전환하고 헤드 차원을 블록화하는 것은 Triton 커널 최적화의 기본 패턴이다. cos/sin broadcasting으로 메모리 효율도 동시에 개선한 깔끔한 변경이다.
참고 자료
이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 코드 분석과 해석에서 오류가 있을 수 있으니, 정확한 내용은 원본 PR을 참고해주세요.
관련 포스트
PR Analysis 의 다른글
- 이전글 [SGLang] wait-for-jobs에 ETag conditional request 도입으로 API rate limit 절약
- 현재글 : [SGLang] Diffusion Triton Rotary Embedding 다중 헤드 병렬 처리 최적화
- 다음글 [triton] Triton AMD 백엔드 최적화: SGPR 활용과 루프 최적화를 통한 GEMM 성능 향상
댓글