[triton] AMD gfx1250 Gluon에 Tensor Async Scatter 지원 추가
PR 링크: triton-lang/triton#9299 상태: Merged | 변경: +766 / -32
들어가며
TDM(Tensor Data Movement) scatter는 shared memory의 데이터를 global memory의 비연속적인 행들에 비동기적으로 쓰는 연산입니다. Flash Attention에서 다양한 시퀀스 위치에 결과를 흩뿌릴 때 유용합니다. 이 PR은 Gluon 프론트엔드에서 이 기능을 사용할 수 있게 합니다.
핵심 코드 분석
Python API
@builtin
def async_scatter(desc: tensor_descriptor, dst_row_indices: ttgl.tensor,
dst_col_offset, src: shared_memory_descriptor,
mbarrier=None, _semantic=None) -> None:
"""Scatter data from shared memory to non-contiguous rows.
dst_row_indices의 dtype에 따라:
- int16: 한 번에 최대 16행
- int32: 한 번에 최대 8행
"""
ndim = len(desc.block_shape)
assert ndim == 2, f"TDM scatter only supports 2D tensors"
MLIR Op 정의
def AsyncTDMScatterOp : TT_AMDGPU_Op<"async_tdm_scatter"> {
let arguments = (ins
Arg<TT_TensorDescType, "", [MemWrite<GlobalMemory>]>:$desc,
TensorOf<[I16, I32]>:$dst_row_indices,
I32:$dst_col_offset,
Arg<TTG_MemDescType, "", [MemRead<SharedMemory>]>:$src,
Optional<TTG_MemDescType>:$barrier
);
}
왜 이게 좋은가
- 비연속 쓰기: 인덱스 기반으로 임의의 행에 데이터를 쓸 수 있어 scatter 패턴에 효율적입니다.
- 비동기 실행: TDM 엔진이 GPU 코어와 병렬로 데이터를 전송합니다.
- 완전한 스택: Python API -> MLIR Op -> LLVM lowering까지 전체 스택을 구현했습니다.
정리
AMD gfx1250의 TDM scatter 하드웨어 기능을 Gluon에서 활용할 수 있게 한 기능 추가입니다. Flash Attention 등에서 비연속 메모리 쓰기 패턴의 성능을 하드웨어 가속으로 극대화할 수 있습니다.
참고 자료
이 글은 AI(Claude)의 도움을 받아 작성되었으며, 원본 PR의 코드 변경 사항을 기반으로 분석한 내용입니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [Ray Data] 논리적 최적화 규칙에서 in-place 변형을 제거하여 불변성 준비
- 현재글 : [triton] AMD gfx1250 Gluon에 Tensor Async Scatter 지원 추가
- 다음글 [Open WebUI] 검색 쿼리 디바운스 적용으로 불필요한 DB 요청 감소
댓글