본문으로 건너뛰기

[Axolotl] ScatterMoE LoRA Triton 커널의 autotune 탐색 공간 축소

PR 링크: axolotl-ai-cloud/axolotl#3525 상태: Merged | 변경: +12 / -12

들어가며

Triton의 @triton.autotune은 여러 커널 설정(block size, warp 수, pipeline stage 수)을 런타임에 벤치마크하여 최적 설정을 선택합니다. 그러나 탐색 공간이 너무 크면 첫 호출 시 오랜 컴파일 시간이 소요되고, 일부 설정은 GPU의 shared memory 제한을 초과하여 실패합니다.

핵심 코드 분석

scatter2scatter_lora (Forward 커널)

Before:

for block_m, block_n, block_k, warps, stages in product(
    [32, 64, 128],           # BLOCK_M
    [32, 64, 128, 256],      # BLOCK_N
    [32, 64, 128],           # BLOCK_K
    [4, 8],                  # num_warps
    [3, 4, 5],               # num_stages
):

After:

for block_m, block_n, block_k, warps, stages in product(
    [32, 64, 128],           # BLOCK_M
    [32, 64],                # BLOCK_N  (128, 256 제거)
    [32, 64, 128],           # BLOCK_K
    [4, 8],                  # num_warps
    [3, 4, 5],               # num_stages
):

BLOCK_N에서 128과 256을 제거했습니다. 이 변경으로 forward 커널의 설정 수가 72개에서 36개로 줄었습니다.

scatter2scatter_lora_dX (Backward dX 커널)

Before: BLOCK_K: [32, 64, 128, 256], BLOCK_N: [32, 64, 128, 256] After: BLOCK_K: [32, 64, 128], BLOCK_N: [32, 64]

group_bwd_lora (Backward dA/dB 커널)

Before: BLOCK_M: [32, 64, 128, 256], BLOCK_K: [32, 64, 128, 256], BLOCK_N: [32, 64, 128, 256] After: BLOCK_M: [32, 64, 128], BLOCK_K: [32, 64, 128], BLOCK_N: [32, 64]

총 설정 수: 768 -> 108로 약 7배 감소.

왜 이게 좋은가

  • 컴파일 시간 단축: autotune은 각 설정을 컴파일+실행하므로, 설정 수 감소가 직접적으로 첫 호출 시간을 줄입니다.
  • shared memory 초과 방지: BLOCK_N=256은 일부 GPU(~99KB shared memory)에서 초과하여 런타임 에러를 유발했습니다.
  • 성능 유지: 벤치마크 결과 큰 block size가 최적인 경우가 드물었습니다. LoRA의 rank가 작기 때문에(일반적으로 8-64) 큰 BLOCK_N이 불필요합니다.

정리

12줄 변경으로 autotune 탐색 공간을 7배 줄여 컴파일 시간을 단축하고 GPU 호환성을 개선했습니다. "더 많은 선택지 = 더 좋은 결과"가 아닌, 실측 기반의 탐색 공간 정리입니다.

참고 자료


이 포스트는 AI가 작성하였으며, 사실과 다를 수 있습니다. 정확한 정보는 원본 PR을 참고해 주세요.

댓글

관련 포스트

PR Analysis 의 다른글