[Triton] MXFP Flash Attention 예제에 LDS reduction 적용
들어가며
AMD gfx1250의 MXFP(Microscaled FP) Flash Attention 예제에서 기존에는 softmax의 warp 간 reduction을 위해 2개의 커널을 사용했다. 이 PR은 LDS(Local Data Share)를 활용하여 단일 커널에서 warp 간 softmax를 수행하도록 변경한다.
핵심 코드 분석
Before: 2-커널 방식
MQA(Multi-Query Attention)에서 softmax의 max/sum을 워프 간에 공유하기 위해 중간 결과를 글로벌 메모리에 쓰고 별도 reduce 커널로 합산했다.
After: LDS reduction 방식
# 3D WMMA layout으로 warp 차원 추가
def get_wmma_layout(shape, num_warps, packed=False, preshuffled=False, warp_axis=0):
rank = len(shape)
assert rank == 2 or rank == 3
warps_per_cta = [1] * rank
warps_per_cta[warp_axis] = num_warps
KV 메모리 블록에 split-k 차원이 추가되었다:
k_block_shape = KVMemory.get_shuffle_shape(
[BLOCK_N * SPLIT_K, HEAD_SZ // KV_PACK_DIV] if not SUBTILE else
[BLOCK_N * SPLIT_K // 2, HEAD_SZ // KV_PACK_DIV])
K/V 버퍼를 꺼낼 때도 SPLIT_K에 맞게 reshape한다:
def get_k_buffer(self, sub_idx, buf):
buffer = ...
if cfg.SPLIT_K > 1:
buffer = buffer.reshape([cfg.SPLIT_K, block_shape[0] // cfg.SPLIT_K, block_shape[1]])
buffer = buffer.permute([0, 2, 1])
return buffer
왜 이게 좋은가
- 커널 수 감소: 2커널에서 1커널로 줄여 커널 launch 오버헤드와 글로벌 메모리 왕복을 제거한다.
- LDS 활용: warp 간 reduction에 빠른 LDS를 사용하여 latency를 줄인다.
- 3D layout 지원: WMMA layout이 2D에서 3D로 확장되어 split-k 패턴을 자연스럽게 표현한다.
정리
+337/-382로, 코드 양은 비슷하지만 아키텍처적으로 2커널에서 1커널로 통합하는 의미 있는 변경이다. LDS reduction은 GPU 프로그래밍에서 warp 간 통신의 표준 패턴이다.
참고 자료
이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 코드 분석 내용은 실제 PR diff를 기반으로 합니다.
댓글