본문으로 건너뛰기

[Triton] GFX1250용 stream-k 및 attention decode 커널 업데이트

들어가며

AMD gfx1250용 Gluon 예제 커널에 두 가지 주요 업데이트가 이루어졌다: attention decode 커널의 추가와, 기존 flash attention 커널의 LDS reduction 및 split-k 지원이다. Attention decode는 추론 시 KV-cache에서 단일 query 토큰에 대해 attention을 계산하는 연산으로, GPU 점유율을 높이기 위해 K 차원을 분할하는 것이 일반적이다.

핵심 코드 분석

Decode Forward 커널

@gluon.jit
def attn_decode_fwd_kernel(..., SPLIT_FACTOR: gl.constexpr, CHUNK_SIZE: gl.constexpr):
    split_id = gl.program_id(2)
    start_k = split_id * CHUNK_SIZE
    end_k = min(start_k + CHUNK_SIZE, SEQLEN_K)

    for current_k in range(start_k, end_k, BLOCK_N):
        # K, V 로드 및 attention 계산
        pgm.tdm_load_global_to_shared_k([current_k, 0], buffer_index=0)
        k = pgm.tdm_shared_load_k(0, wait_count=0)
        qk = pgm.compute_qk(k, current_k)

        # split 경계 마스킹
        extra_mask = (current_k + gl.arange(0, BLOCK_N, ...)) < end_k
        qk = gl.where(extra_mask, qk, float("-inf"))
        # ... softmax + PV 계산

Reduce 커널

@gluon.jit
def attn_decode_reduce_kernel(..., SPLIT_FACTOR: gl.constexpr):
    for s in range(SPLIT_FACTOR):
        m_s = gl.load(mid_m_ptr + ...)
        l_s = gl.load(mid_l_ptr + ...)
        acc_s = gl.amd.gfx1250.buffer_load(mid_o_ptr, mid_o_offs)

        m_new = gl.maximum(m_global, m_s)
        alpha = gl.exp2((m_global - m_new) * SM_SCALE * rcp_ln2)
        beta = gl.exp2((m_s - m_new) * SM_SCALE * rcp_ln2)
        l_global = l_global * alpha + l_s * beta
        acc_global = acc_global * alpha[:, None] + acc_s * beta[:, None]

split-k attention의 표준 패턴: 각 split의 partial softmax 결과(m_s, l_s, acc_s)를 online softmax로 합산한다.

왜 이게 좋은가

  • GPU 활용도 극대화: decode 시 query가 1개뿐이라 병렬성이 부족한데, split-k로 K 차원을 분할하여 GPU를 활용한다.
  • TDM 활용: gfx1250의 Tensor Data Mover로 K, V를 비동기 로드한다.
  • 자동 split factor 계산: target workgroup 수(1024)를 기반으로 최적 split factor를 자동 결정한다.

정리

+297/-30 변경으로, gfx1250에서의 attention decode 최적화를 위한 완전한 2-커널(forward + reduce) 패턴을 구현했다.

참고 자료


이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 코드 분석 내용은 실제 PR diff를 기반으로 합니다.

댓글