본문으로 건너뛰기

[Triton] WGMMA rs-dot 분할을 2회로 제한 — 1% MoE 성능 향상

PR 링크: triton-lang/triton#9152 상태: Merged | 변경: +21 / -33

들어가며

WGMMA에서 LHS가 레지스터에 있는 경우(rs-dot), K 차원을 여러 조각으로 분할하여 각 WGMMA가 서로 다른 레지스터 집합을 사용하도록 한다. 이를 통해 하나의 WGMMA가 완료되기 전에 다음 WGMMA가 시작될 수 있다(in-register pipelining). 기존에는 K/instrK개로 최대 분할했지만, 실험 결과 2분할이 최적이었다.

핵심 코드 분석

Before

std::vector<ttng::WarpGroupDotOp> splitRSDot(ttng::WarpGroupDotOp dotOp) {
  // K/16 개의 wgmma(tensor, shmem) Mx16, 16xN → MxN으로 분할
  auto newK = cast<ttg::NvidiaMmaEncodingAttr>(...)
                  .getInstrShape()[2];
  auto numSplits = origK / newK;
  if (numSplits <= 1) return {dotOp};

After

std::vector<ttng::WarpGroupDotOp> splitRSDot(ttng::WarpGroupDotOp dotOp) {
  // wgmma(tensor, shmem, acc)를 2분할:
  //   wgmma(tensor[:, :K//2], shmem[:K//2, :], acc)
  //   wgmma(tensor[:, K//2:], shmem[K//2:, :], acc)
  auto instrK = cast<ttg::NvidiaMmaEncodingAttr>(...)
                    .getInstrShape()[2];
  if (origK <= instrK) return {dotOp};
  constexpr int numSplits = 2;
  uint32_t newK = origK / numSplits;

왜 이게 좋은가

  • 실측 개선: bf16 x mxfp4 MoE에서 ~1% 성능 향상이 관측되었다.
  • 레지스터 압력 감소: 2분할은 2세트의 레지스터만 필요하지만, 최대 분할은 수십 세트가 필요하여 레지스터 스필링이 발생할 수 있다.
  • 파이프라이닝 균형: 2개의 WGMMA가 번갈아 실행되면 충분한 오버랩이 달성되며, 그 이상은 오버헤드가 이득을 초과한다.
  • 코드 단순화: 21줄 삭제로 테스트의 CHECK 패턴도 대폭 간결해졌다.

정리

"이론상 최적"과 "실측 최적"이 다른 전형적인 사례다. K 차원 최대 분할은 이론적으로 최대 병렬성을 제공하지만, 레지스터 압력과 wait 오버헤드를 고려하면 2분할이 실용적 최적점이다.

참고 자료


이 글은 AI 도구의 도움을 받아 작성되었습니다.

댓글

관련 포스트

PR Analysis 의 다른글