[triton] AMD GFX1250 MXFP Flash Attention 예제 커널 대규모 리팩터링
PR 링크: triton-lang/triton#9705 상태: Merged | 변경: +554 / -491
들어가며
Flash Attention은 트랜스포머 모델의 핵심 연산으로, GPU에서의 최적 구현이 매우 중요합니다. 이 PR은 AMD GFX1250을 위한 MXFP(Microscaling Floating Point) Flash Attention 예제 커널을 대규모로 리팩터링하여 코드를 단순화하고, 불필요한 preshuffle 로직을 제거하며, 더 효율적인 메모리 접근 패턴을 적용합니다.
핵심 코드 분석
1. Preshuffle 로직 제거
기존에는 K, V 텐서의 scale을 global memory에서 최적 벡터화를 위해 미리 재배치(preshuffle)하고 커널 내에서 다시 원래 형태로 복원(unshuffle)하는 복잡한 로직이 있었습니다.
Before (150+ 라인의 preshuffle/unshuffle 코드):
def preshuffle_kv_scale(x):
"""Preshuffle scales for scaled wmma instruction."""
# 복잡한 view, permute, contiguous 체인
x = x.view(batch, num_chunk_m, 4, preshuffle_factor // 4, ...)
x = x.permute(0, 1, 4, 3, 2, 5).contiguous()
return x.view(...)
After: 이 함수들이 모두 제거되었습니다. TDM이 메모리 레이아웃을 효율적으로 처리하므로 host 측 preshuffle이 불필요해졌습니다.
2. Python Slicing에서 expand_dims로 전환
Before:
offs = offs_m[:, None] * shape[1] + offs_n[None, :]
mask = (offs_m < shape[0])[:, None] & (offs_n < shape[1])[None, :]
After:
offs = expand_dims(offs_m, -1) * shape[1] + expand_dims(offs_n, -2)
mask = expand_dims(offs_m < shape[0], -1) & expand_dims(offs_n < shape[1], -2)
왜 이게 좋은가
이 리팩터링은 약 150줄의 복잡한 preshuffle/unshuffle 코드를 제거하면서 동시에 TDM(Triton Data Mover)의 하드웨어 기능을 더 적극적으로 활용합니다. 코드 단순화는 유지보수성과 가독성을 크게 향상시키며, host 측 데이터 재배치 작업을 제거하여 전처리 오버헤드도 줄입니다. store_layout의 제거와 TDM store로의 전환은 output 저장 경로를 하드웨어에 최적화된 방식으로 변경합니다.
정리
GFX1250 MXFP FA 예제에서 preshuffle/unshuffle 로직을 제거하고, TDM 기반 메모리 접근으로 전환하며, expand_dims API를 적용하여 코드를 약 60줄 줄이면서 하드웨어 활용도를 높였습니다.
참고 자료
이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [Ray] Ray Data에 cuDF 배치 포맷 추가
- 현재글 : [triton] AMD GFX1250 MXFP Flash Attention 예제 커널 대규모 리팩터링
- 다음글 [Ultralytics] 캘리브레이션 데이터셋이 배치보다 작을 때 에러 대신 자동 조정
댓글