[Triton] AMD scf.if else 분기 누락 버그 수정 — deduceMinCountBetweeOps
PR 링크: triton-lang/triton#9034 상태: Merged | 변경: +31 / -0
들어가며
Triton 컴파일러의 AMD 백엔드에는 deduceMinCountBetweeOps라는 함수가 있다. 이 함수는 두 연산 사이에 존재하는 비동기 명령어의 최소 개수를 계산하여 async_wait의 대기 카운트를 결정한다. 문제는 scf.if에 else 영역이 없는 경우를 처리하지 않아, else가 없으면 assert에서 크래시가 발생하는 것이었다.
핵심 코드 분석
Before
int count = 0;
for (auto op = beginOp; op != endOp; op = op->getNextNode()) {
if (auto ifOp = llvm::dyn_cast<scf::IfOp>(op)) {
assert(!ifOp.getThenRegion().empty() && !ifOp.getElseRegion().empty());
auto minThen =
deduceMinCountInBlock(ifOp.getThenRegion().front(), countFunc);
else 영역이 비어있으면 assert가 실패한다.
After
int count = 0;
for (auto op = beginOp; op != endOp; op = op->getNextNode()) {
if (auto ifOp = llvm::dyn_cast<scf::IfOp>(op)) {
if (ifOp.getElseRegion().empty())
continue;
assert(!ifOp.getThenRegion().empty() && !ifOp.getElseRegion().empty());
auto minThen =
deduceMinCountInBlock(ifOp.getThenRegion().front(), countFunc);
else가 없으면 해당 scf.if를 건너뛴다. else 경로에서의 비동기 명령어 수가 0이므로, 양쪽 경로의 최솟값은 0이 되기 때문이다.
왜 이게 좋은가
- 정확성 복원: else가 없는
scf.if는 else 경로에서 비동기 명령어가 0개다. 최솟값을 취하면 0이므로,continue로 건너뛰는 것이 수학적으로 올바르다. - 크래시 방지: 실제 커널에서
scf.if에 else가 없는 패턴은 흔하다. 이 수정 없이는 해당 패턴의 커널이 컴파일 시 크래시한다. - 테스트 포함: else 없는
scf.if가 def chain에 있을 때num_inst = 0이 올바르게 추론되는지 검증하는 MLIR 테스트가 추가되었다.
정리
단 3줄의 가드 조건 추가로, else 분기가 없는 제어 흐름에서 발생하는 크래시를 해결했다. 간결하지만 논리적으로 정확한 수정이다.
참고 자료
이 글은 AI 도구의 도움을 받아 작성되었습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [triton] Triton GFX1250 MXFP GEMM 커널의 4-Warp 스케줄링 최적화 분석
- 현재글 : [Triton] AMD scf.if else 분기 누락 버그 수정 — deduceMinCountBetweeOps
- 다음글 [Triton] Frontend에서 scaled batched matrix multiply 지원
댓글