[Triton] GFX1250용 stream-k 및 attention decode 커널 업데이트
들어가며
AMD gfx1250용 Gluon 예제 커널에 두 가지 주요 업데이트가 이루어졌다: attention decode 커널의 추가와, 기존 flash attention 커널의 LDS reduction 및 split-k 지원이다. Attention decode는 추론 시 KV-cache에서 단일 query 토큰에 대해 attention을 계산하는 연산으로, GPU 점유율을 높이기 위해 K 차원을 분할하는 것이 일반적이다.
핵심 코드 분석
Decode Forward 커널
@gluon.jit
def attn_decode_fwd_kernel(..., SPLIT_FACTOR: gl.constexpr, CHUNK_SIZE: gl.constexpr):
split_id = gl.program_id(2)
start_k = split_id * CHUNK_SIZE
end_k = min(start_k + CHUNK_SIZE, SEQLEN_K)
for current_k in range(start_k, end_k, BLOCK_N):
# K, V 로드 및 attention 계산
pgm.tdm_load_global_to_shared_k([current_k, 0], buffer_index=0)
k = pgm.tdm_shared_load_k(0, wait_count=0)
qk = pgm.compute_qk(k, current_k)
# split 경계 마스킹
extra_mask = (current_k + gl.arange(0, BLOCK_N, ...)) < end_k
qk = gl.where(extra_mask, qk, float("-inf"))
# ... softmax + PV 계산
Reduce 커널
@gluon.jit
def attn_decode_reduce_kernel(..., SPLIT_FACTOR: gl.constexpr):
for s in range(SPLIT_FACTOR):
m_s = gl.load(mid_m_ptr + ...)
l_s = gl.load(mid_l_ptr + ...)
acc_s = gl.amd.gfx1250.buffer_load(mid_o_ptr, mid_o_offs)
m_new = gl.maximum(m_global, m_s)
alpha = gl.exp2((m_global - m_new) * SM_SCALE * rcp_ln2)
beta = gl.exp2((m_s - m_new) * SM_SCALE * rcp_ln2)
l_global = l_global * alpha + l_s * beta
acc_global = acc_global * alpha[:, None] + acc_s * beta[:, None]
split-k attention의 표준 패턴: 각 split의 partial softmax 결과(m_s, l_s, acc_s)를 online softmax로 합산한다.
왜 이게 좋은가
- GPU 활용도 극대화: decode 시 query가 1개뿐이라 병렬성이 부족한데, split-k로 K 차원을 분할하여 GPU를 활용한다.
- TDM 활용: gfx1250의 Tensor Data Mover로 K, V를 비동기 로드한다.
- 자동 split factor 계산: target workgroup 수(1024)를 기반으로 최적 split factor를 자동 결정한다.
정리
+297/-30 변경으로, gfx1250에서의 attention decode 최적화를 위한 완전한 2-커널(forward + reduce) 패턴을 구현했다.
참고 자료
이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 코드 분석 내용은 실제 PR diff를 기반으로 합니다.
댓글