[triton] AMD TensorDescType의 Shared Memory 크기 계산 수정
PR 링크: triton-lang/triton#9506 상태: Merged | 변경: +90 / -14
들어가며
WarpSpecialize는 GPU의 warp들을 서로 다른 역할(producer/consumer)로 나누어 파이프라인 병렬성을 극대화하는 기법입니다. 이때 warp 간 데이터 전달을 위해 shared memory에 capture list를 저장하는데, 각 타입의 크기를 정확히 계산해야 합니다. 이 PR은 AMD의 TensorDescType(TDM 디스크립터)이 capture로 전달될 때의 크기 계산을 수정합니다.
핵심 코드 분석
기존에는 getCaptureSizeAlign()이 하나의 함수로 size와 alignment를 동시에 반환했고, getSharedMemorySize가 static 함수로 제한되어 있었습니다.
Before:
static size_t getSharedMemorySize(Type type) { ... }
std::pair<uint64_t, uint64_t> WarpSpecializeOp::getCaptureSizeAlign() {
uint64_t captureSize = 0;
for (Type type : getPartitionOp().getOperandTypes()) {
captureSize += getSharedMemorySize(type);
}
return {captureSize, 8};
}
After:
// 이제 public 함수로 변경
size_t getSharedMemorySize(Type type) { ... }
uint64_t WarpSpecializeOp::getCaptureSize() {
uint64_t captureSize = 0;
for (Type type : getPartitionOp().getOperandTypes()) {
captureSize += getSharedMemorySize(type);
}
return captureSize;
}
uint64_t WarpSpecializeOp::getCaptureAlign() { return 8; }
핵심은 AMD 백엔드에서 TensorDescType의 dword 수를 정확히 계산하는 유틸리티를 추가한 것입니다:
inline int getTensorDescNumDwords(triton::TensorDescType type) {
auto shape = type.getBlockType().getShape();
return (shape.size() > 2) ? (4 + 8 + 4 + 4) : (4 + 8);
// 2D: 12 dwords = 48 bytes, 3D-5D: 20 dwords = 80 bytes
}
왜 이게 좋은가
TensorDescType의 크기가 차원 수에 따라 다르다는 점(2D=48B, 3D-5D=80B)을 명시적으로 처리함으로써, WarpSpecialize의 shared memory 할당이 정확해집니다. 부정확한 크기 계산은 shared memory 범위를 벗어나는 쓰기를 유발하여 데이터 corruption이나 크래시를 일으킬 수 있습니다. 또한 size와 alignment를 별도 API로 분리한 것은 allocation analysis에서 backend별 scratch size를 독립적으로 오버라이드할 수 있게 하여 확장성을 높였습니다.
정리
AMD TensorDescType의 shared memory 크기를 차원 수에 따라 정확히 계산하는 유틸리티를 추가하고, WarpSpecialize의 capture size/align API를 분리하여 backend별 커스터마이징을 지원하도록 개선했습니다.
참고 자료
이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [faster-qwen3-tts] 패키지 리네이밍 및 코드 간소화
- 현재글 : [triton] AMD TensorDescType의 Shared Memory 크기 계산 수정
- 다음글 [faster-qwen3-tts] CustomVoice/VoiceDesign 지원, CLI, PyPI 배포, 스트리밍 UX 개선
댓글