본문으로 건너뛰기

[Triton] TMA im2col 모드 — tma load op 수정

PR 링크: triton-lang/triton#9303 상태: Merged | 변경: +56 / -72

들어가며

TMA(Tensor Memory Accelerator)는 NVIDIA Hopper 이상 GPU에서 제공하는 하드웨어 가속 메모리 복사 엔진이다. 기존에는 tiled 모드만 지원했는데, 이 시리즈는 im2col(image to column) 모드를 추가한다. im2col은 convolution 연산에서 입력 텐서의 패치를 열(column)로 변환하여 행렬 곱셈으로 처리할 수 있게 하는 기법이다.

이 PR은 시리즈의 세 번째로, AsyncTMACopyGlobalToLocalOp에 im2col offset을 전달하는 인터페이스를 수정한다.

핵심 코드 분석

Before: offset 파라미터 없음

// AsyncTMACopyGlobalToLocalOp
def TTNG_AsyncTMACopyGlobalToLocalOp :
    TTNG_Op<"async_tma_copy_global_to_local"> {
  let arguments = (ins
    TT_TensorDescType:$desc_ptr,
    Variadic<I32>:$coords,
    TTNG_MBarrierType:$barrier,
    TTG_MemDescType:$result,
    I1:$pred,
    BoolAttr:$multicast
  );
}

After: im2col offset 추가

def TTNG_AsyncTMACopyGlobalToLocalOp :
    TTNG_Op<"async_tma_copy_global_to_local"> {
  let arguments = (ins
    TT_TensorDescOrIm2ColType:$desc_ptr,  // im2col 타입도 허용
    Variadic<I32>:$coords,
    Variadic<I16>:$im2col_offsets,         // im2col offset 추가
    TTNG_MBarrierType:$barrier,
    TTG_MemDescType:$result,
    I1:$pred,
    BoolAttr:$multicast
  );
}

Verifier도 업데이트되어 im2col 모드와 tiled 모드의 유효성을 구분 검증한다:

LogicalResult AsyncTMACopyGlobalToLocalOp::verify() {
  auto descType = getDescPtr().getType();
  if (isa<TensorDescIm2ColType>(descType)) {
    // im2col: offsets가 있어야 함
    if (getIm2colOffsets().empty())
      return emitError("im2col mode requires offsets");
  } else {
    // tiled: offsets가 없어야 함
    if (!getIm2colOffsets().empty())
      return emitError("tiled mode should not have offsets");
  }
  return success();
}

왜 이게 좋은가

  1. 하드웨어 기능 활용: TMA im2col 모드를 사용하면 소프트웨어 im2col 변환 없이 하드웨어가 직접 패치 추출을 수행하여 메모리 대역폭을 절약한다.
  2. 타입 안전성: TensorDescOrIm2ColType 유니온 타입과 verifier를 통해 im2col/tiled 모드의 파라미터 사용 오류를 컴파일 타임에 잡는다.
  3. 깔끔한 인터페이스: offset이 선택적(Variadic) 파라미터로 추가되어 기존 tiled 모드 사용자에게 영향이 없다.

정리

이 PR은 TMA load op에 im2col offset 파라미터를 추가하고, descriptor 타입을 im2col과 tiled 모두 허용하도록 확장했다. Verifier를 통해 모드별 파라미터 정합성을 보장한다.

참고 자료


이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 핵심 코드와 explaination은 실제 PR diff를 기반으로 합니다.

댓글

관련 포스트

PR Analysis 의 다른글