본문으로 건너뛰기

[triton] [Blackwell] NVIDIA 차세대 아키텍처를 위한 Triton의 tcgen05.ld.red 최적화 분석

PR 링크: triton-lang/triton#9151 상태: Merged | 변경: +872 / -82

들어가며

NVIDIA의 차세대 아키텍처인 **Blackwell(SM100+)**은 텐서 코어와 메모리 계층 구조에서 혁신적인 변화를 가져왔습니다. 그중 하나가 **Tensor Memory(TMEM)**와 이를 활용한 새로운 명령어 셋인 tcgen05입니다.

이번에 분석할 PR은 Triton의 고성능 커널 라이브러리인 Gluon에 Blackwell 전용 기능인 tcgen05.ld.red 명령어를 도입한 변경사항입니다. 이 명령어의 핵심은 TMEM에서 데이터를 레지스터로 로드하는 동시에, 특정 차원(N-dimension)에 대한 리덕션(Min/Max) 연산을 하드웨어 수준에서 수행한다는 점입니다. 이를 통해 별도의 리덕션 커널이나 추가적인 레지스터 연산 없이도 통계치(예: Softmax의 Max값)를 계산할 수 있어 성능이 크게 향상됩니다.

코드 분석: 핵심 변경 사항

1. IR 정의: TMEMLoadOp의 확장

기존의 TMEMLoadOp는 단순히 TMEM에서 데이터를 가져오는 역할만 했습니다. 이번 PR에서는 리덕션 옵션을 추가하여 하드웨어 가속을 활용할 수 있도록 확장되었습니다.

Before (TritonNvidiaGPUOps.td):

def TTNG_TMEMLoadOp : TTNG_Op<"tmem_load"> {
  let arguments = (ins
    Arg<TTG_MemDescType, "", [MemRead<TensorMemory>]>:$src,
    Optional<TTG_AsyncToken>:$dep
  );
  let results = (outs
    TT_Tensor:$result,
    Optional<TTG_AsyncToken>:$token
  );
}

After (TritonNvidiaGPUOps.td):

def TTNG_TMEMLoadOp : TTNG_Op<"tmem_load", [AttrSizedResultSegments]> {
  let arguments = (ins
    Arg<TTG_MemDescType, "", [MemRead<TensorMemory>]>:$src,
    Optional<TTG_AsyncToken>:$dep,
    OptionalAttr<TTNG_TMEMLoadReduceModifierEnum>:$redOp, // MIN/MAX 추가
    OptionalAttr<BoolAttr>:$abs, // 절대값 옵션
    OptionalAttr<BoolAttr>:$NaN  // NaN 전파 옵션
  );
  let results = (outs
    TT_Tensor:$result,
    Optional<TTG_AsyncToken>:$token,
    Optional<TT_Tensor>:$red // 리덕션 결과값(M 차원 벡터)
  );
}

이 변경을 통해 tmem_load는 이제 원본 데이터뿐만 아니라 리덕션된 결과값(red)까지 동시에 반환할 수 있는 구조를 갖추게 되었습니다.

2. Verifier를 통한 하드웨어 제약 조건 강제

tcgen05.ld.red는 하드웨어 명령어이므로 매우 엄격한 레이아웃 제약 조건을 가집니다. 특히 리덕션이 일어나는 N 차원이 스레드 간에 분산(sharded)되어 있으면 안 됩니다. 이를 검증하기 위한 로직이 Ops.cpp에 추가되었습니다.

lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp:

// N 차원이 스레드 간에 공유되지 않고 레지스터 내에 온전히 있는지 확인
auto regDims = toLinearEncoding(regTy).basesPerDim(kReg);
if (regDims[dimN] != toLinearLayout(regTy).getOutDimSizes().begin()[dimN] ||
    regDims[dimM] != 1) {
  return emitOpError("tmem_load reduction with N dimension sharded across "
                     "threads is not supported.");
}

리뷰어 Thomas Raoux는 이 부분에서 레이아웃 가정이 깨질 위험을 경고했고, 이에 따라 unpacked=false 조건 및 TMEM 레이아웃에 대한 명시적인 체크가 강화되었습니다.

3. Gluon Python API의 변화

사용자 레벨에서는 load_max, load_min과 같은 직관적인 API를 통해 이 기능을 사용할 수 있습니다.

python/examples/gluon/01-attention-forward.py:

# Before: 일반 로드 후 별도 처리 필요
# qks = qks + (s_tmem.slice(i * SIZE, SIZE).load(layout), )

# After: 로드와 동시에 Max 값 획득
if use_tmem_red:
    vals, reds = s_tmem.slice(i * SIZE, SIZE).load_max(layout)
    red_total = reds if red_total is None else gl.maximum(red_total, reds)
    qks = qks + (vals, )

Flash Attention의 Forward 패스에서 Softmax를 위한 Max 값을 구할 때, 데이터를 읽어오면서 동시에 하드웨어가 Max를 찾아주므로 루프 내의 연산 오버헤드가 줄어듭니다.

왜 이게 좋은 최적화인가?

1. 성능 향상 (TFLOPS)

PR 설명에 포함된 벤치마크 결과에 따르면, Blackwell(B300) 환경에서 has_tmem_red=True일 때 눈에 띄는 성능 향상이 관찰됩니다.

  • FP8 Attention (D=128, N_CTX=32768): 1808 TFLOPS -> 1998 TFLOPS (약 10.5% 향상)
  • FP8 Causal Attention (D=128, N_CTX=32768): 1870 TFLOPS -> 2012 TFLOPS

2. 메모리 대역폭 및 레지스터 절약

일반적으로 리덕션을 수행하려면 데이터를 로드한 후, 레지스터 수준에서 max 연산을 수행해야 합니다. tcgen05.ld.red는 로드 유닛 내부에서 이 작업을 처리하므로:

  • 추가적인 ALU 연산 사이클이 필요 없습니다.
  • 중간 단계의 리덕션을 위한 추가 레지스터 소모를 줄일 수 있습니다.

3. 하드웨어 추상화의 정석

이 PR은 복잡한 PTX 명령어(tcgen05.ld.red.sync.aligned...)를 Triton의 상위 IR 수준에서 TMEMLoadOp의 속성으로 추상화했습니다. 이는 컴파일러가 레이아웃을 최적화할 수 있는 여지를 남기면서도, 개발자가 하드웨어 특화 기능을 손쉽게 사용할 수 있게 해줍니다.

결론 및 교훈

이번 최적화는 **"데이터 이동과 연산의 결합"**이 현대 GPU 아키텍처에서 얼마나 중요한지를 보여줍니다. 단순히 데이터를 빨리 옮기는 것을 넘어, 옮기는 과정에서 부가적인 연산을 처리하는 하드웨어 기능을 소프트웨어 스택(Triton)이 얼마나 기민하게 수용하느냐가 딥러닝 프레임워크의 경쟁력이 됩니다.

시니어 엔지니어로서 배울 점은, 새로운 하드웨어 기능을 도입할 때 단순히 기능을 추가하는 것에 그치지 않고, Verifier를 통해 하드웨어의 제약 사항을 컴파일 타임에 엄격히 체크하여 런타임 에러를 방지하는 설계 철학입니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글