[Triton] AMD gfx1250 Tensor Descriptor 기반 GEMM 테스트 추가
PR 링크: triton-lang/triton#9779 상태: Merged | 변경: +528 / -0
들어가며
AMD GFX1250 아키텍처는 Tensor Descriptor Mode(TDM)를 지원한다. TDM은 텐서의 메모리 레이아웃 정보를 descriptor로 표현하여 하드웨어가 직접 텐서 로드/스토어를 처리하도록 하는 기능이다. 이 PR은 GFX1250에서 TDM을 사용한 GEMM(General Matrix Multiply) 커널의 종합 테스트를 추가한다.
핵심 코드 분석
Device TDM GEMM 커널
커널 내부에서 직접 tensor descriptor를 생성하는 방식이다.
@triton.jit
def gemm_device_tdm_kernel(a_ptr, b_ptr, c_ptr, M, N, K,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_M)
pid_m = pid % num_pid_m
pid_n = pid // num_pid_m
a_desc = tl.make_tensor_descriptor(
base=a_ptr + pid_m * BLOCK_M * K,
shape=(M, K),
strides=(K, 1),
block_shape=(BLOCK_M, BLOCK_K),
)
# ...
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, K, BLOCK_K):
a = a_desc.load([0, k])
b = b_desc.load([k, 0])
accumulator = tl.dot(a, b, acc=accumulator)
c_desc.store([pid_m * BLOCK_M, pid_n * BLOCK_N], accumulator)
Host TDM GEMM 커널
Host에서 TensorDescriptor 객체를 미리 생성하여 커널 인자로 전달하는 방식이다.
def _run_kernel(x, y, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, use_tdm='device'):
z = torch.empty(M, N, dtype=torch.float32).cuda()
grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1)
if use_tdm == 'host':
a_desc = TensorDescriptor.from_tensor(x, (BLOCK_M, BLOCK_K))
b_desc = TensorDescriptor.from_tensor(y, (BLOCK_K, BLOCK_N))
c_desc = TensorDescriptor.from_tensor(z, (BLOCK_M, BLOCK_N))
gemm_host_tdm_kernel[grid](a_desc, b_desc, c_desc, M, N, K,
BLOCK_M, BLOCK_N, BLOCK_K)
else:
gemm_device_tdm_kernel[grid](x, y, z, M, N, K,
BLOCK_M, BLOCK_N, BLOCK_K)
return z
Scale Pre-shuffle 최적화
MXFP GEMM에서는 scale 텐서를 128-element 그룹으로 pre-shuffle하여 메모리 접근 효율을 높인다.
def pack_scale(x: torch.Tensor, preshuffle_factor: int = 128) -> torch.Tensor:
NON_K, K_SCALE = x.shape
SCALE_KWIDTH = 4 if K_SCALE >= 4 else K_SCALE
num_chunk_m = NON_K // preshuffle_factor
num_chunk_k = K_SCALE // SCALE_KWIDTH
# [NON_K, K_SCALE] -> [num_chunk_m, 4, 32, num_chunk_k, SCALE_KWIDTH]
x = x.view(num_chunk_m, 4, preshuffle_factor // 4, num_chunk_k, SCALE_KWIDTH)
x = x.permute(0, 3, 2, 1, 4).contiguous()
왜 이게 좋은가
- 포괄적 테스트 커버리지: FP16 GEMM, MXFP GEMM, MXFP Fused Attention의 3가지 핵심 워크로드를 커버한다.
- Device/Host TDM 모드 모두 검증: descriptor 생성 위치에 따른 두 가지 경로를 모두 테스트한다.
- 다양한 블록 크기: (32x32x64), (128x128x128) 등 여러 타일 크기 조합을 parametrize로 검증한다.
- Pre-shuffled scale layout: 실제 프로덕션에서 사용되는 최적화된 scale 레이아웃을 테스트에 포함한다.
정리
이 PR은 AMD GFX1250의 Tensor Descriptor Mode에 대한 종합 테스트 스위트를 추가했다. FP16 GEMM, MXFP GEMM, Fused Attention 커널을 device/host TDM 모드로 검증하며, scale pre-shuffle 최적화까지 포함한다.
참고 자료
이 글은 AI를 활용하여 PR의 핵심 변경사항을 분석하고 정리한 것입니다. 실제 코드의 맥락은 원본 PR을 참고해 주세요.
관련 포스트
- [triton] AMD/Gluon: gfx1250에서 async_copy 런타임 테스트 추가 및 UpdateAsyncWaitCnt 활성화
- [triton] Triton AMD 백엔드 최적화: SGPR 활용과 루프 최적화를 통한 GEMM 성능 향상
- [triton] AMD gfx1250에서 Async Copy와 TDM 경로의 Padded Layout 휴리스틱 통합
- [triton] AMD gfx1250 MXFP Flash Attention 예제 커널 업데이트
- [triton] AMD 백엔드에서 Floating-Point Sanitizer(FPSan) 지원 활성화
PR Analysis 의 다른글
- 이전글 [sglang] DeepEP Low Latency FP8 Dispatch 변경 revert
- 현재글 : [Triton] AMD gfx1250 Tensor Descriptor 기반 GEMM 테스트 추가
- 다음글 [sglang] NPU 호환성 수정: empty_cache와 memory_saver 충돌 해결
댓글