본문으로 건너뛰기

[triton] AMD gfx1250 Gluon에 Tensor Async Scatter 지원 추가

PR 링크: triton-lang/triton#9299 상태: Merged | 변경: +766 / -32

들어가며

TDM(Tensor Data Movement) scatter는 shared memory의 데이터를 global memory의 비연속적인 행들에 비동기적으로 쓰는 연산입니다. Flash Attention에서 다양한 시퀀스 위치에 결과를 흩뿌릴 때 유용합니다. 이 PR은 Gluon 프론트엔드에서 이 기능을 사용할 수 있게 합니다.

핵심 코드 분석

Python API

@builtin
def async_scatter(desc: tensor_descriptor, dst_row_indices: ttgl.tensor,
                  dst_col_offset, src: shared_memory_descriptor,
                  mbarrier=None, _semantic=None) -> None:
    """Scatter data from shared memory to non-contiguous rows.
    
    dst_row_indices의 dtype에 따라:
    - int16: 한 번에 최대 16행
    - int32: 한 번에 최대 8행
    """
    ndim = len(desc.block_shape)
    assert ndim == 2, f"TDM scatter only supports 2D tensors"

MLIR Op 정의

def AsyncTDMScatterOp : TT_AMDGPU_Op<"async_tdm_scatter"> {
  let arguments = (ins
    Arg<TT_TensorDescType, "", [MemWrite<GlobalMemory>]>:$desc,
    TensorOf<[I16, I32]>:$dst_row_indices,
    I32:$dst_col_offset,
    Arg<TTG_MemDescType, "", [MemRead<SharedMemory>]>:$src,
    Optional<TTG_MemDescType>:$barrier
  );
}

왜 이게 좋은가

  1. 비연속 쓰기: 인덱스 기반으로 임의의 행에 데이터를 쓸 수 있어 scatter 패턴에 효율적입니다.
  2. 비동기 실행: TDM 엔진이 GPU 코어와 병렬로 데이터를 전송합니다.
  3. 완전한 스택: Python API -> MLIR Op -> LLVM lowering까지 전체 스택을 구현했습니다.

정리

AMD gfx1250의 TDM scatter 하드웨어 기능을 Gluon에서 활용할 수 있게 한 기능 추가입니다. Flash Attention 등에서 비연속 메모리 쓰기 패턴의 성능을 하드웨어 가속으로 극대화할 수 있습니다.

참고 자료


이 글은 AI(Claude)의 도움을 받아 작성되었으며, 원본 PR의 코드 변경 사항을 기반으로 분석한 내용입니다.

댓글

관련 포스트

PR Analysis 의 다른글