본문으로 건너뛰기

[triton] AMD gfx1250 MXFP Flash Attention 예제 커널 업데이트

PR 링크: triton-lang/triton#9522 상태: Merged | 변경: +296 / -305

들어가며

이 PR은 AMD gfx1250의 MXFP Flash Attention Gluon 예제 커널을 대폭 리팩터링합니다. 주요 변경은 WMMA 레이아웃 생성 로직 개선, shared memory 레이아웃의 padding 옵션 추가, MemoryUnit/MemoryBlock 추상화 간소화 등입니다.

핵심 코드 분석

Before - 고정된 레이아웃

def get_store_layout(block_shape, num_warps):
    if dim_inner == 64:
        reg = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16]]
        lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 32]]
    else:
        assert dim_inner == 128
        reg = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [0, 32]]
        lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 64]]

After - 동적 레이아웃 생성

def get_store_layout(shape, num_warps):
    dim_outer, dim_inner = shape
    reg = [[0, 1], [0, 2]]  # 기본 4 contiguous elements (128 bits)
    tile_inner = 4
    lane = [[1, 0], [2, 0], [4, 0], [8, 0]]  # 16 lanes for outer dim
    tile_outer = 16
    # 동적으로 inner dim 절반까지 확장
    while tile_inner < dim_inner // 2:
        reg.append([0, tile_inner])
        tile_inner <<= 1
    lane.append([0, tile_inner])  # 나머지 16 lanes for inner dim

왜 이게 좋은가

  1. 유연한 레이아웃: head_size가 64, 128 외의 값이어도 동작하는 범용적인 레이아웃 생성 로직입니다.
  2. padding 옵션: get_shared_layout(shape, padding=True)로 bank conflict 회피를 선택적으로 활성화합니다.
  3. 코드 간소화: MemoryBlock/MemoryUnit 클래스를 제거하고 더 직관적인 함수형 인터페이스로 대체했습니다.

정리

Flash Attention 예제 커널의 레이아웃 로직을 범용화하고 추상화를 간소화한 리팩터링입니다. 특히 하드코딩된 조건 분기를 while 루프 기반의 동적 생성으로 바꾼 것이 핵심 개선입니다.

참고 자료


이 글은 AI(Claude)의 도움을 받아 작성되었으며, 원본 PR의 코드 변경 사항을 기반으로 분석한 내용입니다.

댓글

관련 포스트

PR Analysis 의 다른글