본문으로 건너뛰기

[triton] AMD TDM의 Partition-Aware 분할 및 다중 Intrinsic 지원

PR 링크: triton-lang/triton#9844 상태: Merged | 변경: +461 / -61

들어가며

TDM(Triton Data Mover)은 AMD GPU에서 global memory와 shared memory 간의 비동기 데이터 전송을 담당합니다. Partitioned shared encoding은 하나의 shared memory 블록을 여러 파티션으로 나누어 파이프라인 버퍼링을 구현하는데, 기존의 warp 배분 알고리즘은 파티션 경계를 고려하지 않아 하나의 warp가 여러 파티션에 걸쳐 데이터를 쓰는 문제가 있었습니다.

핵심 코드 분석

Before (파티션 무시):

SmallVector<unsigned> getWarpDistribution(ArrayRef<int64_t> blockShape,
                                          int numWarps) {
  // 단순히 blockShape 기반으로 warp 배분
  tdmGetWarpDistribution(blockShape.data(), numDims, numWarps, warps.data());
}

After (파티션 경계 인식):

std::pair<SmallVector<unsigned>, unsigned> distributeTDMWarpsAlignToPartition(
    ArrayRef<int64_t> blockShape, int numWarps,
    PartitionedSharedEncodingAttr partitionedEnc) {
  unsigned numDims = blockShape.size();
  unsigned partitionDim = partitionedEnc.getPartitionDim();
  // 1. numLogicalPieces개의 warp를 partitionDim에 우선 배정
  // 2. 남은 warp를 non-partition, non-inner 차원에 배분
  // 3. innermost 차원은 절대 분할하지 않음 (패딩 정합성)
  // Returns: (warpsPerCTA, numTDMInstructions)
}

핵심은 innermost 차원을 분할하지 않는 이유입니다: TDM은 데이터를 linear element stream으로 기록하며 padInterval마다 패딩을 삽입합니다. warp 단위로 inner 차원을 분할하면 per-warp extent가 row width와 달라져 패딩 경계가 어긋납니다.

왜 이게 좋은가

파티션 경계를 인식하는 warp 배분은 각 warp가 단일 파티션 내에서만 동작하도록 보장하여, 파티션 간 데이터 오염을 방지합니다. numWarps < numLogicalPieces인 경우 여러 TDM 명령어를 생성하고, 이 수를 wait count 계산에 반영하여 동기화 정합성도 유지합니다. 이전에는 partitioned layout에서 TDM을 사용할 수 없거나 잘못된 결과를 낼 수 있었는데, 이 PR로 안전하게 지원됩니다.

정리

TDM의 warp 배분을 partitioned shared encoding의 파티션 경계에 맞추고, 다중 TDM 명령어 생성 시 정확한 wait count를 계산하도록 개선했습니다.

참고 자료

이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글