본문으로 건너뛰기

[Triton] MXFP Flash Attention 예제에 LDS reduction 적용

들어가며

AMD gfx1250의 MXFP(Microscaled FP) Flash Attention 예제에서 기존에는 softmax의 warp 간 reduction을 위해 2개의 커널을 사용했다. 이 PR은 LDS(Local Data Share)를 활용하여 단일 커널에서 warp 간 softmax를 수행하도록 변경한다.

핵심 코드 분석

Before: 2-커널 방식

MQA(Multi-Query Attention)에서 softmax의 max/sum을 워프 간에 공유하기 위해 중간 결과를 글로벌 메모리에 쓰고 별도 reduce 커널로 합산했다.

After: LDS reduction 방식

# 3D WMMA layout으로 warp 차원 추가
def get_wmma_layout(shape, num_warps, packed=False, preshuffled=False, warp_axis=0):
    rank = len(shape)
    assert rank == 2 or rank == 3
    warps_per_cta = [1] * rank
    warps_per_cta[warp_axis] = num_warps

KV 메모리 블록에 split-k 차원이 추가되었다:

k_block_shape = KVMemory.get_shuffle_shape(
    [BLOCK_N * SPLIT_K, HEAD_SZ // KV_PACK_DIV] if not SUBTILE else
    [BLOCK_N * SPLIT_K // 2, HEAD_SZ // KV_PACK_DIV])

K/V 버퍼를 꺼낼 때도 SPLIT_K에 맞게 reshape한다:

def get_k_buffer(self, sub_idx, buf):
    buffer = ...
    if cfg.SPLIT_K > 1:
        buffer = buffer.reshape([cfg.SPLIT_K, block_shape[0] // cfg.SPLIT_K, block_shape[1]])
        buffer = buffer.permute([0, 2, 1])
    return buffer

왜 이게 좋은가

  • 커널 수 감소: 2커널에서 1커널로 줄여 커널 launch 오버헤드와 글로벌 메모리 왕복을 제거한다.
  • LDS 활용: warp 간 reduction에 빠른 LDS를 사용하여 latency를 줄인다.
  • 3D layout 지원: WMMA layout이 2D에서 3D로 확장되어 split-k 패턴을 자연스럽게 표현한다.

정리

+337/-382로, 코드 양은 비슷하지만 아키텍처적으로 2커널에서 1커널로 통합하는 의미 있는 변경이다. LDS reduction은 GPU 프로그래밍에서 warp 간 통신의 표준 패턴이다.

참고 자료


이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 코드 분석 내용은 실제 PR diff를 기반으로 합니다.

댓글