[SGLang] Triton Attention 커널: Python으로 작성하는 GPU 커널
들어가며
GPU 커널을 CUDA C++로 작성하면 성능은 좋지만, 개발과 디버깅이 어렵다. OpenAI가 만든 Triton은 Python으로 GPU 커널을 작성할 수 있게 해주는 컴파일러이다. SGLang의 TritonAttnBackend는 Triton으로 작성한 Decode, Extend(Prefill), Merge 커널을 조합하여 FlashInfer나 FlashAttention 없이도 어텐션을 수행한다.
이 글에서는 python/sglang/srt/layers/attention/triton_backend.py와 triton_ops/ 디렉토리를 분석한다.
구조도
┌──────────────────────────────────────────────────┐
│ TritonAttnBackend │
│ │
│ init_forward_metadata() │
│ ├── Decode: kv_indptr, kv_indices 구성 │
│ ├── Extend: qo_indptr, custom_mask 구성 │
│ └── Target Verify: mask_indptr 구성 │
│ │
│ forward_decode() ──▶ decode_attention_fwd │
│ forward_extend() ──▶ extend_attention_fwd │
└──────────────────────────────────────────────────┘
triton_ops/
├── decode_attention.py ← Flash Decoding (2-stage)
├── extend_attention.py ← Prefill/Extend 커널
├── prefill_attention.py ← Context Attention
├── merge_state.py ← 분할 결과 병합
├── double_sparsity_attention.py
├── rocm_mla_decode_rope.py
└── trtllm_fp8_kv_kernel.py
핵심 코드 분석
TritonAttnBackend 초기화: 커널 바인딩과 버퍼 할당
class TritonAttnBackend(AttentionBackend):
def __init__(self, model_runner: ModelRunner, skip_prefill=False, ...):
from sglang.srt.layers.attention.triton_ops.decode_attention import (
decode_attention_fwd,
)
from sglang.srt.layers.attention.triton_ops.extend_attention import (
build_unified_kv_indices,
extend_attention_fwd,
extend_attention_fwd_unified,
)
self.decode_attention_fwd = torch.compiler.disable(decode_attention_fwd)
self.extend_attention_fwd = torch.compiler.disable(extend_attention_fwd)
Triton 커널을 lazy import하고 torch.compiler.disable로 감싼다. PyTorch 컴파일러가 Triton 커널 내부를 추적하지 않도록 하는 것이다. Triton 커널은 이미 자체 컴파일러로 최적화되므로 torch.compile의 추적은 불필요하다.
KV Split 결정: GPU 코어 수 기반 동적 분할
def get_num_kv_splits(self, num_kv_splits, seq_lens):
if self.enable_deterministic:
# 결정적 추론: 고정 타일 크기 사용
num_kv_splits[:] = (
expanded_seq_lens + self.split_tile_size - 1
) // self.split_tile_size
return
get_num_kv_splits_triton[(1,)](
num_kv_splits, seq_lens,
num_seq, num_group,
self.num_head, self.num_kv_head,
self.max_kv_splits, self.device_core_count,
MAX_NUM_SEQ=SCHEDULE_SEQ,
)
Decode에서 긴 KV 캐시를 한 번에 처리하면 GPU SM 활용도가 낮아진다. SGLang은 KV 캐시를 여러 split으로 나누어 병렬 처리한다. 결정적 추론 모드에서는 split_tile_size로 고정 분할하고, 기본 모드에서는 GPU 코어 수와 시퀀스 길이를 기반으로 Triton 커널이 동적으로 최적 split 수를 계산한다.
Decode 커널: 2-Stage Flash Decoding
# decode_attention.py
@triton.jit
def _fwd_kernel_stage1(
Q, K_Buffer, V_Buffer, sm_scale_withk,
kv_indptr, kv_indices,
Att_Out, Att_Lse,
num_kv_splits,
...
kv_group_num: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_DV: tl.constexpr,
BLOCK_N: tl.constexpr,
...
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
split_kv_id = tl.program_id(2)
cur_kv_head = cur_head // kv_group_num
Decode 커널은 LightLLM에서 차용한 2-Stage Flash Decoding 패턴이다. Stage 1에서 각 KV split에 대해 부분 어텐션을 계산하고 중간 결과(logits + LSE)를 저장한다. Stage 2에서 모든 split의 결과를 log-sum-exp로 안전하게 병합한다. 3D 그리드 (batch, head, split)로 GPU 병렬성을 최대화한다.
Extend 커널: Prefill + KV Cache 통합
# extend_attention.py
def _get_block_sizes_for_extend_attention(Lq: int, Lv: int):
if Lq == 576:
BLOCK_DMODEL = 512
BLOCK_DPE = 64
elif Lq == 288:
BLOCK_DMODEL = 256
BLOCK_DPE = 32
elif Lq == 192:
BLOCK_DMODEL = 128
BLOCK_DPE = 64
else:
BLOCK_DMODEL = triton.next_power_of_2(Lq)
BLOCK_DPE = 0
Extend 커널은 Prefill과 기존 KV 캐시를 함께 어텐션한다. Head dimension에 따라 블록 크기를 조정하며, DeepSeek-V2의 MLA(576차원), 288차원 등 특수한 head dim을 위한 전용 설정이 있다. BLOCK_DPE는 position embedding을 위한 별도 차원이다. 아키텍처별 최적 블록 크기도 세밀하게 조정한다(Blackwell sm120, Hopper sm100, HIP 등).
Sliding Window 지원
if self.sliding_window_size is not None and self.sliding_window_size > 0:
self.window_kv_indptr = torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)
Gemma3 같은 모델은 일부 레이어에서 Sliding Window Attention을 사용한다. Triton 백엔드는 일반 KV 인덱스(kv_indptr, kv_indices)와 별도로 윈도우 KV 인덱스(window_kv_indptr, window_kv_indices)를 관리한다. SWA 레이어와 Full Attention 레이어의 v_head_dim이 다를 수 있어(예: Gemma4에서 SWA=256, Full=512), 별도의 swa_attn_logits 버퍼도 할당한다.
ForwardMetadata: 단계별 메타데이터 구조
@dataclass
class ForwardMetadata:
attn_logits: torch.Tensor # Decode용 중간 결과
attn_lse: torch.Tensor # Decode용 LSE
max_extend_len: int # Extend 최대 길이
num_kv_splits: torch.Tensor # 분할 수
kv_indptr: torch.Tensor # KV 인덱스 포인터
kv_indices: torch.Tensor # KV 인덱스
qo_indptr: torch.Tensor # Query 인덱스 포인터
custom_mask: torch.Tensor # Speculative용 커스텀 마스크
mask_indptr: torch.Tensor # 마스크 인덱스 포인터
window_kv_indptr: torch.Tensor # Sliding Window KV 포인터
window_kv_indices: torch.Tensor
window_num_kv_splits: torch.Tensor
window_kv_offsets: torch.Tensor
Decode, Extend, Target Verify, Draft Extend 각 모드에서 필요한 메타데이터가 다르다. ForwardMetadata는 모든 모드의 메타데이터를 하나의 구조에 담되, 각 모드에서 사용하지 않는 필드는 None으로 둔다.
Triton vs CUDA 커널 비교
| 특성 | Triton | CUDA C++ |
|---|---|---|
| 언어 | Python | C++/PTX |
| 개발 속도 | 빠름 | 느림 |
| 디버깅 | Python 도구 사용 가능 | nsight, cuda-gdb |
| 메모리 관리 | 자동 (tl.load/store) | 수동 (shared memory) |
| 성능 | CUDA의 ~90-95% | 최적 |
| 이식성 | AMD ROCm 지원 | NVIDIA 전용 |
SGLang이 Triton 백엔드를 별도로 유지하는 이유는 AMD GPU 지원과 빠른 실험이다. FlashInfer는 CUDA 전용이므로, ROCm 환경에서는 Triton 백엔드가 유일한 선택지이다.
관련 포스트
- Hybrid Attention: Dense-Sparse 동적 전환 전략
- GDN (Gated Diagonal Net): 게이트 기반 선형 어텐션
참고
관련 포스트
SGLang 의 다른글
- 이전글 [SGLang] Hybrid Attention: Dense-Sparse 동적 전환 전략
- 현재글 : [SGLang] Triton Attention 커널: Python으로 작성하는 GPU 커널
- 다음글 [SGLang] Mamba (SSM): 선형 시간 복잡도 시퀀스 모델링
댓글