본문으로 건너뛰기

[triton] AMD GFX1250용 Warp-Pipeline f16 GEMM 예제 추가

PR 링크: triton-lang/triton#9382 상태: Merged | 변경: +179 / -0

들어가며

AMD GFX1250은 TDM(Triton Data Mover)과 warp pipeline을 지원하는 최신 아키텍처입니다. 이 PR은 이 두 기능을 결합한 f16 GEMM(General Matrix Multiplication) 커널 예제를 추가합니다. Triple buffering과 warp pipeline stage를 활용하여 메모리 로드와 연산을 효과적으로 오버랩합니다.

핵심 코드 분석

핵심 루프 구조 - warp pipeline의 2단계 파이프라인:

# Triple buffering: 2개 prefetch
for _ in ttgl.static_range(2):
    producer = issue_loads(producer, a_desc, b_desc, 0, 0,
                          a_buffer, b_buffer, BLOCK_K, NUM_BUFFERS, TRANSPOSE_B)

ttgl.amd.gfx1250.tdm.async_wait(1 * 2)

for _ in range(0, ttgl.cdiv(K, BLOCK_K) - (NUM_BUFFERS - 1)):
    with ttgl.amd.warp_pipeline_stage("stage0", priority=1):
        consumer, a, b = lds_load(consumer, a_buffer, OPERAND_LAYOUT_A,
                                  b_buffer, OPERAND_LAYOUT_B, NUM_BUFFERS,
                                  TRANSPOSE_B)
    ttgl.amd.gfx1250.tdm.async_wait(0)
    with ttgl.amd.warp_pipeline_stage("stage1", priority=0):
        producer = issue_loads(producer, a_desc, b_desc, ...)
        accumulator = issue_wmma_compute(a, b, accumulator)

WMMA layout은 GFX1250의 16x16x32 instrShape을 사용:

WMMA_LAYOUT = ttgl.amd.AMDWMMALayout(3, True, WARP_BASES, [],
                                      [16, 16, 32])

왜 이게 좋은가

Warp pipeline은 하나의 warp 내에서 두 개의 코드 경로(stage0: LDS load, stage1: compute + global load)를 시분할(time-multiplexing)하여 실행합니다. priority=1인 stage0이 priority=0인 stage1보다 먼저 실행되므로, LDS에서 데이터를 읽는 동안 이전 데이터의 WMMA 연산과 다음 데이터의 global load가 병행됩니다. Triple buffering(NUM_BUFFERS=3)으로 충분한 버퍼를 확보하여 파이프라인 bubble을 최소화합니다.

정리

AMD GFX1250에서 TDM async load + warp pipeline stage + triple buffering을 결합한 f16 GEMM 예제를 추가하고, 256x256x64 타일 크기의 correctness 테스트를 포함했습니다.

참고 자료

이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글