[sglang] [SGLang] Blackwell(B200)에서 Diffusion Attention 성능을 7배 끌어올리는 Triton 커널 최적화 분석
PR 링크: sgl-project/sglang#26318 상태: Merged | 변경: +649 / -3
들어가며
최근 Qwen-Image와 같은 멀티모달 Diffusion 모델들은 텍스트와 이미지 스트림이 섞인 가변 길이(variable length) 데이터를 처리하기 위해 USPAttention을 사용합니다. 하지만 PyTorch의 표준 SDPA(Scaled Dot Product Attention)는 attn_mask가 제공될 경우 FlashAttention을 사용하지 못하고 구형 아키텍처용인 EFFICIENT (CUTLASS SM80) 커널로 폴백(fallback)되는 고질적인 문제가 있습니다.
특히 최신 Blackwell(B200, SM 10.0) 아키텍처에서 이 SM80 경로는 네이티브 FA4(flash_fwd_sm100) 경로보다 약 7배나 느려, 전체 디노이징(denoise) 루프의 병목 지점이 됩니다. 이번 PR은 이를 해결하기 위해 가변 길이 데이터를 패킹(packing)하여 FlashAttention의 varlen 인터페이스를 직접 호출하는 최적화된 Triton 커널 쌍을 도입했습니다.
코드 분석: Triton을 이용한 Fused Pack/Scatter
기존에는 PyTorch의 기본 연산들(index_select, zeros, index_copy_)을 조합해 가변 길이를 처리할 수 있었지만, 이는 5번의 호스트 런칭(Host Launch) 오버헤드를 발생시킵니다. 이번 PR은 이를 단 2개의 Triton 커널로 퓨전(Fusion)했습니다.
1. Q/K/V 통합 패킹 커널 (fused_pack_qkv)
가장 핵심적인 변경 사항은 유효한 토큰들만 골라내어 연속된 메모리 공간으로 모으는 패킹 과정입니다. 기존 방식이 Q, K, V 각각에 대해 index_select를 호출했다면, 새로운 Triton 커널은 이를 한 번의 커널 실행으로 끝냅니다.
# python/sglang/jit_kernel/diffusion/triton/varlen_pack_pad.py
@triton.jit
def _fused_pack_qkv_kernel(
Q_ptr, K_ptr, V_ptr,
Q_unpad_ptr, K_unpad_ptr, V_unpad_ptr,
indices_ptr,
HD, src_row_stride, dst_row_stride,
BLOCK_HD: tl.constexpr,
):
# 각 프로그램은 패킹된 결과물의 한 행(row)을 담당합니다.
out_row = tl.program_id(0)
# 원본 데이터에서의 실제 인덱스를 로드합니다.
src_row = tl.load(indices_ptr + out_row).to(tl.int64)
cols = tl.arange(0, BLOCK_HD)
col_mask = cols < HD
src_offset = src_row * src_row_stride + cols
dst_offset = out_row * dst_row_stride + cols
# Q, K, V를 동시에 로드하여 패킹된 버퍼에 저장합니다.
q_val = tl.load(Q_ptr + src_offset, mask=col_mask)
k_val = tl.load(K_ptr + src_offset, mask=col_mask)
v_val = tl.load(V_ptr + src_offset, mask=col_mask)
tl.store(Q_unpad_ptr + dst_offset, q_val, mask=col_mask)
tl.store(K_unpad_ptr + dst_offset, k_val, mask=col_mask)
tl.store(V_unpad_ptr + dst_offset, v_val, mask=col_mask)
2. 출력값 복원 커널 (fused_scatter_to_padded)
FlashAttention 연산이 끝난 후, 패킹된 데이터를 다시 원래의 [B, S, H, D] 레이아웃으로 돌려놓아야 합니다. 이때 마스킹된(무효한) 위치는 0으로 채워야 하는데, 이 역시 Triton 커널로 최적화되었습니다.
@triton.jit
def _fused_scatter_to_padded_kernel(
Out_unpad_ptr, Out_padded_ptr,
inv_indices_ptr, # [B*S]: 유효하면 패킹 인덱스, 아니면 -1
HD, src_row_stride, dst_row_stride,
BLOCK_HD: tl.constexpr,
):
pad_row = tl.program_id(0)
inv_idx = tl.load(inv_indices_ptr + pad_row).to(tl.int64)
cols = tl.arange(0, BLOCK_HD)
col_mask = cols < HD
valid = inv_idx >= 0
# 유효한 행이면 데이터를 가져오고, 아니면 0.0으로 채웁니다.
safe_idx = tl.where(valid, inv_idx, 0)
src_offset = safe_idx * src_row_stride + cols
dst_offset = pad_row * dst_row_stride + cols
val = tl.load(Out_unpad_ptr + src_offset, mask=col_mask & valid, other=0.0)
tl.store(Out_padded_ptr + dst_offset, val, mask=col_mask)
왜 이게 좋은 최적화인가?
1. 호스트 오버헤드 극소화
기존 PyTorch 연산 조합 방식은 다음과 같은 단계를 거칩니다:
index_select(Q) +index_select(K) +index_select(V) +zeros(Output Buffer) +index_copy_(Scatter)
이 5번의 연산은 각각 GPU 커널 런칭을 유발하며, 특히 배치 사이즈가 작거나 시퀀스 길이가 짧은 경우 호스트 사이드의 오버헤드가 전체 실행 시간의 상당 부분을 차지하게 됩니다. 이번 PR은 이를 2번의 Triton 커널 런칭으로 줄여 Dispatch 지연 시간을 대폭 개선했습니다.
2. Blackwell 아키텍처의 잠재력 활용
B200과 같은 최신 하드웨어는 FA3/FA4와 같은 최신 알고리즘에 최적화되어 있습니다. 마스크 때문에 구형 SM80 커널로 폴백되는 것을 막고, 데이터를 수동으로 패킹하여 flash_attn_varlen_func를 호출함으로써 하드웨어 가속기를 100% 활용할 수 있게 되었습니다.
3. 성능 수치 (Performance Metrics)
B200에서 Qwen-Image-2512 모델로 테스트한 결과:
- GPU Active Time: 최대 21.0% 감소 (Batch size 6 기준)
- End-to-End Latency: 최대 16.9% 개선 (Batch size 20 기준)
- 커널 수준에서
fmha_cutlassF가 29,720ms 걸리던 작업이 단 38ms로 단축되었습니다.
리뷰어 피드백 반영
리뷰 과정에서 BBuf와 nvpohanh는 안전성과 범용성에 대한 피드백을 주었습니다.
- 장치 일치 확인:
attn_mask.device == q.device체크를 추가하여 런타임 에러를 방지했습니다. - 메타데이터 재사용:
build_varlen_mask_meta를 통해 인덱스 계산을 요청당 한 번만 수행하고 모든 블록에서 재사용하도록 설계하여 중복 계산을 없앴습니다. - 엄격한 게이팅:
SGLANG_VARLEN_FA=0환경 변수를 통해 언제든 기존 안전한 경로로 돌아갈 수 있는 기능을 제공합니다.
결론
이번 최적화는 단순히 "더 빠른 커널"을 쓰는 것을 넘어, 상위 프레임워크(PyTorch SDPA)의 한계를 이해하고 하위 수준(Triton)에서 데이터를 재구조화하여 최적의 경로를 강제했다는 점에서 큰 의미가 있습니다. 특히 Blackwell과 같은 최신 GPU를 사용하는 환경에서 Diffusion 모델의 서빙 성능을 극대화하려는 엔지니어들에게 훌륭한 레퍼런스가 될 것입니다.
참고 자료
- flash_attn_varlen_func — 가변 길이 시퀀스를 위한 FlashAttention 공식 구현
- triton.jit — Triton 커널 작성을 위한 공식 문서
- torch.nn.functional.scaled_dot_product_attention — PyTorch SDPA 공식 문서
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [vllm] vLLM의 MoE Permute 최적화: 버퍼 사전 할당을 통한 성능 향상
- 현재글 : [sglang] [SGLang] Blackwell(B200)에서 Diffusion Attention 성능을 7배 끌어올리는 Triton 커널 최적화 분석
- 다음글 [feast] Feast Feature Server의 직렬화 성능 4배 향상: MessageToDict 최적화
댓글