본문으로 건너뛰기

[SGLang] Triton Attention 커널: Python으로 작성하는 GPU 커널

들어가며

GPU 커널을 CUDA C++로 작성하면 성능은 좋지만, 개발과 디버깅이 어렵다. OpenAI가 만든 Triton은 Python으로 GPU 커널을 작성할 수 있게 해주는 컴파일러이다. SGLang의 TritonAttnBackend는 Triton으로 작성한 Decode, Extend(Prefill), Merge 커널을 조합하여 FlashInfer나 FlashAttention 없이도 어텐션을 수행한다.

이 글에서는 python/sglang/srt/layers/attention/triton_backend.pytriton_ops/ 디렉토리를 분석한다.

구조도

┌──────────────────────────────────────────────────┐
│              TritonAttnBackend                    │
│                                                  │
│  init_forward_metadata()                         │
│    ├── Decode: kv_indptr, kv_indices 구성         │
│    ├── Extend: qo_indptr, custom_mask 구성        │
│    └── Target Verify: mask_indptr 구성            │
│                                                  │
│  forward_decode() ──▶ decode_attention_fwd       │
│  forward_extend() ──▶ extend_attention_fwd       │
└──────────────────────────────────────────────────┘

triton_ops/
├── decode_attention.py    ← Flash Decoding (2-stage)
├── extend_attention.py    ← Prefill/Extend 커널
├── prefill_attention.py   ← Context Attention
├── merge_state.py         ← 분할 결과 병합
├── double_sparsity_attention.py
├── rocm_mla_decode_rope.py
└── trtllm_fp8_kv_kernel.py

핵심 코드 분석

TritonAttnBackend 초기화: 커널 바인딩과 버퍼 할당

class TritonAttnBackend(AttentionBackend):
    def __init__(self, model_runner: ModelRunner, skip_prefill=False, ...):
        from sglang.srt.layers.attention.triton_ops.decode_attention import (
            decode_attention_fwd,
        )
        from sglang.srt.layers.attention.triton_ops.extend_attention import (
            build_unified_kv_indices,
            extend_attention_fwd,
            extend_attention_fwd_unified,
        )

        self.decode_attention_fwd = torch.compiler.disable(decode_attention_fwd)
        self.extend_attention_fwd = torch.compiler.disable(extend_attention_fwd)

Triton 커널을 lazy import하고 torch.compiler.disable로 감싼다. PyTorch 컴파일러가 Triton 커널 내부를 추적하지 않도록 하는 것이다. Triton 커널은 이미 자체 컴파일러로 최적화되므로 torch.compile의 추적은 불필요하다.

KV Split 결정: GPU 코어 수 기반 동적 분할

def get_num_kv_splits(self, num_kv_splits, seq_lens):
    if self.enable_deterministic:
        # 결정적 추론: 고정 타일 크기 사용
        num_kv_splits[:] = (
            expanded_seq_lens + self.split_tile_size - 1
        ) // self.split_tile_size
        return

    get_num_kv_splits_triton[(1,)](
        num_kv_splits, seq_lens,
        num_seq, num_group,
        self.num_head, self.num_kv_head,
        self.max_kv_splits, self.device_core_count,
        MAX_NUM_SEQ=SCHEDULE_SEQ,
    )

Decode에서 긴 KV 캐시를 한 번에 처리하면 GPU SM 활용도가 낮아진다. SGLang은 KV 캐시를 여러 split으로 나누어 병렬 처리한다. 결정적 추론 모드에서는 split_tile_size로 고정 분할하고, 기본 모드에서는 GPU 코어 수와 시퀀스 길이를 기반으로 Triton 커널이 동적으로 최적 split 수를 계산한다.

Decode 커널: 2-Stage Flash Decoding

# decode_attention.py
@triton.jit
def _fwd_kernel_stage1(
    Q, K_Buffer, V_Buffer, sm_scale_withk,
    kv_indptr, kv_indices,
    Att_Out, Att_Lse,
    num_kv_splits,
    ...
    kv_group_num: tl.constexpr,
    BLOCK_DMODEL: tl.constexpr,
    BLOCK_DV: tl.constexpr,
    BLOCK_N: tl.constexpr,
    ...
):
    cur_batch = tl.program_id(0)
    cur_head = tl.program_id(1)
    split_kv_id = tl.program_id(2)
    cur_kv_head = cur_head // kv_group_num

Decode 커널은 LightLLM에서 차용한 2-Stage Flash Decoding 패턴이다. Stage 1에서 각 KV split에 대해 부분 어텐션을 계산하고 중간 결과(logits + LSE)를 저장한다. Stage 2에서 모든 split의 결과를 log-sum-exp로 안전하게 병합한다. 3D 그리드 (batch, head, split)로 GPU 병렬성을 최대화한다.

Extend 커널: Prefill + KV Cache 통합

# extend_attention.py
def _get_block_sizes_for_extend_attention(Lq: int, Lv: int):
    if Lq == 576:
        BLOCK_DMODEL = 512
        BLOCK_DPE = 64
    elif Lq == 288:
        BLOCK_DMODEL = 256
        BLOCK_DPE = 32
    elif Lq == 192:
        BLOCK_DMODEL = 128
        BLOCK_DPE = 64
    else:
        BLOCK_DMODEL = triton.next_power_of_2(Lq)
        BLOCK_DPE = 0

Extend 커널은 Prefill과 기존 KV 캐시를 함께 어텐션한다. Head dimension에 따라 블록 크기를 조정하며, DeepSeek-V2의 MLA(576차원), 288차원 등 특수한 head dim을 위한 전용 설정이 있다. BLOCK_DPE는 position embedding을 위한 별도 차원이다. 아키텍처별 최적 블록 크기도 세밀하게 조정한다(Blackwell sm120, Hopper sm100, HIP 등).

Sliding Window 지원

if self.sliding_window_size is not None and self.sliding_window_size > 0:
    self.window_kv_indptr = torch.zeros(
        (max_bs + 1,), dtype=torch.int32, device=model_runner.device
    )

Gemma3 같은 모델은 일부 레이어에서 Sliding Window Attention을 사용한다. Triton 백엔드는 일반 KV 인덱스(kv_indptr, kv_indices)와 별도로 윈도우 KV 인덱스(window_kv_indptr, window_kv_indices)를 관리한다. SWA 레이어와 Full Attention 레이어의 v_head_dim이 다를 수 있어(예: Gemma4에서 SWA=256, Full=512), 별도의 swa_attn_logits 버퍼도 할당한다.

ForwardMetadata: 단계별 메타데이터 구조

@dataclass
class ForwardMetadata:
    attn_logits: torch.Tensor      # Decode용 중간 결과
    attn_lse: torch.Tensor         # Decode용 LSE
    max_extend_len: int            # Extend 최대 길이
    num_kv_splits: torch.Tensor    # 분할 수
    kv_indptr: torch.Tensor        # KV 인덱스 포인터
    kv_indices: torch.Tensor       # KV 인덱스
    qo_indptr: torch.Tensor        # Query 인덱스 포인터
    custom_mask: torch.Tensor      # Speculative용 커스텀 마스크
    mask_indptr: torch.Tensor      # 마스크 인덱스 포인터
    window_kv_indptr: torch.Tensor # Sliding Window KV 포인터
    window_kv_indices: torch.Tensor
    window_num_kv_splits: torch.Tensor
    window_kv_offsets: torch.Tensor

Decode, Extend, Target Verify, Draft Extend 각 모드에서 필요한 메타데이터가 다르다. ForwardMetadata는 모든 모드의 메타데이터를 하나의 구조에 담되, 각 모드에서 사용하지 않는 필드는 None으로 둔다.

Triton vs CUDA 커널 비교

특성 Triton CUDA C++
언어 Python C++/PTX
개발 속도 빠름 느림
디버깅 Python 도구 사용 가능 nsight, cuda-gdb
메모리 관리 자동 (tl.load/store) 수동 (shared memory)
성능 CUDA의 ~90-95% 최적
이식성 AMD ROCm 지원 NVIDIA 전용

SGLang이 Triton 백엔드를 별도로 유지하는 이유는 AMD GPU 지원과 빠른 실험이다. FlashInfer는 CUDA 전용이므로, ROCm 환경에서는 Triton 백엔드가 유일한 선택지이다.

관련 포스트

  • Hybrid Attention: Dense-Sparse 동적 전환 전략
  • GDN (Gated Diagonal Net): 게이트 기반 선형 어텐션

참고

댓글

관련 포스트

SGLang 의 다른글