[triton] Triton에서 Ragged Mode를 위한 X Scale Swizzling 최적화
PR 링크: triton-lang/triton#8897 상태: Merged | 변경: +380 / -49
들어가며
최근 LLM 추론 및 학습에서 MXFP8(Microscaling Formats)과 같은 저정밀도 연산은 성능 향상을 위한 핵심 기술로 자리 잡았습니다. 특히 Triton은 이러한 하드웨어 가속 기능을 효율적으로 활용하기 위해 다양한 레이아웃 최적화를 제공합니다. 이번 PR은 Triton의 ragged 모드에서 X scale(activation scale)에 대한 swizzling을 지원하여, 행렬 곱셈(MatMul)의 지연 시간을 줄이는 것을 목표로 합니다. 기존에는 batched 모드에서만 제한적으로 지원되던 swizzling을 ragged 모드로 확장함으로써, 가변적인 시퀀스 길이를 가진 입력 데이터에서도 고성능 연산이 가능해졌습니다.
코드 분석
1. MatMul 오프셋 계산 로직 변경 (matmul_details/_common.py)
Ragged 모드에서 각 슬라이스(slice)는 가변적인 길이를 가집니다. 이를 128 단위로 패딩하여 TMA(Tensor Memory Accelerator)와 호환되게 만들기 위해 compute_offsets 함수에 XBlockOffs가 추가되었습니다.
# Before
off_x_slice = tl.load(XSliceOffs + off_w_z)
# After
off_x_slice = tl.load(XSliceOffs + off_w_z)
off_x_slice_tile = tl.load(XBlockOffs + off_w_z)
이 변경을 통해 현재 슬라이스의 시작점뿐만 아니라, 패딩된 블록 오프셋까지 정확히 추적할 수 있게 되었습니다.
2. TMA(Tensor Memory Accelerator) 활성화 (matmul.py)
기존에는 TMA 사용 조건이 매우 엄격했으나, BlackwellActMXScaleLayout을 사용하는 경우 persistent 커널 환경에서 X scale swizzling을 적극적으로 활용하도록 로직이 개선되었습니다.
# After
if a_has_mx and isinstance(a_scale.storage.layout, BlackwellActMXScaleLayout):
assert opt_flags.is_persistent, "swizzled x scale is only supported for persistent case"
assert opt_flags.block_m == 128 and opt_flags.block_k >= 128, "..."
a_scale_has_tma = True
3. P-MatMul 오프셋 적용 (matmul_details/_p_matmul.py)
_p_matmul 내부에서 TMA 모드에 따라 스케일 오프셋을 계산하는 방식이 분기되었습니다.
# After
if X_TMA_MODE == "dense":
off_m_scale = off_x_z * ((M + 127) // 128) + off_m // 128
else:
# slice_block_off_m points to the start of the current slice in the padded version
off_m_scale = slice_block_off_m + off_m // 128
왜 이게 좋은가
이번 최적화의 핵심은 **데이터 레이아웃의 정렬(Alignment)**입니다. MXFP8 연산 시 하드웨어는 128바이트 단위의 정렬된 메모리 접근에서 최상의 성능을 발휘합니다. Ragged 모드에서는 데이터가 불규칙하게 배치되어 있어 TMA를 직접 사용하기 어려웠으나, 이번 PR은 다음과 같은 이점을 제공합니다.
- 지연 시간 감소: TMA를 통한 효율적인 데이터 로딩으로 메모리 대역폭 활용도를 극대화합니다.
- 범용성 확보:
ragged모드에서도batched모드와 동일한 수준의 swizzling 최적화를 적용할 수 있게 되어, 다양한 모델 구조(예: MoE)에서 성능 이득을 볼 수 있습니다. - 교훈: 하드웨어 가속기(TMA 등)를 활용할 때는 데이터의 정렬이 필수적이며, 가변적인 데이터(ragged)를 고정된 블록 단위로 매핑하는 메타데이터 관리(padding)가 성능 최적화의 핵심임을 보여줍니다.
리뷰 과정에서 block_k를 128로 제한하여 공유 메모리(Smem) 사용량을 최적화하고 num_stages를 늘리는 등의 세밀한 튜닝이 논의되었으며, 이는 고성능 커널 작성 시 메모리 제약 조건을 고려하는 것이 얼마나 중요한지 잘 보여줍니다.
참고 자료
- https://pytorch.org/docs/stable/generated/torch.compile.html
- https://triton-lang.org/main/index.html
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [triton] 손상된 캐시 파일에 대한 방어적 처리 추가
- 현재글 : [triton] Triton에서 Ragged Mode를 위한 X Scale Swizzling 최적화
- 다음글 [Ray] 단일 노드 RDT 마이크로벤치마크 도입
댓글