본문으로 건너뛰기

[triton] Triton AMD 백엔드: 8-Wave PingPong Attention 커널 구현 분석

PR 링크: triton-lang/triton#9427 상태: Merged | 변경: +262 / -32

들어가며

최신 LLM 모델의 성능을 결정짓는 핵심 연산인 Attention 커널은 GPU의 메모리 대역폭과 연산 유닛을 얼마나 효율적으로 활용하느냐에 따라 성능이 크게 좌우됩니다. 특히 AMD GPU 환경에서 Triton을 사용할 때, 기존의 파이프라이닝 방식만으로는 SIMD 유닛의 활용도를 극대화하는 데 한계가 있었습니다. 이번 PR에서는 'PingPong' 기법을 도입하여 동일한 SIMD 내에서 두 개의 웨이브(wave)가 협력적으로 실행되도록 함으로써, 8-Wave 기반의 고성능 Attention 커널을 구현했습니다.

코드 분석

1. AttentionConfig 및 NUM_WARPS 확장

기존에는 4-warp 기반의 설정이 주를 이루었으나, 8-wave PingPong 커널을 지원하기 위해 AttentionConfig 클래스를 수정했습니다.

# Before
@gluon.constexpr_function
def __init__(self, SEQLEN_Q, SEQLEN_K, HEAD_SZ, BLOCK_M, BLOCK_N, NUM_BUFFERS):
    # ...

# After
@gluon.constexpr_function
def __init__(self, SEQLEN_Q, SEQLEN_K, HEAD_SZ, BLOCK_M, BLOCK_N, NUM_BUFFERS, NUM_WARPS):
    assert NUM_WARPS == 4 or NUM_WARPS == 8
    if NUM_WARPS == 4:
        warp_bases = [[1, 0], [2, 0]]
    else:
        warp_bases = [[1, 0], [2, 0], [4, 0]]

warp_bases를 동적으로 설정하여 8-warp 환경에서도 적절한 레지스터 및 메모리 레이아웃을 가져갈 수 있도록 개선했습니다.

2. PingPong 커널 구현 (attn_fwd_pingpong_pipelined_kernel)

핵심은 warp_pipeline_stage를 활용한 연산과 메모리 로드의 중첩입니다. PingPong 기법은 한 웨이브가 연산을 수행하는 동안 다른 웨이브가 데이터를 프리페치(prefetch)하도록 설계되었습니다.

# Hot Loop 내의 파이프라인 스테이지
with gl.amd.warp_pipeline_stage("stage0", priority=0):
    qk = pgm.compute_qk_no_mask(k)

gl.amd.gfx1250.tdm.async_wait(2)
with gl.amd.warp_pipeline_stage("stage1", priority=1):
    p, l_i, acc = pgm.softmax_part1(p, l_i, acc, alpha)
    v = pgm.v_buffer.index(iter_id % NUM_BUFFERS).load(layout=pgm.cfg.v_layout)
    pgm.tdm_load_global_to_shared_k([t_3, 0], (iter_id + 1) % NUM_BUFFERS)

이 구조는 async_wait를 통해 메모리 로드 완료를 기다리면서도, 연산 파이프라인이 멈추지 않도록 하여 GPU 점유율을 극대화합니다.

왜 이게 좋은가

  1. SIMD 활용도 극대화: 8-Wave PingPong 방식은 동일한 SIMD 유닛 내에서 연산과 메모리 접근을 교차(interleaving)시킴으로써, 메모리 지연 시간(latency)을 숨기고 연산 유닛의 유휴 시간을 최소화합니다.
  2. 유연한 커널 선택: ATTN_FN을 문자열 기반으로 관리하도록 변경하여, 사용자가 상황에 맞는 최적의 커널(pipeline vs pingpong)을 쉽게 선택할 수 있게 되었습니다.
  3. 교훈: GPU 최적화에서 '파이프라이닝'은 단순히 루프를 푸는 것이 아니라, 데이터 로드와 연산의 의존성을 분석하여 하드웨어의 비동기 엔진(TDM 등)을 얼마나 잘 활용하느냐가 핵심임을 보여줍니다.

이번 변경은 특히 AMD GFX1250 아키텍처에서 대규모 Attention 연산 시 처리량(throughput)을 유의미하게 개선할 것으로 기대됩니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글