본문으로 건너뛰기

[Triton] Warp Specialization 중첩 루프 지원

PR 링크: triton-lang/triton#8687 상태: Merged | 변경: +1377 / -131

들어가며

Warp Specialization(WS)은 NVIDIA Blackwell GPU에서 서로 다른 warp 그룹이 load와 MMA를 동시에 수행하는 기법이다. 기존에는 단일 루프만 지원했지만, persistent attention 같은 커널은 외부 루프(타일 반복)와 내부 루프(K 차원 반복)의 중첩 구조를 가진다. 이 PR은 partition-schedule 패스를 재귀적으로 확장하여 중첩 루프의 WS를 E2E로 지원한다.

핵심 코드 분석

Before (PartitionSet::fromLoop)

// 루프 body의 직접 연산만 파티셔닝
for (Operation &op : loop.getBody()->without_terminator()) {
  auto attrs = getPartitionIds(&op);
  for (auto idx : attrs) {
    result.partitions[idx]->addOp(&op);
  }
}

After (재귀적 파티셔닝)

// 중첩 루프 내부의 연산도 재귀적으로 파티셔닝
SmallVector<Operation *> annotatedOps;
loop->walk([&](Operation *op) {
  if (hasPartition(op)) {
    annotatedOps.push_back(op);
  }
});

for (auto op : annotatedOps) {
  auto attrs = getPartitionIds(op);
  for (auto idx : attrs) {
    result.partitions[idx]->addOp(op);
  }
}

tmem_alloc Hoisting

내부 루프의 accumulator 초기화를 외부 루프 밖으로 끌어올려, 기본 파티션과 MMA 파티션 간 불필요한 동기화를 방지한다. tl.assume으로 내부 루프가 최소 1회 실행됨이 보장되는 경우를 감지하여 hoisting 안전성을 검증한다.

왜 이게 좋은가

  • E2E 중첩 루프: persistent attention의 외부(타일)/내부(K) 루프 구조가 WS로 완전히 지원된다.
  • 불필요한 동기화 제거: tmem_alloc을 최상위로 hoisting하면, 매 외부 루프 반복마다 accumulator를 0으로 초기화하는 동기화 비용이 사라진다.
  • TMA descriptor 멀티버퍼링: 중첩 루프에서 SWP에 의존하지 않고 독립적으로 TMA descriptor를 멀티버퍼링한다.

정리

중첩 루프의 WS 지원은 persistent kernel 패턴에서 필수적이다. 이 PR은 파티셔닝의 재귀적 확장, tmem_alloc hoisting, TMA descriptor 멀티버퍼링이라는 세 가지 핵심 변경을 통해 이를 달성한다.

참고 자료


이 글은 AI 도구의 도움을 받아 작성되었습니다.

댓글

관련 포스트

PR Analysis 의 다른글