[triton] Triton GFX1250 MXFP GEMM 커널의 4-Warp 스케줄링 최적화 분석
PR 링크: triton-lang/triton#9031 상태: Merged | 변경: +839 / -253
들어가며
최근 Triton 레포지토리의 AMD GFX1250 타겟 MXFP(Microscaling Formats) GEMM 커널에 중요한 리팩토링과 성능 개선이 이루어졌습니다. 이번 업데이트의 핵심은 기존 커널 구조를 유연하게 개선하여 '4-Warp 스케줄링'을 지원하고, 스케일 데이터 처리에 비동기 복사(Async Copy)를 도입하여 SGPR(Scalar General Purpose Register) 압박을 완화한 것입니다. 본 글에서는 이 변경사항이 어떻게 커널의 효율성을 높였는지 분석합니다.
코드 분석
1. MXFPGEMMConfig의 유연한 구조화
기존 커널은 고정된 스케줄링에 의존했으나, MXFPGEMMConfig 클래스를 도입하여 다양한 스케줄링 전략을 설정값으로 제어할 수 있게 되었습니다.
# Before: 고정된 레이아웃 및 설정
# After: MXFPGEMMConfig를 통한 유연한 설정
@aggregate
class MXFPGEMMConfig:
BLOCK_M: gl.constexpr
BLOCK_N: gl.constexpr
# ... (중략)
ASYNC_COPY_SCALE: gl.constexpr
NUM_SUBTILES: gl.constexpr
이 구조를 통해 NUM_SUBTILES를 조절하여 A 행렬은 K 차원을 따라, B 행렬은 N과 K 차원을 따라 슬라이싱하는 4-Warp 스케줄링을 구현했습니다.
2. Async Copy를 통한 스케일 로드 최적화
SGPR 압박을 줄이기 위해 스케일 데이터 로드에 비동기 복사를 선택적으로 사용할 수 있게 되었습니다. ScaleAsyncCopyDescriptor 클래스는 이를 위해 설계되었습니다.
# After: 비동기 복사 로직 추가
@gluon.jit
def issue_async_load(self, idx: int, buffer, pred=True):
NUM_SUBTILES_NONK: gl.constexpr = self.cfg.NUM_SUBTILES[self.op_idx]
if pred:
cp.global_to_shared(
buffer, self.ptr + (idx % NUM_SUBTILES_NONK) * self.step_nonk +
(idx // NUM_SUBTILES_NONK) * self.step_k + self.offs)
cp.commit_group()
이 방식은 메모리 로드와 연산을 겹쳐(Overlap) 지연 시간을 숨기고, 레지스터 사용량을 최적화하여 더 높은 Occupancy를 달성합니다.
왜 이게 좋은가
이번 최적화는 크게 두 가지 측면에서 이점을 제공합니다:
- SGPR 압박 완화: 스케일 데이터를 비동기적으로 로드함으로써, 스칼라 레지스터(SGPR)에 가해지는 부담을 줄였습니다. 이는 더 복잡한 커널 로직을 실행할 수 있는 여유를 제공합니다.
- 유연한 스케줄링: 4-Warp 스케줄링을 통해 하드웨어 유닛(WMMA)의 활용도를 극대화했습니다. 특히
NUM_SUBTILES를 통해 타일링 전략을 세밀하게 조정할 수 있어, 다양한 행렬 크기에서 최적의 성능을 낼 수 있습니다.
일반적인 교훈으로, GPU 커널 최적화 시 레지스터 압박(Register Pressure)을 줄이기 위한 비동기 메모리 연산 활용과 하드웨어 아키텍처에 맞춘 유연한 타일링 전략이 성능 향상의 핵심임을 확인할 수 있습니다.
참고 자료
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [triton] wgmma wait(0)를 accumulator 첫 사용 시점으로 지연하여 MMA-epilogue 오버랩 달성
- 현재글 : [triton] Triton GFX1250 MXFP GEMM 커널의 4-Warp 스케줄링 최적화 분석
- 다음글 [Triton] AMD scf.if else 분기 누락 버그 수정 — deduceMinCountBetweeOps
댓글