본문으로 건너뛰기

[triton] AMD MXFP FA 예제에서 TDM Store 도입으로 Output 저장 최적화

PR 링크: triton-lang/triton#9813 상태: Merged | 변경: +26 / -100

들어가며

AMD GFX1250의 Flash Attention 예제에서 output 저장은 기존에 buffer_store와 수동 store_layout을 사용했습니다. 이 PR은 이를 TDM(Triton Data Mover) store로 전환하여 약 74줄의 레이아웃 관련 코드를 제거합니다.

핵심 코드 분석

Before (수동 레이아웃 + buffer_store):

store_layout = get_store_layout([BLOCK_M, HEAD_SZ], NUM_WARPS)

o_blk = MemoryBlock.initialize(
    o_ptr + o_off,
    shape=[SEQLEN_Q, HEAD_SZ],
    block_shape=[BLOCK_M, HEAD_SZ],
    layout=cfg.store_layout)

o = acc.to(o_blk.dtype)
o = ttgl.convert_layout(o, cfg.store_layout)
buffer_store(o, o_blk.ptr, o_blk.offs, o_blk.mask)

After (TDM store):

# store_layout 필요 없음, get_store_layout 함수 자체 제거
o_base = SEQLEN_Q * HEAD_SZ * (NUM_Q_HEADS * off_z + off_h)
o_shape = [SEQLEN_Q, HEAD_SZ]
# TDM store가 레이아웃 변환을 자동 처리

또한 get_store_layoutsplit_n 유틸리티 함수(약 60줄)가 완전히 제거되었습니다.

왜 이게 좋은가

TDM store는 하드웨어가 최적의 메모리 접근 패턴을 선택하므로, 수동으로 store_layout을 계산하고 convert_layout을 수행할 필요가 없습니다. 이는 코드 복잡성을 크게 줄이면서도 하드웨어 최적화를 유지합니다. 특히 get_store_layout 함수는 "128B contiguous store를 위해 각 lane이 inner dim의 절반을 담당" 같은 복잡한 레이아웃 설계를 포함했는데, TDM이 이를 자동으로 처리합니다.

정리

MXFP FA 예제에서 buffer_store를 TDM store로 전환하고, 수동 store 레이아웃 계산 코드(약 74줄)를 제거하여 코드를 단순화했습니다.

참고 자료

이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글