[triton] AMD GFX1250을 위한 Triton Stream-K 커널 최적화: 4/8 Warp 구현
PR 링크: triton-lang/triton#9370 상태: Merged | 변경: +1336 / -0
들어가며
최근 Triton 레포지토리에 AMD GFX1250 아키텍처를 타겟으로 한 Stream-K 커널 최적화 PR이 병합되었습니다. Stream-K는 행렬 곱셈(GEMM) 연산 시 워크로드 분산이 불균형할 때 발생하는 성능 저하를 해결하기 위한 기법입니다. 이번 PR은 특히 4 warp 및 8 warp 구성을 지원하고, atomic 기반의 스핀락과 효율적인 버퍼 관리를 통해 연산 효율을 극대화하는 데 초점을 맞추고 있습니다.
코드 분석
1. 256x256 타일 처리를 위한 쿼드런트 분할
기존의 단일 버퍼 저장 방식은 256x256과 같은 대형 타일에서 레지스터 압박을 유발합니다. 이를 해결하기 위해 split_accumulator_quadrant 함수가 도입되었습니다.
@gluon.jit
def split_accumulator_quadrant(accumulator, HALF_M, HALF_N, qm, qn):
acc_4d = accumulator.reshape([2, HALF_M, 2, HALF_N])
acc_4d = acc_4d.permute(1, 3, 0, 2)
acc_n0, acc_n1 = acc_4d.split()
# ... (생략)
return acc_m0 if qm == 0 else acc_m1
이 변경은 accumulator를 4개의 쿼드런트로 나누어 처리함으로써, 메모리 접근 패턴을 최적화하고 buffer_store 및 buffer_load 시의 대역폭 효율을 높입니다.
2. Atomic 기반의 Stream-K 동기화
Stream-K의 핵심은 여러 워크그룹이 부분적인 타일을 계산하고, 이를 최종 결과로 합치는 과정입니다. 이번 PR에서는 atomic_cas를 사용하여 효율적인 동기화를 구현했습니다.
# Owner: Aggregate contributors and store result
while ttgl.atomic_cas(locks_ptr + next_pid, 1, 1) != 1:
pass
# ... (Load and accumulate quadrants)
atomic_cas를 이용한 스핀락은 다른 워크그룹이 계산을 완료할 때까지 대기하며, 완료 즉시 데이터를 병합하여 지연 시간을 최소화합니다.
왜 이게 좋은가
이번 최적화의 핵심은 연산과 메모리 접근의 오버랩입니다.
- Persistent Loop Prefetch: 루프 내에서
async_load를 적극적으로 활용하여 prologue와 epilogue 사이의 연산 공백을 제거했습니다. - Atomic Spinning Locks: 기존의 무거운 동기화 방식 대신 가벼운 atomic 연산을 사용하여 GPU의 유휴 시간을 줄였습니다.
- 메모리 효율성: 256x256 타일을 쿼드런트 단위로 쪼개어 처리함으로써, 캐시 히트율을 높이고 레지스터 사용량을 최적화했습니다.
일반적으로 Stream-K 커널을 설계할 때는 워크그룹 간의 의존성을 최소화하고, 공유 메모리(또는 P-buffer) 접근 시의 충돌을 방지하는 것이 성능의 핵심입니다. 이번 PR은 GFX1250 아키텍처의 특성을 잘 활용한 좋은 사례입니다.
참고 자료
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [Triton] AMD PartitionedSharedEncodingAttr 도입으로 shared memory 파티셔닝 지원
- 현재글 : [triton] AMD GFX1250을 위한 Triton Stream-K 커널 최적화: 4/8 Warp 구현
- 다음글 [triton] ConSan 컴파일 타임 19분에서 34초로 단축 - 대규모 최적화
댓글