[Triton] WGMMA rs-dot 분할을 2회로 제한 — 1% MoE 성능 향상
PR 링크: triton-lang/triton#9152 상태: Merged | 변경: +21 / -33
들어가며
WGMMA에서 LHS가 레지스터에 있는 경우(rs-dot), K 차원을 여러 조각으로 분할하여 각 WGMMA가 서로 다른 레지스터 집합을 사용하도록 한다. 이를 통해 하나의 WGMMA가 완료되기 전에 다음 WGMMA가 시작될 수 있다(in-register pipelining). 기존에는 K/instrK개로 최대 분할했지만, 실험 결과 2분할이 최적이었다.
핵심 코드 분석
Before
std::vector<ttng::WarpGroupDotOp> splitRSDot(ttng::WarpGroupDotOp dotOp) {
// K/16 개의 wgmma(tensor, shmem) Mx16, 16xN → MxN으로 분할
auto newK = cast<ttg::NvidiaMmaEncodingAttr>(...)
.getInstrShape()[2];
auto numSplits = origK / newK;
if (numSplits <= 1) return {dotOp};
After
std::vector<ttng::WarpGroupDotOp> splitRSDot(ttng::WarpGroupDotOp dotOp) {
// wgmma(tensor, shmem, acc)를 2분할:
// wgmma(tensor[:, :K//2], shmem[:K//2, :], acc)
// wgmma(tensor[:, K//2:], shmem[K//2:, :], acc)
auto instrK = cast<ttg::NvidiaMmaEncodingAttr>(...)
.getInstrShape()[2];
if (origK <= instrK) return {dotOp};
constexpr int numSplits = 2;
uint32_t newK = origK / numSplits;
왜 이게 좋은가
- 실측 개선: bf16 x mxfp4 MoE에서 ~1% 성능 향상이 관측되었다.
- 레지스터 압력 감소: 2분할은 2세트의 레지스터만 필요하지만, 최대 분할은 수십 세트가 필요하여 레지스터 스필링이 발생할 수 있다.
- 파이프라이닝 균형: 2개의 WGMMA가 번갈아 실행되면 충분한 오버랩이 달성되며, 그 이상은 오버헤드가 이득을 초과한다.
- 코드 단순화: 21줄 삭제로 테스트의 CHECK 패턴도 대폭 간결해졌다.
정리
"이론상 최적"과 "실측 최적"이 다른 전형적인 사례다. K 차원 최대 분할은 이론적으로 최대 병렬성을 제공하지만, 레지스터 압력과 wait 오버헤드를 고려하면 2분할이 실용적 최적점이다.
참고 자료
이 글은 AI 도구의 도움을 받아 작성되었습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [Triton] WarpSpecializePartitionsOp에 명시적 캡처 전달 — IR 구조 정합성 개선
- 현재글 : [Triton] WGMMA rs-dot 분할을 2회로 제한 — 1% MoE 성능 향상
- 다음글 [Triton] Proton GlobalScratchAllocOp 폐기 — TritonGPU 공용 op으로 통합
댓글