[triton] AMD WMMA Utilization 개선: Unroll 제거와 상수 폴딩
PR 링크: triton-lang/triton#9837 상태: Merged | 변경: +14 / -43
들어가며
GPU 커널 최적화에서 레지스터 압력(register pressure)은 성능에 직접적인 영향을 미칩니다. LLVM의 코드 생성기가 루프를 과도하게 언롤링하면 레지스터 사용량이 급증하여 스필링(spilling)이 발생하고, 이는 성능 저하로 이어집니다. 이 PR은 AMD GFX1250의 WMMA(Wave Matrix Multiply Accumulate) 활용률을 개선하기 위해 두 가지 변경을 적용합니다.
핵심 코드 분석
1. 불필요한 Unroll 제거
Before:
for block_id in range(block_min, block_max, 2 * BLOCK_N):
"""
unroll_factor=2 to save computation wrt iter_id
"""
# 1/2 of unrolled loop
t_1 = block_id + BLOCK_N
t_2 = block_id + 2 * BLOCK_N
t_3 = block_id + 3 * BLOCK_N
# ... 첫 번째 반복 로직 ...
v = pgm.tdm_shared_load_v(0, wait_count=2)
pgm.tdm_load_global_to_shared_k([t_3, 0], 1)
k = pgm.tdm_shared_load_k(0, wait_count=2)
pgm.tdm_load_global_to_shared_v([t_2, 0], 0)
# 2/2 of unrolled loop (거의 동일한 코드 반복)
v = pgm.tdm_shared_load_v(1, wait_count=2)
pgm.tdm_load_global_to_shared_k([t_3, 0], 0)
k = pgm.tdm_shared_load_k(1, wait_count=2)
pgm.tdm_load_global_to_shared_v([t_2, 0], 1)
After:
iter_id = 0
for block_id in range(block_min, block_max, BLOCK_N):
# 단일 루프 본문, 모듈로 연산으로 버퍼 인덱스 관리
v = pgm.tdm_shared_load_v(iter_id % NUM_BUFFERS, wait_count=2)
pgm.tdm_load_global_to_shared_k([t_3, 0], (iter_id + 1) % NUM_BUFFERS)
k = pgm.tdm_shared_load_k(iter_id % NUM_BUFFERS, wait_count=2)
pgm.tdm_load_global_to_shared_v([t_2, 0], iter_id % NUM_BUFFERS)
iter_id += 1
수동 2배 언롤을 제거하고, iter_id % NUM_BUFFERS로 rotating register 패턴을 구현했습니다. 코드 크기가 절반으로 줄어들었고, LLVM이 더 나은 레지스터 할당을 수행할 수 있게 되었습니다.
2. 상수 폴딩으로 VALU 연산 감소
# Before: 런타임에 매번 곱셈 수행
m_ij_scaled = m_ij * self.sm_scale * self.rcp_ln2
# After: 컴파일 타임에 상수 미리 계산
self.sm_scale_dot_rcp_ln2 = self.sm_scale * self.rcp_ln2
m_ij_scaled = m_ij * self.sm_scale_dot_rcp_ln2
sm_scale * rcp_ln2는 커널 실행 중 변하지 않는 값이므로, 초기화 시점에 미리 계산하여 루프 내에서 불필요한 곱셈 VALU 명령어를 제거했습니다.
왜 이게 좋은가
이 PR은 **"최신 LLVM의 코드 생성 특성에 맞춰 소스 레벨에서 적응한다"**는 실용적 접근을 보여줍니다. 이론적으로 수동 언롤은 rotating register를 줄여 성능에 유리해야 하지만, 실제로는 LLVM의 언롤링 휴리스틱이 이미 충분히 효과적이며, 수동 언롤이 오히려 레지스터 스필링을 유발했습니다. 상수 폴딩과 결합하여 코드 간결성과 성능을 동시에 개선한 좋은 사례입니다.
정리
- 수동 2배 언롤을 제거하여 LLVM 코드 생성기의 레지스터 스필링 방지
sm_scale * rcp_ln2상수 폴딩으로 루프 내 VALU 연산 감소- 코드 라인 수 43줄 감소, 유지보수성 향상
참고 자료
이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [Open WebUI] sendMessage에서 중복 getChatList 호출 제거
- 현재글 : [triton] AMD WMMA Utilization 개선: Unroll 제거와 상수 폴딩
- 다음글 [Axolotl] 플러그인에 scored rollout 디스패치, 외부 플러그인 경로 확장, vLLM 에러 처리 개선
댓글