본문으로 건너뛰기

[triton] AMD gfx1250 Gluon에 Tensor Async Gather(TDM) 지원 추가

PR 링크: triton-lang/triton#9313 상태: Merged | 변경: +505 / -54

들어가며

TDM gather는 global memory의 비연속적인 행들에서 데이터를 shared memory로 비동기적으로 읽어오는 연산입니다. 앞서 추가된 scatter의 대칭 연산으로, Flash Attention에서 KV cache의 비연속 토큰 위치에서 데이터를 수집할 때 유용합니다.

핵심 코드 분석

Python API

@builtin
def async_gather(desc: tensor_descriptor, src_row_indices: ttgl.tensor,
                 src_col_offset, dst: shared_memory_descriptor,
                 mbarrier=None, _semantic=None) -> None:
    """Gather data from non-contiguous rows in global memory.
    
    src_row_indices의 dtype에 따라:
    - int16: 한 번에 최대 16행
    - int32: 한 번에 최대 8행
    """

MLIR Op 정의

def AsyncTDMGatherOp : TT_AMDGPU_Op<"async_tdm_gather"> {
  let arguments = (ins
    Arg<TT_TensorDescType, "", [MemRead<GlobalMemory>]>:$desc,
    TensorOf<[I16, I32]>:$src_row_indices,
    I32:$src_col_offset,
    Arg<TTG_MemDescType, "", [MemWrite<SharedMemory>]>:$dst,
    Optional<TTG_MemDescType>:$barrier
  );
}

LLVM Lowering에서 scatter와 코드 공유

// scatter와 gather가 동일한 emitTDMGatherScatter 함수 사용
mlir::LLVM::AMD::emitTDMGatherScatter(
    rewriter, loc, getTypeConverter(), desc, shapePerCTA, srcPtr, pred,
    elementType, barrierPtr, cgaLayout, ctaId, srcRowIndices, srcColOffset,
    use32BitIndices, /*isGather=*/true);

왜 이게 좋은가

  1. 대칭 API: scatter/gather가 동일한 패턴으로 제공되어 학습 비용이 낮습니다.
  2. 코드 재사용: LLVM lowering에서 emitTDMGatherScatter로 scatter/gather 공통 로직을 공유합니다.
  3. Flash Attention 최적화: KV cache에서 비연속 토큰을 효율적으로 수집할 수 있습니다.

정리

TDM scatter에 이은 gather 기능 추가로, AMD gfx1250의 TDM 하드웨어를 활용한 비연속 메모리 접근의 양방향 지원이 완성되었습니다.

참고 자료


이 글은 AI(Claude)의 도움을 받아 작성되었으며, 원본 PR의 코드 변경 사항을 기반으로 분석한 내용입니다.

댓글

관련 포스트

PR Analysis 의 다른글