본문으로 건너뛰기

[Triton] AMD gfx1250 Tensor Descriptor 기반 GEMM 테스트 추가

PR 링크: triton-lang/triton#9779 상태: Merged | 변경: +528 / -0

들어가며

AMD GFX1250 아키텍처는 Tensor Descriptor Mode(TDM)를 지원한다. TDM은 텐서의 메모리 레이아웃 정보를 descriptor로 표현하여 하드웨어가 직접 텐서 로드/스토어를 처리하도록 하는 기능이다. 이 PR은 GFX1250에서 TDM을 사용한 GEMM(General Matrix Multiply) 커널의 종합 테스트를 추가한다.

핵심 코드 분석

Device TDM GEMM 커널

커널 내부에서 직접 tensor descriptor를 생성하는 방식이다.

@triton.jit
def gemm_device_tdm_kernel(a_ptr, b_ptr, c_ptr, M, N, K,
                           BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
                           BLOCK_K: tl.constexpr):
    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_M)
    pid_m = pid % num_pid_m
    pid_n = pid // num_pid_m

    a_desc = tl.make_tensor_descriptor(
        base=a_ptr + pid_m * BLOCK_M * K,
        shape=(M, K),
        strides=(K, 1),
        block_shape=(BLOCK_M, BLOCK_K),
    )
    # ...
    accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    for k in range(0, K, BLOCK_K):
        a = a_desc.load([0, k])
        b = b_desc.load([k, 0])
        accumulator = tl.dot(a, b, acc=accumulator)

    c_desc.store([pid_m * BLOCK_M, pid_n * BLOCK_N], accumulator)

Host TDM GEMM 커널

Host에서 TensorDescriptor 객체를 미리 생성하여 커널 인자로 전달하는 방식이다.

def _run_kernel(x, y, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, use_tdm='device'):
    z = torch.empty(M, N, dtype=torch.float32).cuda()
    grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1)
    if use_tdm == 'host':
        a_desc = TensorDescriptor.from_tensor(x, (BLOCK_M, BLOCK_K))
        b_desc = TensorDescriptor.from_tensor(y, (BLOCK_K, BLOCK_N))
        c_desc = TensorDescriptor.from_tensor(z, (BLOCK_M, BLOCK_N))
        gemm_host_tdm_kernel[grid](a_desc, b_desc, c_desc, M, N, K,
                                   BLOCK_M, BLOCK_N, BLOCK_K)
    else:
        gemm_device_tdm_kernel[grid](x, y, z, M, N, K,
                                     BLOCK_M, BLOCK_N, BLOCK_K)
    return z

Scale Pre-shuffle 최적화

MXFP GEMM에서는 scale 텐서를 128-element 그룹으로 pre-shuffle하여 메모리 접근 효율을 높인다.

def pack_scale(x: torch.Tensor, preshuffle_factor: int = 128) -> torch.Tensor:
    NON_K, K_SCALE = x.shape
    SCALE_KWIDTH = 4 if K_SCALE >= 4 else K_SCALE
    num_chunk_m = NON_K // preshuffle_factor
    num_chunk_k = K_SCALE // SCALE_KWIDTH
    # [NON_K, K_SCALE] -> [num_chunk_m, 4, 32, num_chunk_k, SCALE_KWIDTH]
    x = x.view(num_chunk_m, 4, preshuffle_factor // 4, num_chunk_k, SCALE_KWIDTH)
    x = x.permute(0, 3, 2, 1, 4).contiguous()

왜 이게 좋은가

  1. 포괄적 테스트 커버리지: FP16 GEMM, MXFP GEMM, MXFP Fused Attention의 3가지 핵심 워크로드를 커버한다.
  2. Device/Host TDM 모드 모두 검증: descriptor 생성 위치에 따른 두 가지 경로를 모두 테스트한다.
  3. 다양한 블록 크기: (32x32x64), (128x128x128) 등 여러 타일 크기 조합을 parametrize로 검증한다.
  4. Pre-shuffled scale layout: 실제 프로덕션에서 사용되는 최적화된 scale 레이아웃을 테스트에 포함한다.

정리

이 PR은 AMD GFX1250의 Tensor Descriptor Mode에 대한 종합 테스트 스위트를 추가했다. FP16 GEMM, MXFP GEMM, Fused Attention 커널을 device/host TDM 모드로 검증하며, scale pre-shuffle 최적화까지 포함한다.

참고 자료


이 글은 AI를 활용하여 PR의 핵심 변경사항을 분석하고 정리한 것입니다. 실제 코드의 맥락은 원본 PR을 참고해 주세요.

댓글

관련 포스트

PR Analysis 의 다른글