[triton] AMD GFX1250용 Warp-Pipeline f16 GEMM 예제 추가
PR 링크: triton-lang/triton#9382 상태: Merged | 변경: +179 / -0
들어가며
AMD GFX1250은 TDM(Triton Data Mover)과 warp pipeline을 지원하는 최신 아키텍처입니다. 이 PR은 이 두 기능을 결합한 f16 GEMM(General Matrix Multiplication) 커널 예제를 추가합니다. Triple buffering과 warp pipeline stage를 활용하여 메모리 로드와 연산을 효과적으로 오버랩합니다.
핵심 코드 분석
핵심 루프 구조 - warp pipeline의 2단계 파이프라인:
# Triple buffering: 2개 prefetch
for _ in ttgl.static_range(2):
producer = issue_loads(producer, a_desc, b_desc, 0, 0,
a_buffer, b_buffer, BLOCK_K, NUM_BUFFERS, TRANSPOSE_B)
ttgl.amd.gfx1250.tdm.async_wait(1 * 2)
for _ in range(0, ttgl.cdiv(K, BLOCK_K) - (NUM_BUFFERS - 1)):
with ttgl.amd.warp_pipeline_stage("stage0", priority=1):
consumer, a, b = lds_load(consumer, a_buffer, OPERAND_LAYOUT_A,
b_buffer, OPERAND_LAYOUT_B, NUM_BUFFERS,
TRANSPOSE_B)
ttgl.amd.gfx1250.tdm.async_wait(0)
with ttgl.amd.warp_pipeline_stage("stage1", priority=0):
producer = issue_loads(producer, a_desc, b_desc, ...)
accumulator = issue_wmma_compute(a, b, accumulator)
WMMA layout은 GFX1250의 16x16x32 instrShape을 사용:
WMMA_LAYOUT = ttgl.amd.AMDWMMALayout(3, True, WARP_BASES, [],
[16, 16, 32])
왜 이게 좋은가
Warp pipeline은 하나의 warp 내에서 두 개의 코드 경로(stage0: LDS load, stage1: compute + global load)를 시분할(time-multiplexing)하여 실행합니다. priority=1인 stage0이 priority=0인 stage1보다 먼저 실행되므로, LDS에서 데이터를 읽는 동안 이전 데이터의 WMMA 연산과 다음 데이터의 global load가 병행됩니다. Triple buffering(NUM_BUFFERS=3)으로 충분한 버퍼를 확보하여 파이프라인 bubble을 최소화합니다.
정리
AMD GFX1250에서 TDM async load + warp pipeline stage + triple buffering을 결합한 f16 GEMM 예제를 추가하고, 256x256x64 타일 크기의 correctness 테스트를 포함했습니다.
참고 자료
이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [Loki] 대소문자 무시 정규식을 바이너리 연산자로 최적화
- 현재글 : [triton] AMD GFX1250용 Warp-Pipeline f16 GEMM 예제 추가
- 다음글 [Triton] TMA im2col 모드 — LLVM Lowering 구현
댓글