[Triton] AxisInfoAnalysis에서 미방문 operand 처리 버그 수정
들어가며
Triton 컴파일러의 AxisInfoAnalysis는 텐서의 contiguity, divisibility, constancy 정보를 추적하는 dataflow analysis다. 이 분석이 정확해야 이후 lowering 단계에서 올바른 vectorization이 이루어진다. 그런데 operand가 아직 방문되지 않은 상태에서 강제로 entry state를 설정하는 기존 로직이 분석 순서에 따라 결과가 달라지는 비결정적 버그를 유발하고 있었다. 이 PR은 미방문 operand를 만나면 해당 operation을 건너뛰고 나중에 다시 방문하도록 수정한다.
핵심 코드 분석
Before
// TODO: For sure not the right way to do this
// but why is scf.if not initialized otherwise?
for (auto op : operands)
if (op->getValue().getRank() == 0)
setToEntryState((dataflow::Lattice<AxisInfo> *)op);
operand가 아직 방문되지 않았을 때(rank == 0) 강제로 setToEntryState를 호출하여 기본값으로 초기화했다. 이 접근법은 scf.ForOp, scf.IfOp 같은 control flow 연산의 결과값도 명시적으로 처리해야 했고, 방문 순서에 따라 분석 결과가 달라지는 문제가 있었다.
After
// If any operands are not yet ready, skip this operation for now.
for (auto op : operands)
if (op->getValue().getRank() == 0)
return success();
미방문 operand가 있으면 해당 operation을 건너뛴다. dataflow framework가 operand가 실제로 방문된 후 다시 이 operation을 처리하므로, control flow 처리를 framework에 위임할 수 있게 되었다.
visitForOpInductionVar에서도 동일한 패턴이 적용되었다:
// Before: 강제 초기화
for (auto op_iter : {lbLattice, stepLattice})
if (op_iter->getValue().getRank() == 0)
setToEntryState((dataflow::Lattice<AxisInfo> *)op_iter);
// After: 준비 안 됐으면 스킵
if (lbLattice->getValue().getRank() == 0 ||
stepLattice->getValue().getRank() == 0) {
return;
}
또한 RegionBranchOpInterface(ForOp, IfOp, WhileOp)에 대한 "unknown" state 초기화 코드를 제거했다. 이제 이들의 결과는 dataflow framework이 자연스럽게 처리한다.
왜 이게 좋은가
- 결정론적 분석: 방문 순서와 무관하게 동일한 결과를 보장한다.
- 코드 단순화: control flow 연산에 대한 특별 처리 코드가 제거되어 약 19줄이 줄었다.
- 크래시 방지: AsyncCopyGlobalToLocalOp lowering 시 vectorization 검증 실패로 인한 크래시를 근본적으로 해결한다.
정리
dataflow analysis에서 "아직 준비 안 된 값을 추측으로 채우는" 대신 "준비될 때까지 기다리는" 방식이 더 안전하다는 교훈을 보여주는 PR이다. +11/-19의 작은 변경이지만 분석 품질과 안정성에 큰 영향을 준다.
참고 자료
이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 코드 분석 내용은 실제 PR diff를 기반으로 합니다.
댓글