[triton] NVIDIA TMA im2col 모드 Tensor Descriptor 지원
PR 링크: triton-lang/triton#9225 상태: Merged | 변경: +223 / -14
들어가며
NVIDIA의 Tensor Memory Accelerator(TMA)는 tiled 모드와 im2col 모드를 지원합니다. im2col 모드는 convolution 연산에 최적화된 메모리 접근 패턴으로, 입력 텐서를 행렬 곱셈에 적합한 형태로 자동 변환합니다. 이 PR은 Triton의 type system에 TensorDescInterface를 도입하여 두 모드를 통합적으로 다룹니다.
핵심 코드 분석
1. TensorDescInterface 도입
def TT_TensorDescInterface : TypeInterface<"TensorDescInterface"> {
let methods = [
InterfaceMethod<
"Returns the block type of the tensor descriptor",
"mlir::RankedTensorType", "getBlockType", (ins)
>,
InterfaceMethod<
"Returns the block type with signless integer element type",
"mlir::RankedTensorType", "getSignlessBlockType", (ins),
// default implementation provided
>,
];
}
기존 TensorDescType과 새로운 TensorDescIm2ColType이 공통 인터페이스를 구현하여, 하위 코드에서 구체적 타입에 의존하지 않고 처리할 수 있습니다.
2. TensorDescType에 인터페이스 적용
Before:
def TT_TensorDescType : TritonTypeDef<"TensorDesc", "tensordesc", []> {
// getSignlessBlockType()이 extraClassDeclaration에 직접 정의
}
After:
def TT_TensorDescType : TritonTypeDef<"TensorDesc", "tensordesc",
[TT_TensorDescInterface]> {
// getSignlessBlockType()이 인터페이스의 defaultImpl에서 제공
}
getSignlessBlockType()이 인터페이스의 default implementation으로 이동하여, 모든 tensor descriptor 타입에서 자동으로 사용 가능합니다.
3. 범용 타입 constraint
def TT_AnyTensorDescType : Type<
CPred<"::mlir::isa<::mlir::triton::TensorDescInterface>($_self)">,
"tensor descriptor type",
"::mlir::triton::TensorDescInterface"
>;
이 constraint를 사용하면 어떤 종류의 tensor descriptor든 인자로 받을 수 있습니다.
왜 이게 좋은가
- 다형성: 같은 연산이 tiled와 im2col descriptor를 모두 처리할 수 있어 코드 중복이 줄어듭니다.
- 확장성: 새로운 TMA 모드가 추가되더라도 인터페이스만 구현하면 기존 코드와 호환됩니다.
- MLIR 패턴 준수: TableGen 인터페이스를 활용한 깔끔한 타입 시스템 설계입니다.
정리
TMA im2col 모드를 위한 타입 시스템 기반을 마련한 PR입니다. TensorDescInterface 도입으로 tiled/im2col descriptor를 다형적으로 처리할 수 있게 되었습니다.
참고 자료
이 글은 AI의 도움을 받아 작성되었으며, 원본 PR의 코드 변경 사항을 기반으로 분석한 내용입니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [Triton] AMD PrepareIfCombining 패스 추가 — scf.if 병합 최적화
- 현재글 : [triton] NVIDIA TMA im2col 모드 Tensor Descriptor 지원
- 다음글 [Triton] TMA im2col 모드 — tma load op 수정
댓글