[triton] AMD Canonicalize Pointers에서 arith.select의 비대칭 fat pointer 처리 강화
PR 링크: triton-lang/triton#8779 상태: Merged | 변경: +76 / -6
들어가며
Triton의 AMD 백엔드에는 포인터를 base + offset 형태의 "fat pointer"로 정규화하는 CanonicalizePointers 패스가 있습니다. 이 패스는 32비트 정수 연산으로 오프셋을 처리하여 64비트 포인터 연산을 줄이는 최적화를 수행합니다. 그러나 arith.select 연산에서 한쪽 인자만 fat pointer로 변환된 경우(예: pointer_range = 32가 한쪽에만 설정된 경우), 기존 코드가 크래시하거나 잘못된 결과를 생성하는 문제가 있었습니다.
핵심 코드 분석
Before:
LogicalResult matchAndRewrite_(arith::SelectOp selectOp, ...) {
if (adaptor.getTrueValue().size() != 2 ||
adaptor.getFalseValue().size() != 2) {
assert(adaptor.getTrueValue().size() == adaptor.getFalseValue().size());
return success(); // 둘 다 변환 안 된 경우만 처리
}
// 둘 다 fat pointer인 경우만 처리
ValueRange fatPtrFalse = adaptor.getFalseValue();
ValueRange fatPtrTrue = adaptor.getTrueValue();
// select(base_t, base_f), select(offset_t, offset_f)
}
After:
LogicalResult matchAndRewrite_(arith::SelectOp selectOp, ...) {
ValueRange fatPtrFalse = adaptor.getFalseValue();
ValueRange fatPtrTrue = adaptor.getTrueValue();
if (fatPtrTrue.size() == 1 && fatPtrFalse.size() == 1)
return success(); // 둘 다 미변환: 패스
if (fatPtrTrue.size() != 2 || fatPtrFalse.size() != 2) {
// 비대칭: 한쪽만 fat pointer인 경우
Value trueOp, falseOp;
if (fatPtrTrue.size() == 2) {
trueOp = tt::AddPtrOp::create(rewriter, loc,
selectOp.getType(), fatPtrTrue[0], fatPtrTrue[1]);
} else {
trueOp = fatPtrTrue[0];
}
if (fatPtrFalse.size() == 2) {
falseOp = tt::AddPtrOp::create(rewriter, loc,
selectOp.getType(), fatPtrFalse[0], fatPtrFalse[1]);
} else {
falseOp = fatPtrFalse[0];
}
auto newSelectOp = arith::SelectOp::create(rewriter, loc,
selectOp.getType(), selectOp.getCondition(), trueOp, falseOp);
rewriter.replaceOp(selectOp, newSelectOp);
return success();
}
// 둘 다 fat pointer: 기존 로직 유지
}
핵심 변경은 fat pointer가 비대칭적으로 적용된 경우의 처리입니다. 한쪽이 (base, offset) 형태이고 다른 쪽은 원래 포인터 그대로인 경우, fat pointer 쪽을 AddPtrOp으로 다시 합쳐서 일반 포인터로 되돌린 후 select를 수행합니다.
왜 이게 좋은가
이 수정은 방어적 프로그래밍의 좋은 사례입니다. 기존 코드는 select의 양쪽 인자가 항상 동일한 형태(둘 다 변환되거나 둘 다 미변환)라고 가정했지만, tt.pointer_range = 32 속성이 한쪽에만 설정된 실제 워크로드에서 이 가정이 깨졌습니다. offset이 같다고 암묵적으로 가정하는 것 역시 정적으로 안전하지 않은 코드였습니다. 이제 모든 조합을 명시적으로 처리하므로 컴파일러 크래시가 방지됩니다.
정리
arith.select에서 비대칭 fat pointer 조합 처리 추가- 한쪽만 fat pointer인 경우
AddPtrOp으로 재합성 후 select 수행 - 크래시 방지 및 정확한 코드 생성 보장
- 테스트 케이스(
_scalar_select) 추가
참고 자료
이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [sglang] SGLang Ascend NPU에서 Ring-SP를 활용한 성능 최적화 가이드
- 현재글 : [triton] AMD Canonicalize Pointers에서 arith.select의 비대칭 fat pointer 처리 강화
- 다음글 [sglang] FlashInfer v0.6.7 MXFP8 Gemm 통합: CUTLASS와 TensorRT-LLM 백엔드 분리
댓글