본문으로 건너뛰기

[triton] rewrite-partition-dependencies를 insert-aref로 통합하여 Warp Specialization 파이프라인 간소화

PR 링크: triton-lang/triton#8619 상태: Merged | 변경: +754 / -1038

들어가며

Triton의 Warp Specialization은 GPU warp들을 서로 다른 역할(producer/consumer)로 분화시켜 파이프라인 병렬성을 극대화하는 기법입니다. 기존에는 파티션 간 SSA 의존성을 shared memory로 재작성하는 별도의 RewritePartitionDependencies pass가 있었는데, 이 PR은 그 기능을 InsertAref pass에 통합합니다.

핵심 코드 분석

Before - 별도 pass 존재

// AutomaticWarpSpecialization.cpp
pm.addPass(createTritonGPUPartitionScheduling());
pm.addPass(createNVWSInsertAref());
pm.addPass(createNVWSInsertTmemAref());
pm.addPass(createTritonGPURewritePartitionDependencies()); // 별도 pass

After - pass 제거 및 통합

// AutomaticWarpSpecialization.cpp
pm.addPass(createTritonGPUPartitionScheduling());
pm.addPass(createNVWSInsertAref());
pm.addPass(createNVWSInsertTmemAref());
// RewritePartitionDependencies 제거 - InsertAref에 통합됨

새로운 getPartitionIds 헬퍼

SetVector<int> getPartitionIds(OpOperand *use) {
  auto owner = use->getOwner();
  if (isa<scf::YieldOp>(owner)) {
    return getPartitionOutputs(owner->getParentOp())[use->getOperandNumber()];
  } else if (scf::ForOp forOp = dyn_cast<scf::ForOp>(owner)) {
    int idx = use->getOperandNumber() - forOp.getNumControlOperands();
    return idx >= 0 ? getPartitionOutputs(owner)[idx] : *getPartitionIds(forOp);
  } else {
    return *getPartitionIds(owner);
  }
}

왜 이게 좋은가

  1. 파이프라인 간소화: 1038줄 삭제 / 754줄 추가로 순수 284줄 감소하면서 pass 수가 줄었습니다.
  2. 정보 지역성: aref 삽입 시점에서 이미 파티션 정보가 있으므로, 같은 시점에 dependency 재작성을 하면 정보 전달이 자연스럽습니다.
  3. 유지보수 용이: RewritePartitionDependencies.cpp 파일 전체가 삭제되어 코드베이스가 간결해졌습니다.

정리

Warp Specialization의 컴파일 파이프라인에서 두 개의 pass를 하나로 통합한 리팩터링입니다. 기능적으로는 동일하지만 코드 구조가 간결해지고 pass 간 정보 전달이 자연스러워졌습니다.

참고 자료


이 글은 AI(Claude)의 도움을 받아 작성되었으며, 원본 PR의 코드 변경 사항을 기반으로 분석한 내용입니다.

댓글

관련 포스트

PR Analysis 의 다른글