[Triton] GFX1250용 MXGEMM Gluon 커널 업데이트
들어가며
AMD gfx1250의 MXGEMM(Microscaled GEMM) Gluon 커널에서 세 가지 개선이 이루어졌다: hip.init(0) 워크어라운드 제거, TDM(Tensor Data Mover)으로 전체 텐서를 한 번에 로드하여 sgpr 압력 감소, 그리고 padded layout 수정이다.
핵심 코드 분석
Before: 슬라이스별 로드와 subtile 관리
# 슬라이스별 로드
BLOCK_K_PACKED_A = BLOCK_K // self.DIV_FACTOR_A // NUM_SUBTILES_K
# ...
self.shared_layout_a = gl.constexpr(
gl.PaddedSharedLayout.with_identity_for(
[[BLOCK_K_PACKED_A, 16]],
[BLOCK_M // NUM_SUBTILES_M, BLOCK_K_PACKED_A], [1, 0]))
각 subtile을 개별적으로 로드하고, shared memory layout도 subtile 크기에 맞췄다.
After: 전체 텐서를 한 번에 로드
BLOCK_K_PACKED_A = BLOCK_K // self.DIV_FACTOR_A
PAD_INTERVAL_A = 256 if BLOCK_K_PACKED_A <= 256 else BLOCK_K_PACKED_A
self.shared_layout_a = gl.constexpr(
gl.PaddedSharedLayout.with_identity_for(
[[PAD_INTERVAL_A, 16]], [BLOCK_M, BLOCK_K_PACKED_A], [1, 0]))
subtile 분할(// NUM_SUBTILES_K)이 제거되고, 전체 블록을 한 번에 로드한다. 이는 LLVM 백엔드의 LDS 인덱싱 버그가 수정된 후 가능해졌다.
파이프라인 스케줄도 변경되었다:
# Before: lds와 tdm+wmma를 분리
with gl.amd.warp_pipeline_stage("lds", priority=1): ...
with gl.amd.warp_pipeline_stage("tdm+wmma", priority=0): ...
# After: tdm+lds를 합치고 wmma만 분리
with gl.amd.warp_pipeline_stage("tdm+lds", priority=1): ...
with gl.amd.warp_pipeline_stage("wmma", priority=0): ...
왜 이게 좋은가
- sgpr 압력 감소: 전체 텐서를 한 번에 TDM으로 로드하면 sgpr backpressure가 줄어든다.
- 코드 단순화: subtile 관리 로직이 제거되어 가독성이 크게 향상되었다.
- hip.init 제거: 불필요한
hip.hipInit(0)워크어라운드가 제거되었다.
정리
+239/-338로 코드가 약 100줄 줄면서도 성능이 개선된 경우다. LLVM 버그 수정 이후 상위 레벨 최적화가 가능해진 좋은 사례다.
참고 자료
이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 코드 분석 내용은 실제 PR diff를 기반으로 합니다.
댓글