본문으로 건너뛰기

[triton] MemDescSubslice에서 Non-CTA 차원 슬라이싱 지원

PR 링크: triton-lang/triton#9507 상태: Merged | 변경: +59 / -12

들어가며

Triton의 memdesc_subslice 연산은 shared memory descriptor를 더 작은 부분으로 나누는 역할을 합니다. 기존에는 non-trivial block 차원(multi-CTA)이 있으면 일률적으로 에러를 반환했지만, 실제로는 CTA가 데이터를 broadcast하지 않는 한 안전하게 슬라이싱할 수 있습니다. 이 PR은 이 제한을 완화하여 multi-CTA 환경에서의 메모리 슬라이싱을 가능하게 합니다.

핵심 코드 분석

Before:

// NYI: We don't support non-trivial block dimension for now.
auto kBlock = mlir::StringAttr::get(getContext(), "block");
if (ll.getInDimSize(kBlock) != 1) {
    return emitError("non-trivial block dimension not supported");
}

After:

// If any block basis is fully broadcasted, multiple CTAs can alias the same
// output tile region. Subslice on such layouts is unsupported.
auto kBlock = mlir::StringAttr::get(ctx, "block");
if (ll.getFreeVariableMasks()[kBlock] != 0) {
    return emitError("We don't support splitting with broadcasted CTA outputs");
}

// ... offset 검증에서도 block 차원 체크 추가
auto offsetAndBlock = llInv.apply(namedOffsets);
auto offset = offsetAndBlock[0];
auto block = offsetAndBlock[1];
if (!llvm::isPowerOf2_32(offset.second) && offset.second != 0) {
    return emitError("We don't support splitting along the swizzling pattern");
}
if (block.second != 0) {
    return emitError("We don't support splitting along CTA dimensions");
}

왜 이게 좋은가

기존의 보수적인 검증(blockDimSize != 1이면 거부)은 안전하지만 지나치게 제한적이었습니다. 새 로직은 LinearLayout의 free variable mask를 활용하여 **실제로 위험한 경우(broadcast된 CTA)**만 거부하고, CTA가 독립적인 영역을 담당하는 경우는 허용합니다. 이를 통해 multi-CTA GEMM 등에서 shared memory 타일을 서브타일로 분할하는 최적화가 가능해집니다.

정리

memdesc_subslice의 block 차원 검증을 "크기 1이 아니면 거부"에서 "broadcast된 CTA가 있으면 거부"로 정교화하고, 슬라이싱 시 CTA 차원을 따라 분할하는 경우도 별도로 검증하도록 개선했습니다.

참고 자료

이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글