[triton] AMD gfx1250 MXFP Flash Attention 예제 커널 업데이트
PR 링크: triton-lang/triton#9522 상태: Merged | 변경: +296 / -305
들어가며
이 PR은 AMD gfx1250의 MXFP Flash Attention Gluon 예제 커널을 대폭 리팩터링합니다. 주요 변경은 WMMA 레이아웃 생성 로직 개선, shared memory 레이아웃의 padding 옵션 추가, MemoryUnit/MemoryBlock 추상화 간소화 등입니다.
핵심 코드 분석
Before - 고정된 레이아웃
def get_store_layout(block_shape, num_warps):
if dim_inner == 64:
reg = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16]]
lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 32]]
else:
assert dim_inner == 128
reg = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [0, 32]]
lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 64]]
After - 동적 레이아웃 생성
def get_store_layout(shape, num_warps):
dim_outer, dim_inner = shape
reg = [[0, 1], [0, 2]] # 기본 4 contiguous elements (128 bits)
tile_inner = 4
lane = [[1, 0], [2, 0], [4, 0], [8, 0]] # 16 lanes for outer dim
tile_outer = 16
# 동적으로 inner dim 절반까지 확장
while tile_inner < dim_inner // 2:
reg.append([0, tile_inner])
tile_inner <<= 1
lane.append([0, tile_inner]) # 나머지 16 lanes for inner dim
왜 이게 좋은가
- 유연한 레이아웃: head_size가 64, 128 외의 값이어도 동작하는 범용적인 레이아웃 생성 로직입니다.
- padding 옵션:
get_shared_layout(shape, padding=True)로 bank conflict 회피를 선택적으로 활성화합니다. - 코드 간소화: MemoryBlock/MemoryUnit 클래스를 제거하고 더 직관적인 함수형 인터페이스로 대체했습니다.
정리
Flash Attention 예제 커널의 레이아웃 로직을 범용화하고 추상화를 간소화한 리팩터링입니다. 특히 하드코딩된 조건 분기를 while 루프 기반의 동적 생성으로 바꾼 것이 핵심 개선입니다.
참고 자료
이 글은 AI(Claude)의 도움을 받아 작성되었으며, 원본 PR의 코드 변경 사항을 기반으로 분석한 내용입니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [faster-qwen3-tts] README 비스트리밍 RTF 수치 업데이트
- 현재글 : [triton] AMD gfx1250 MXFP Flash Attention 예제 커널 업데이트
- 다음글 [faster-qwen3-tts] 모드 간 성능 동등성 검증 및 벤치마크 비교 문서화
댓글