본문으로 건너뛰기

[triton] NVIDIA TMA im2col 모드 Gluon 튜토리얼 - Convolution 커널 구현

PR 링크: triton-lang/triton#9406 상태: Merged | 변경: +1861 / -1

들어가며

TMA(Tensor Memory Accelerator)의 im2col 모드는 Convolution 연산에 특화된 하드웨어 가속 기능입니다. 일반적으로 Convolution을 GEMM으로 변환하려면 im2col 변환이 필요한데, TMA im2col 모드를 사용하면 이 변환을 하드웨어가 자동 수행하므로 별도의 메모리 복사 없이 직접 shared memory에 데이터를 로드할 수 있습니다. 이 PR은 Gluon API를 통해 TMA im2col 모드를 활용하는 Convolution 커널 튜토리얼을 제공합니다.

핵심 코드 분석

1. Warp-Specialized 파티션 구조

@gluon.jit
def load_partition(p):
    """Load partition: issues TMA copies for input (im2col) and weight tiles."""
    ci_num_blocks = gl.cdiv(config.Ci, BLOCK_K)
    num_rs = config.R * config.S
    num_k_iter = num_rs * ci_num_blocks

    for k_iter in range(num_k_iter):
        index = k_iter % num_buffers
        phase = k_iter // num_buffers & 1
        mbarrier.wait(p.load_empty_bars.index(index), phase ^ 1)
        mbarrier.expect(ready_bar, ...)
        tma.async_copy_global_to_shared_im2col(...)

로드 파티션은 필터의 R*S 차원과 채널(Ci) 차원을 순회하면서, TMA im2col 모드로 입력 데이터를 shared memory에 비동기 복사합니다. mbarrier를 사용한 동기화로 생산자-소비자 패턴을 구현합니다.

2. ConvConfig 집계 타입

@aggregate
class ConvConfig:
    N: gl.tensor; H: gl.tensor; W: gl.tensor
    Ci: gl.tensor; Co: gl.tensor; R: gl.tensor; S: gl.tensor
    BLOCK_M: gl.constexpr; BLOCK_N: gl.constexpr; BLOCK_K: gl.constexpr

    @gluon.jit
    def get_program(self, pid):
        num_pid_m = gl.cdiv(M_GEMM, self.BLOCK_M)
        num_pid_n = gl.cdiv(N_GEMM, self.BLOCK_N)
        # grouped ordering for better L2 locality
        pid_m = first_pid_m + (pid % group_size_m)
        pid_n = (pid % num_pid_in_group) // group_size_m
        return ConvProgram(self, pid_m, pid_n)

GEMM 매핑(M=Nout_hout_w, N=Co, K=RSCi)과 grouped ordering 타일 스케줄링을 @aggregate 타입으로 깔끔하게 캡슐화합니다.

왜 이게 좋은가

TMA im2col 모드는 Convolution에서 가장 비용이 큰 데이터 재배치 작업을 하드웨어가 처리하므로, 메모리 대역폭을 극대화할 수 있습니다. 이 튜토리얼은 Gluon의 저수준 API를 활용하여 warp specialization, 다중 버퍼링, mbarrier 동기화 등 Blackwell GPU의 핵심 기능을 모두 사용하는 완전한 Convolution 커널을 보여줍니다. Triton 사용자에게 하드웨어 최적 Convolution 구현의 참조 구현을 제공합니다.

정리

  • TMA im2col 모드로 하드웨어 가속 Convolution 로드 구현
  • Warp-specialized load/MMA/store 파티션 분리
  • mbarrier 기반 다중 버퍼 생산자-소비자 패턴
  • Grouped ordering 타일 스케줄링으로 L2 캐시 효율 극대화

참고 자료

이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글