본문으로 건너뛰기

[triton] Triton Blackwell 아키텍처를 위한 MXFP8 입력 스케일 스위즐링 최적화

PR 링크: triton-lang/triton#8863 상태: Merged | 변경: +163 / -46

들어가며

최신 AI 모델의 추론 및 학습 성능을 극대화하기 위해 NVIDIA의 Blackwell 아키텍처와 같은 차세대 하드웨어는 MXFP8(Microscaling Formats)과 같은 새로운 데이터 포맷을 지원합니다. 하지만 기존 Triton 구현에서 MXFP8 입력을 처리할 때, FP8 대비 약 1.7배의 지연 시간(latency)이 발생하는 병목 현상이 있었습니다. 본 PR은 Blackwell 아키텍처에서 입력 스케일(x scale)에 대한 스위즐링(swizzling)을 도입하고, TMA(Tensor Memory Accelerator)를 활용하여 데이터를 로드함으로써 이 성능 격차를 1.1배 수준으로 대폭 줄였습니다.

코드 분석

1. 테스트 케이스 확장 (python/triton_kernels/tests/test_matmul.py)

기존에는 hbm_swizzling이라는 단일 플래그로 스위즐링을 제어했으나, 이번 변경을 통해 입력(A)과 가중치(B)에 대한 스위즐링을 개별적으로 제어할 수 있도록 변경되었습니다.

# Before
Case(*shape, "plain", "bfloat16", "mxfloat4_e2m1", hbm_swizzling=True)

# After
Case(*shape, "plain", "bfloat16", "mxfloat4_e2m1", b_hbm_swizzling=True)

또한, _test_op 함수 내에서 a_hbm_swizzling에 대한 검증 로직을 추가하여, 현재 Blackwell(B200) 아키텍처와 배치(batched) 입력, 그리고 mxfloat8 데이터 타입에서만 이 최적화가 동작하도록 제한을 두었습니다.

2. 레이아웃 최적화 (python/triton_kernels/matmul.py)

입력 스케일의 메모리 접근 패턴을 하드웨어 친화적으로 만들기 위해 BlackwellActMXScaleLayout을 도입했습니다. 이는 TMA를 통해 데이터를 로드할 때 메모리 정렬을 최적화하여 대역폭 효율을 높이는 역할을 합니다.

# TMA를 활용한 스케일 로드 설정
scale_hbm_swizzling = layout.make_default_matmul_mxfp8_act_scale_layout if a_hbm_swizzling else None

왜 이게 좋은가

이번 최적화의 핵심은 메모리 접근 패턴의 최적화입니다.

  1. TMA(Tensor Memory Accelerator) 활용: 기존의 일반적인 로드 방식 대신 TMA를 사용하여 스케일 데이터를 로드함으로써, 하드웨어 수준에서 메모리 요청을 병합하고 효율적으로 처리합니다.
  2. 스위즐링(Swizzling): 데이터를 메모리에 저장할 때 특정 패턴으로 재배치하여, GPU의 공유 메모리(Shared Memory) 뱅크 충돌을 방지합니다. 이는 행렬 곱셈 연산 시 데이터 로드 속도를 비약적으로 향상시킵니다.

성능 개선 수치 (B=M=N=K=1024)

  • Before: 143.8ms (FP8 대비 1.7배 느림)
  • After: 102.6ms (FP8 대비 1.1배 느림)

결과적으로 약 28% 이상의 성능 향상을 달성했습니다. 이 사례는 차세대 하드웨어의 특수 기능(TMA, MXFP8)을 활용하기 위해 데이터 레이아웃을 하드웨어 아키텍처에 맞게 재구성하는 것이 얼마나 중요한지 잘 보여줍니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글