[triton] AMD BlockPingpong 패스의 non-MFMA dot 크래시 수정
PR 링크: triton-lang/triton#9618 상태: Merged | 변경: +79 / -1
들어가며
AMD GPU의 BlockPingpong 최적화는 MFMA(Matrix Fused Multiply-Add) 명령어 기반의 dot 연산에서 스케줄링을 개선하는 패스입니다. 그러나 FMA 기반 dot 연산(MFMA가 아닌 일반 행렬 곱)이 이 패스에 전달되면, MFMA encoding으로의 잘못된 cast로 인해 크래시가 발생했습니다.
핵심 코드 분석
Before:
void Pingponger::getDotPingponged() {
auto encoding = cast<RankedTensorType>(aType).getEncoding();
auto srcEncoding = cast<ttg::DotOperandEncodingAttr>(encoding);
kWidth = srcEncoding.getKWidth();
auto mfmaEncoding = cast<ttg::AMDMfmaEncodingAttr>(srcEncoding.getParent());
// cast 실패 시 크래시
SmallVector<int64_t> intShape;
auto mnkDim = mfmaEncoding.getInstrShape();
// ...
}
After:
void Pingponger::getDotPingponged() {
auto encoding = cast<RankedTensorType>(aType).getEncoding();
auto srcEncoding = cast<ttg::DotOperandEncodingAttr>(encoding);
kWidth = srcEncoding.getKWidth();
auto mfmaEncoding =
dyn_cast<ttg::AMDMfmaEncodingAttr>(srcEncoding.getParent());
if (!mfmaEncoding) {
LDBG("Encountered non-MFMA layout");
return;
}
SmallVector<int64_t> intShape;
auto mnkDim = mfmaEncoding.getInstrShape();
// ...
}
cast를 dyn_cast로 변경하고 null 체크를 추가하여, MFMA가 아닌 인코딩(예: FMA)일 때 안전하게 조기 반환합니다. 테스트에서는 #fake_mma = #ttg.blocked<...> 인코딩을 사용한 FMA dot으로 이 경로를 검증합니다.
왜 이게 좋은가
MLIR 기반 컴파일러에서 cast vs dyn_cast의 올바른 사용은 매우 중요합니다. cast는 타입이 확실할 때만 사용해야 하며, 그렇지 않으면 즉시 크래시합니다. 이 PR은 방어적 프로그래밍의 좋은 사례로, 현재 BlockPingpong이 MFMA 전용임을 명시하면서도 다른 인코딩에서 안전하게 동작하도록 보장합니다. 테스트 케이스에 CHECK-NOT: rocdl.sched.barrier와 CHECK-NOT: rocdl.s.setprio를 포함하여 최적화가 적용되지 않음을 검증합니다.
정리
cast<AMDMfmaEncodingAttr>을dyn_cast로 변경하여 non-MFMA 크래시 방지- FMA dot 연산에 대한 테스트 케이스 추가
- BlockPingpong이 MFMA 전용임을 코드에 명시적으로 표현
참고 자료
이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [triton] Profile scratch용 기본 allocator 제공
- 현재글 : [triton] AMD BlockPingpong 패스의 non-MFMA dot 크래시 수정
- 다음글 [sglang] MoE 모델 추론 최적화: Triton 커널 퓨전을 통한 TTFT 28% 개선
댓글