본문으로 건너뛰기

[triton] Triton GFX1250 MXFP GEMM 커널의 4-Warp 스케줄링 최적화 분석

PR 링크: triton-lang/triton#9031 상태: Merged | 변경: +839 / -253

들어가며

최근 Triton 레포지토리의 AMD GFX1250 타겟 MXFP(Microscaling Formats) GEMM 커널에 중요한 리팩토링과 성능 개선이 이루어졌습니다. 이번 업데이트의 핵심은 기존 커널 구조를 유연하게 개선하여 '4-Warp 스케줄링'을 지원하고, 스케일 데이터 처리에 비동기 복사(Async Copy)를 도입하여 SGPR(Scalar General Purpose Register) 압박을 완화한 것입니다. 본 글에서는 이 변경사항이 어떻게 커널의 효율성을 높였는지 분석합니다.

코드 분석

1. MXFPGEMMConfig의 유연한 구조화

기존 커널은 고정된 스케줄링에 의존했으나, MXFPGEMMConfig 클래스를 도입하여 다양한 스케줄링 전략을 설정값으로 제어할 수 있게 되었습니다.

# Before: 고정된 레이아웃 및 설정
# After: MXFPGEMMConfig를 통한 유연한 설정
@aggregate
class MXFPGEMMConfig:
    BLOCK_M: gl.constexpr
    BLOCK_N: gl.constexpr
    # ... (중략)
    ASYNC_COPY_SCALE: gl.constexpr
    NUM_SUBTILES: gl.constexpr

이 구조를 통해 NUM_SUBTILES를 조절하여 A 행렬은 K 차원을 따라, B 행렬은 N과 K 차원을 따라 슬라이싱하는 4-Warp 스케줄링을 구현했습니다.

2. Async Copy를 통한 스케일 로드 최적화

SGPR 압박을 줄이기 위해 스케일 데이터 로드에 비동기 복사를 선택적으로 사용할 수 있게 되었습니다. ScaleAsyncCopyDescriptor 클래스는 이를 위해 설계되었습니다.

# After: 비동기 복사 로직 추가
@gluon.jit
def issue_async_load(self, idx: int, buffer, pred=True):
    NUM_SUBTILES_NONK: gl.constexpr = self.cfg.NUM_SUBTILES[self.op_idx]
    if pred:
        cp.global_to_shared(
            buffer, self.ptr + (idx % NUM_SUBTILES_NONK) * self.step_nonk +
            (idx // NUM_SUBTILES_NONK) * self.step_k + self.offs)
        cp.commit_group()

이 방식은 메모리 로드와 연산을 겹쳐(Overlap) 지연 시간을 숨기고, 레지스터 사용량을 최적화하여 더 높은 Occupancy를 달성합니다.

왜 이게 좋은가

이번 최적화는 크게 두 가지 측면에서 이점을 제공합니다:

  1. SGPR 압박 완화: 스케일 데이터를 비동기적으로 로드함으로써, 스칼라 레지스터(SGPR)에 가해지는 부담을 줄였습니다. 이는 더 복잡한 커널 로직을 실행할 수 있는 여유를 제공합니다.
  2. 유연한 스케줄링: 4-Warp 스케줄링을 통해 하드웨어 유닛(WMMA)의 활용도를 극대화했습니다. 특히 NUM_SUBTILES를 통해 타일링 전략을 세밀하게 조정할 수 있어, 다양한 행렬 크기에서 최적의 성능을 낼 수 있습니다.

일반적인 교훈으로, GPU 커널 최적화 시 레지스터 압박(Register Pressure)을 줄이기 위한 비동기 메모리 연산 활용하드웨어 아키텍처에 맞춘 유연한 타일링 전략이 성능 향상의 핵심임을 확인할 수 있습니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글