본문으로 건너뛰기

[triton] AMD TensorDescType의 Shared Memory 크기 계산 수정

PR 링크: triton-lang/triton#9506 상태: Merged | 변경: +90 / -14

들어가며

WarpSpecialize는 GPU의 warp들을 서로 다른 역할(producer/consumer)로 나누어 파이프라인 병렬성을 극대화하는 기법입니다. 이때 warp 간 데이터 전달을 위해 shared memory에 capture list를 저장하는데, 각 타입의 크기를 정확히 계산해야 합니다. 이 PR은 AMD의 TensorDescType(TDM 디스크립터)이 capture로 전달될 때의 크기 계산을 수정합니다.

핵심 코드 분석

기존에는 getCaptureSizeAlign()이 하나의 함수로 size와 alignment를 동시에 반환했고, getSharedMemorySize가 static 함수로 제한되어 있었습니다.

Before:

static size_t getSharedMemorySize(Type type) { ... }

std::pair<uint64_t, uint64_t> WarpSpecializeOp::getCaptureSizeAlign() {
  uint64_t captureSize = 0;
  for (Type type : getPartitionOp().getOperandTypes()) {
    captureSize += getSharedMemorySize(type);
  }
  return {captureSize, 8};
}

After:

// 이제 public 함수로 변경
size_t getSharedMemorySize(Type type) { ... }

uint64_t WarpSpecializeOp::getCaptureSize() {
  uint64_t captureSize = 0;
  for (Type type : getPartitionOp().getOperandTypes()) {
    captureSize += getSharedMemorySize(type);
  }
  return captureSize;
}

uint64_t WarpSpecializeOp::getCaptureAlign() { return 8; }

핵심은 AMD 백엔드에서 TensorDescType의 dword 수를 정확히 계산하는 유틸리티를 추가한 것입니다:

inline int getTensorDescNumDwords(triton::TensorDescType type) {
  auto shape = type.getBlockType().getShape();
  return (shape.size() > 2) ? (4 + 8 + 4 + 4) : (4 + 8);
  // 2D: 12 dwords = 48 bytes, 3D-5D: 20 dwords = 80 bytes
}

왜 이게 좋은가

TensorDescType의 크기가 차원 수에 따라 다르다는 점(2D=48B, 3D-5D=80B)을 명시적으로 처리함으로써, WarpSpecialize의 shared memory 할당이 정확해집니다. 부정확한 크기 계산은 shared memory 범위를 벗어나는 쓰기를 유발하여 데이터 corruption이나 크래시를 일으킬 수 있습니다. 또한 size와 alignment를 별도 API로 분리한 것은 allocation analysis에서 backend별 scratch size를 독립적으로 오버라이드할 수 있게 하여 확장성을 높였습니다.

정리

AMD TensorDescType의 shared memory 크기를 차원 수에 따라 정확히 계산하는 유틸리티를 추가하고, WarpSpecialize의 capture size/align API를 분리하여 backend별 커스터마이징을 지원하도록 개선했습니다.

참고 자료

이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글