[triton] MMAv2 dot에 Prefetch 재활성화 - 루프 프롤로그 분리 방식으로 재설계
PR 링크: triton-lang/triton#9806 상태: Merged | 변경: +565 / -115
들어가며
Prefetch는 dot 연산의 피연산자를 미리 로드하여 shared memory 접근 지연을 숨기는 최적화입니다. MMAv2(Ampere/Hopper) dot에 대한 prefetch가 이전에 비활성화되어 있었는데, 이 PR은 루프 프롤로그 분리 방식으로 완전히 재설계하여 재활성화합니다.
핵심 코드 분석
기존 prefetch는 K 차원을 분할하여 subview로 처리했지만, 새 방식은 루프의 첫 번째 반복을 프롤로그로 분리합니다.
After (새 패턴):
// 프롤로그: 첫 반복의 operand를 미리 로드
%a_view0 = ttg.memdesc_index %a[%idx_next0]
%wait0 = ttg.async_wait %tok0, %tok1 {num = 4 : i32}
%a0 = ttg.local_load %a_view0 token %wait0
%b0 = ttg.local_load %b_view0 token %wait0
// 루프: 현재 반복의 연산 + 다음 반복의 prefetch
%loop = scf.for ... iter_args(..., %a_prefetch = %a0, %b_prefetch = %b0) {
%wait_next = ttg.async_wait ...
%a_rem = ttg.local_load %a_tail token %wait
%dot0 = tt.dot %a_prefetch, %b_prefetch, %acc
%a_next = ttg.local_load %next_a_head token %wait_next
%acc_next = tt.dot %a_rem, %b_rem, %dot0
scf.yield ..., %a_next, %b_next
}
각 반복에서 현재 반복의 나머지 부분(tail)과 다음 반복의 시작 부분(head)을 로드하여, dot 연산과 메모리 접근이 오버랩됩니다.
왜 이게 좋은가
루프 프롤로그 분리는 소프트웨어 파이프라이닝의 고전적 기법입니다. 첫 반복의 로드를 루프 외부로 분리하면, 루프 내에서는 항상 "현재 반복 연산"과 "다음 반복 로드"가 병렬로 실행됩니다. 이는 MMAv2 dot에서 shared memory -> register 전송 지연을 효과적으로 숨깁니다.
정리
- MMAv2 dot에 대한 prefetch를 루프 프롤로그 분리 방식으로 재설계
- 첫 반복 프롤로그 + 루프 내 prefetch 패턴 구현
- K 차원 subview 대신 루프 구조 변환 방식 채택
- async wait token 기반 동기화 지원
참고 자료
이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [Ray Data] _map_task 공통 인자 캐싱으로 직렬화 오버헤드 절감
- 현재글 : [triton] MMAv2 dot에 Prefetch 재활성화 - 루프 프롤로그 분리 방식으로 재설계
- 다음글 [Ray Data] PyArrow 스키마 해싱 방식 개선으로 대규모 데이터셋 성능 향상
댓글