[triton] AMD gfx1250 Gluon에 Tensor Async Gather(TDM) 지원 추가
PR 링크: triton-lang/triton#9313 상태: Merged | 변경: +505 / -54
들어가며
TDM gather는 global memory의 비연속적인 행들에서 데이터를 shared memory로 비동기적으로 읽어오는 연산입니다. 앞서 추가된 scatter의 대칭 연산으로, Flash Attention에서 KV cache의 비연속 토큰 위치에서 데이터를 수집할 때 유용합니다.
핵심 코드 분석
Python API
@builtin
def async_gather(desc: tensor_descriptor, src_row_indices: ttgl.tensor,
src_col_offset, dst: shared_memory_descriptor,
mbarrier=None, _semantic=None) -> None:
"""Gather data from non-contiguous rows in global memory.
src_row_indices의 dtype에 따라:
- int16: 한 번에 최대 16행
- int32: 한 번에 최대 8행
"""
MLIR Op 정의
def AsyncTDMGatherOp : TT_AMDGPU_Op<"async_tdm_gather"> {
let arguments = (ins
Arg<TT_TensorDescType, "", [MemRead<GlobalMemory>]>:$desc,
TensorOf<[I16, I32]>:$src_row_indices,
I32:$src_col_offset,
Arg<TTG_MemDescType, "", [MemWrite<SharedMemory>]>:$dst,
Optional<TTG_MemDescType>:$barrier
);
}
LLVM Lowering에서 scatter와 코드 공유
// scatter와 gather가 동일한 emitTDMGatherScatter 함수 사용
mlir::LLVM::AMD::emitTDMGatherScatter(
rewriter, loc, getTypeConverter(), desc, shapePerCTA, srcPtr, pred,
elementType, barrierPtr, cgaLayout, ctaId, srcRowIndices, srcColOffset,
use32BitIndices, /*isGather=*/true);
왜 이게 좋은가
- 대칭 API: scatter/gather가 동일한 패턴으로 제공되어 학습 비용이 낮습니다.
- 코드 재사용: LLVM lowering에서
emitTDMGatherScatter로 scatter/gather 공통 로직을 공유합니다. - Flash Attention 최적화: KV cache에서 비연속 토큰을 효율적으로 수집할 수 있습니다.
정리
TDM scatter에 이은 gather 기능 추가로, AMD gfx1250의 TDM 하드웨어를 활용한 비연속 메모리 접근의 양방향 지원이 완성되었습니다.
참고 자료
이 글은 AI(Claude)의 도움을 받아 작성되었으며, 원본 PR의 코드 변경 사항을 기반으로 분석한 내용입니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [triton] Triton AMD GPU 백엔드: v_perm 명령어를 활용한 레이아웃 변환 최적화
- 현재글 : [triton] AMD gfx1250 Gluon에 Tensor Async Gather(TDM) 지원 추가
- 다음글 [triton] AMD MoveUpPrologueLoads로 ReorderInstructions 패스 완전 대체
댓글