본문으로 건너뛰기

[triton] NVIDIA TMA im2col 모드 Tensor Descriptor 지원

PR 링크: triton-lang/triton#9225 상태: Merged | 변경: +223 / -14

들어가며

NVIDIA의 Tensor Memory Accelerator(TMA)는 tiled 모드와 im2col 모드를 지원합니다. im2col 모드는 convolution 연산에 최적화된 메모리 접근 패턴으로, 입력 텐서를 행렬 곱셈에 적합한 형태로 자동 변환합니다. 이 PR은 Triton의 type system에 TensorDescInterface를 도입하여 두 모드를 통합적으로 다룹니다.

핵심 코드 분석

1. TensorDescInterface 도입

def TT_TensorDescInterface : TypeInterface<"TensorDescInterface"> {
  let methods = [
    InterfaceMethod<
      "Returns the block type of the tensor descriptor",
      "mlir::RankedTensorType", "getBlockType", (ins)
    >,
    InterfaceMethod<
      "Returns the block type with signless integer element type",
      "mlir::RankedTensorType", "getSignlessBlockType", (ins),
      // default implementation provided
    >,
  ];
}

기존 TensorDescType과 새로운 TensorDescIm2ColType이 공통 인터페이스를 구현하여, 하위 코드에서 구체적 타입에 의존하지 않고 처리할 수 있습니다.

2. TensorDescType에 인터페이스 적용

Before:

def TT_TensorDescType : TritonTypeDef<"TensorDesc", "tensordesc", []> {
  // getSignlessBlockType()이 extraClassDeclaration에 직접 정의
}

After:

def TT_TensorDescType : TritonTypeDef<"TensorDesc", "tensordesc",
    [TT_TensorDescInterface]> {
  // getSignlessBlockType()이 인터페이스의 defaultImpl에서 제공
}

getSignlessBlockType()이 인터페이스의 default implementation으로 이동하여, 모든 tensor descriptor 타입에서 자동으로 사용 가능합니다.

3. 범용 타입 constraint

def TT_AnyTensorDescType : Type<
  CPred<"::mlir::isa<::mlir::triton::TensorDescInterface>($_self)">,
  "tensor descriptor type",
  "::mlir::triton::TensorDescInterface"
>;

이 constraint를 사용하면 어떤 종류의 tensor descriptor든 인자로 받을 수 있습니다.

왜 이게 좋은가

  • 다형성: 같은 연산이 tiled와 im2col descriptor를 모두 처리할 수 있어 코드 중복이 줄어듭니다.
  • 확장성: 새로운 TMA 모드가 추가되더라도 인터페이스만 구현하면 기존 코드와 호환됩니다.
  • MLIR 패턴 준수: TableGen 인터페이스를 활용한 깔끔한 타입 시스템 설계입니다.

정리

TMA im2col 모드를 위한 타입 시스템 기반을 마련한 PR입니다. TensorDescInterface 도입으로 tiled/im2col descriptor를 다형적으로 처리할 수 있게 되었습니다.

참고 자료


이 글은 AI의 도움을 받아 작성되었으며, 원본 PR의 코드 변경 사항을 기반으로 분석한 내용입니다.

댓글

관련 포스트

PR Analysis 의 다른글