[triton] [NVIDIA] SM120을 위한 FP4 Native Scaled Matmul 지원 및 성능 최적화 분석
PR 링크: triton-lang/triton#8494 상태: Merged | 변경: +62 / -32
들어가며
최근 대규모 언어 모델(LLM)의 효율적인 추론을 위해 FP8을 넘어 FP4와 같은 저정밀도 데이터 타입에 대한 하드웨어 가속 요구가 높아지고 있습니다. NVIDIA의 차세대 아키텍처(SM120 등)는 이러한 저정밀도 연산을 위한 전용 하드웨어 유닛을 포함하고 있습니다.
기존 Triton의 FP4 구현은 하드웨어의 기능을 온전히 활용하지 못하고, FP4를 상위 정밀도(FP16/FP32)로 변환(Unpacking)한 뒤 연산하는 Decomposition Fallback 방식을 사용했습니다. 이는 불필요한 메모리 대역폭 소모와 연산 오버헤드를 발생시켰습니다.
이번 PR은 SM120 아키텍처에서 Native FP4 Scaled Matmul을 지원하도록 개선하여, 실제 Llama3-8B 벤치마크에서 실행 시간을 61초에서 33초로 약 46% 단축시키는 놀라운 최적화 성과를 보여주었습니다.
코드 분석: 핵심 변경 사항
1. Fallback 로직 제거 및 Native 경로 확보
기존에는 FP4 타입이 들어오면 tt.dot_scaled 연산을 수행하지 못하고 일반 tt.dot으로 분해(Decompose)했습니다. 이를 수정하여 하드웨어 가속이 가능한 경로로 유도합니다.
Before (AccelerateMatmul.cpp):
// E5M2/E4M3(FP8)만 지원하고 나머지는 실패 처리
if (!((dotOp.getAElemType() == ScaleDotElemType::E5M2 ||
dotOp.getBElemType() == ScaleDotElemType::E4M3) && ...)) {
return rewriter.notifyMatchFailure(dotOp, "only E5M2/E4M3 is supported");
}
After (AccelerateMatmul.cpp):
// FP4를 포함한 더 넓은 범위의 타입을 수용하기 위해 제약 조건 삭제
// 이제 FP4 타입도 failure()를 반환하지 않고 MMA(Matrix Multiply-Accumulate) 인코딩으로 변환됩니다.
2. PTX MMA 인스트럭션 매핑 확장
NVIDIA GPU의 어셈블리 수준 언어인 PTX에서 FP4 가속을 위한 mma.sync 명령어를 정의합니다. 특히 mxf4nvf4와 같은 새로운 하드웨어 모드를 명시합니다.
MMAv2.cpp 변경점:
{TensorCoreType::FP32_FP4E2M1_FP4E2M1_FP32_SCALE_VEC_2X,
"mma.sync.aligned.m16n8k64.row.col."
"kind::mxf4nvf4.block_scale.scale_vec::"
"2X.f32.e2m1.e2m1.f32.ue8m0"},
{TensorCoreType::FP32_NVFP4_NVFP4_FP32_SCALE_VEC_4X,
"mma.sync.aligned.m16n8k64.row.col."
"kind::mxf4nvf4.block_scale.scale_vec::"
"4X.f32.e2m1.e2m1.f32.ue4m3"},
여기서 scale_vec::2X 또는 4X는 하드웨어가 한 번에 처리하는 스케일 벡터의 크기를 의미하며, 이는 FP4의 높은 압축률을 하드웨어가 직접 핸들링할 수 있게 합니다.
3. Scale Vector 패킹 로직 구현
FP4 연산 시 스케일 값들을 하드웨어 레지스터 규격에 맞게 패킹하는 packElements 헬퍼 함수가 추가되었습니다. 여러 개의 바이트를 하나의 32비트 레지스터로 병합하는 비트 연산 최적화가 핵심입니다.
MMAv2.cpp 내 packElements:
auto packElements = [&](ArrayRef<Value> bytes, int loc, int numBytes) -> Value {
Value packed = tb.zext(i32, bytes[loc]);
for (int i = 1; i < numBytes; ++i) {
Value byte = tb.zext(i32, bytes[loc + i]);
Value shifted = tb.shl(byte, tb.i32_val(i * 8)); // 8비트씩 시프트하여 패킹
packed = tb.or_(packed, shifted);
}
return packed;
};
왜 이게 좋은 최적화인가?
1. 극적인 성능 향상 (E2E vLLM Benchmark)
- 기존 (Main Branch): 61 sec
- 개선 후 (This PR): 33 sec
- 결과: 약 1.85배의 속도 향상이 관찰되었습니다. 이는 소프트웨어적인 에뮬레이션(Decomposition)을 하드웨어 고유 기능(Native Support)으로 대체했을 때 얻을 수 있는 전형적인 이득입니다.
2. 메모리 및 레지스터 효율성
기존 방식은 FP4 데이터를 연산 전에 FP16/32로 확장해야 했으므로 레지스터 점유율이 높았습니다. Native FP4를 사용하면 하드웨어가 직접 저정밀도 데이터를 읽어 연산하므로 레지스터 압박(Register Pressure)이 줄어들고 더 큰 배치 사이즈나 복잡한 커널을 실행할 수 있는 여유가 생깁니다.
3. 유연한 스케일링 지원
리뷰어 @masahi의 조언에 따라 LinearLayout을 scale_vec 모드에 구애받지 않도록(Agnostic) 설계하여, 향후 다른 정밀도나 새로운 하드웨어가 추가되어도 유연하게 대응할 수 있는 구조를 갖추었습니다.
마치며
이번 PR은 최신 NVIDIA GPU 아키텍처의 잠재력을 끌어내기 위해 컴파일러 레벨에서 어떤 작업이 필요한지 잘 보여줍니다. 단순히 코드를 짜는 것을 넘어, 하드웨어의 PTX 명세와 데이터 패킹 방식을 깊이 이해하고 이를 MLIR 다이얼렉트 변환 과정에 녹여낸 훌륭한 사례입니다.
LLM 추론 가속화를 고민하는 엔지니어라면, Triton과 같은 커널 언어가 어떻게 하드웨어의 저정밀도 가속기를 활용하는지 이 PR을 통해 학습해 보시길 권장합니다.
참고 자료
- https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling
- https://pytorch.org/docs/stable/generated/torch.cuda.get_device_capability.html
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
- [triton] Triton NVIDIA GPU 백엔드: WarpGroupDotWaitOp 최적화 및 동기화 개선
- [triton] [Blackwell] NVIDIA 차세대 아키텍처를 위한 Triton의 tcgen05.ld.red 최적화 분석
- [sglang] SGLang의 디코드 성능 향상을 위한 Temperature 및 Softmax 커널 융합
- [triton] GSan AxisInfo 기반 Shadow Update 중복 제거로 2~10배 성능 향상
- [triton] Triton AMD 백엔드 최적화: SGPR 활용과 루프 최적화를 통한 GEMM 성능 향상
PR Analysis 의 다른글
- 이전글 [Triton] Gluon 레이아웃 검증 에러 메시지 개선
- 현재글 : [triton] [NVIDIA] SM120을 위한 FP4 Native Scaled Matmul 지원 및 성능 최적화 분석
- 다음글 [Triton] AxisInfo의 unrealized_conversion_cast 처리 강화
댓글