[triton] ConSan Multi-CTA 지원 추가
PR 링크: triton-lang/triton#9764 상태: Merged | 변경: +1452 / -935
들어가며
Triton의 Concurrency Sanitizer(ConSan)는 GPU 커널의 동시성 버그를 런타임에 감지하는 도구입니다. 기존에는 단일 CTA 환경만 지원했는데, 이 PR은 multi-CTA 클러스터(여러 CTA가 하나의 클러스터로 묶여 협력하는 구조)에서도 ConSan이 정상 동작하도록 확장합니다. 핵심은 클러스터 전체가 공유하는 scratch memory 영역을 올바르게 관리하는 것입니다.
핵심 코드 분석
1. shared_cluster_state 속성 도입
Before:
let arguments = (
ins
I32Attr:$nbytes,
I32Attr:$alignment,
OptionalAttr<UnitAttr>:$third_party_allocation
);
After:
let arguments = (
ins
I32Attr:$nbytes,
I32Attr:$alignment,
OptionalAttr<UnitAttr>:$third_party_allocation,
OptionalAttr<UnitAttr>:$shared_cluster_state
);
GlobalScratchAllocOp에 shared_cluster_state 속성을 추가하여, 해당 allocation이 클러스터 전체에서 공유되는 상태인지 명시합니다.
2. Profile scratch pointer의 CTA 오프셋 제어
Before:
if (numCTAs > 1) {
linearId = b.mul(linearId, b.i32_val(numCTAs));
linearId = b.add(linearId, targetInfo.getClusterCTAId(rewriter, loc));
}
After:
if (numCTAs > 1) {
linearId = b.mul(linearId, b.i32_val(numCTAs));
if (currentCTA)
linearId = b.add(linearId, targetInfo.getClusterCTAId(rewriter, loc));
}
currentCTA 플래그로 포인터를 현재 CTA 기준으로 오프셋할지, 클러스터 전체 텐서의 시작점을 유지할지 제어합니다. shared_cluster_state로 표시된 allocation은 currentCTA=false로 설정되어 모든 CTA가 같은 메모리 영역을 봅니다.
3. 클러스터 CTA ID 연산 추가
def TTI_ExperimentalClusterCTAIdOp
: TTI_Op<"experimental_cluster_cta_id", [Pure]> {
let summary = "Get the CTA id within the current cluster";
let description = [{
Return the cluster-local CTA id used to index ConSan's multi-CTA scratch
slabs. For single-CTA kernels this is always zero.
}];
let results = (outs I32:$result);
}
ConSan이 multi-CTA scratch slab을 인덱싱할 때 사용하는 새로운 IR 연산입니다.
왜 이게 좋은가
- 클러스터 동시성 버그 감지: Multi-CTA 환경에서 발생하는 공유 메모리 경합 문제를 런타임에 감지할 수 있게 되었습니다.
- 기존 코드 호환성:
shared_cluster_state는OptionalAttr이므로 기존 단일 CTA 코드에 영향 없습니다. - 유연한 메모리 관리:
currentCTA파라미터로 per-CTA와 cluster-wide allocation을 하나의 코드 경로에서 처리합니다.
정리
이 PR은 ConSan의 동작 범위를 단일 CTA에서 multi-CTA 클러스터로 확장합니다. shared_cluster_state 속성과 currentCTA 플래그를 통해 scratch memory의 CTA 오프셋 여부를 제어하고, 새로운 ExperimentalClusterCTAIdOp으로 클러스터 내 CTA ID를 얻어 올바른 slab 인덱싱을 수행합니다.
참고 자료
이 글은 AI의 도움을 받아 작성되었으며, 원본 PR의 코드 변경 사항을 기반으로 분석한 내용입니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [axolotl] Axolotl: Triton 커널을 활용한 Entropy 및 Selective Log Softmax 최적화
- 현재글 : [triton] ConSan Multi-CTA 지원 추가
- 다음글 [triton] getTranspositionSelectors 알고리즘 단순화 및 복원
댓글