본문으로 건너뛰기

[triton] Warp Specialization: 데이터 플로우 그래프 기반의 개선된 파티션 스케줄링 패스

PR 링크: triton-lang/triton#7312 상태: Merged | 변경: +2334 / -938

들어가며

Warp Specialization에서 파티션 스케줄링은 루프 내 연산을 producer/consumer 파티션에 어떻게 배치할지 결정하는 핵심 패스입니다. 기존 구현은 특정 패턴에 최적화되어 범용성이 부족했습니다. 이번 PR은 데이터 플로우 그래프(DFG)를 구축하고, heuristic 기반 점진적 병합으로 파티션을 결정하는 새로운 패스를 제공합니다.

핵심 코드 분석

그래프 노드 구조

// PartitionSchedulingUtility.h
enum Flags : uint8_t {
  NONE = 0, MANUAL = 1 << 0, LOAD = 1 << 1, STORE = 1 << 2,
  MMA = 1 << 3, TMEM = 1 << 4, SFU = 1 << 5, VIEW = 1 << 6,
};

class Partition {
public:
  void add(Node *node);
  size_t getStage() const {
    if (flags & Flags::MMA) return 1;  // MMA는 stage 1
    return 0;                           // 나머지는 stage 0
  }
  static void merge(Partition *lhs, Partition *rhs);
};

class Node {
public:
  explicit Node(Operation *op) : op(op), cost(computeCost(op)) {}
  // 입력/출력 포트와 엣지를 통한 데이터 의존성 표현
  static void addEdge(OutputPort from, InputPort to);
};

파티션 병합 로직

그래프가 구축되면, 비용 기반 heuristic으로 파티션을 점진적으로 병합합니다. LOAD 플래그가 있는 노드는 producer 파티션에, MMA 플래그가 있는 노드는 consumer 파티션에 배치됩니다. 파티션 간 엣지를 통해 데이터 전달 비용을 계산하고, 비용이 낮은 파티션 쌍을 우선 병합합니다.

왜 이게 좋은가

  1. 범용 접근: DFG 기반 분석으로 다양한 커널 패턴에 대응할 수 있습니다.
  2. Drop-in 대체: 기존 패스와 동일한 인터페이스를 유지하면서 내부 알고리즘만 교체했습니다.
  3. 성능 검증: B200에서 09-persistent-matmul.py06-fused-attention.py에서 성능 회귀가 없음을 확인했습니다.

정리

2334줄 추가의 대규모 재작성이지만, 컴파일러 패스를 더 범용적이고 확장 가능한 구조로 개선한 좋은 사례입니다. 데이터 플로우 그래프 분석은 파티셔닝 문제의 자연스러운 모델링 방법입니다.

참고 자료


이 글은 AI(Claude)의 도움을 받아 작성되었으며, PR의 실제 diff를 기반으로 분석한 내용입니다.

댓글

관련 포스트

PR Analysis 의 다른글