[Triton] GFX1250 matmul 커널에 Scale Swizzling 통합
들어가며
AMD gfx1250 플랫폼의 MXFP matmul에서 scale 데이터의 메모리 레이아웃 최적화(swizzling)를 통합하는 PR이다. Scale swizzling은 LDS 뱅크 충돌을 피하기 위해 scale 값의 물리적 배치를 재배열하는 기법이다.
핵심 코드 분석
Before
gfx1250에서는 scale swizzling이 적용되지 않았고, 범용 경로만 사용했다.
After
# _matmul.py에 GFX1250 scale swizzling 경로 추가
elif SWIZZLE_MX_SCALE == "GFX1250_SCALE":
tl.static_assert(stride_w_mx_k is not None)
tl.static_assert(stride_w_mx_n is not None)
NON_K_PRESHUFFLE_BLOCK_SIZE: tl.constexpr = 128
PACKED_MX_BLOCK: tl.constexpr = MX_SCALE_BLOCK_K * NON_K_PRESHUFFLE_BLOCK_SIZE
SCALE_BLOCK_N: tl.constexpr = BLOCK_N // NON_K_PRESHUFFLE_BLOCK_SIZE
stride_scale_k = stride_w_mx_k
unswizzle 함수의 permute도 수정되었다:
# Before
x = x.permute(0, 1, 4, 3, 2, 5) # 잘못된 permute
# After
x = x.permute(0, 3, 2, 1, 4) # 올바른 permute
왜 이게 좋은가
- 하드웨어 최적화: gfx1250의 LDS 구조에 맞는 scale 레이아웃으로 뱅크 충돌을 줄인다.
- gfx1250 matmul 테스트 확장: block size 제약(
block_m: 128, block_n: 128, block_k: 128)과 함께 해당 아키텍처 전용 테스트 경로가 추가되었다. - permute 버그 수정: 기존 6차원 permute에서 5차원으로 수정하여 정확한 unswizzling을 보장한다.
정리
+28/-11의 간결한 변경이지만, 새 아키텍처에서의 scale 레이아웃 최적화라는 실질적 성능 영향이 있는 PR이다.
참고 자료
이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 코드 분석 내용은 실제 PR diff를 기반으로 합니다.
댓글