본문으로 건너뛰기

[triton] Triton 커널 최적화: Mask Sorting을 통한 Reduction 연산 가속화

PR 링크: triton-lang/triton#10317 상태: Merged | 변경: +0 / -0

들어가며

딥러닝 연산에서 Reduction(합계, 평균 등)은 매우 빈번하게 발생하며, 성능에 큰 영향을 미치는 핵심 연산입니다. 특히 입력 데이터의 일부만 계산에 참여하게 하는 Masking이 적용된 경우, 커널의 효율성은 더욱 중요해집니다.

기존의 Triton reduce_forward 커널은 마스크의 분포와 상관없이 고정된 루프 범위를 탐색하는 경향이 있었습니다. 예를 들어, 한 행(row)에 활성화된 요소가 1개뿐인데도 전체 루프 범위인 8을 모두 확인해야 한다면 7번의 불필요한 연산이 발생합니다.

이번에 분석할 PR은 "broadcast_n" 마스크 패턴이 사용될 때, 활성 입력 행의 개수에 따라 데이터를 정렬(Sorting)하고 루프 바운드를 하드코딩하여 성능을 최대 7.2%까지 끌어올린 최적화 사례입니다.


코드 분석: 핵심 변경 사항

1. 워크로드 기반의 행 정렬 (Mask Sorting)

가장 핵심적인 아이디어는 _create_row_idxs라는 새로운 커널의 도입입니다. 이 커널은 각 행에서 실제로 계산이 필요한 요소의 개수(ActiveInputCounts)를 계산하고, 비슷한 작업량을 가진 행들을 그룹화합니다.

[After: _create_row_idxs 커널 도입]

@triton.jit
def _create_row_idxs(Mask, stride_mr, stride_m0, ...):
    # ... 생략 ...
    # 각 행별 활성 요소 개수 계산
    n_actives = tl.sum((mask != 0).to(tl.int32), axis=1)
    tl.store(ActiveInputCounts + idxs, n_actives, mask=valid)

    # 활성 요소 개수에 따라 행 인덱스를 그룹화 (4개 이상, 3개, 2개, 1개 이하)
    has_4_or_more = valid & (n_actives >= 4)
    # ... (중략) ...
    has_3 = valid & (n_actives == 3)
    # ... (중략) ...
    # 정렬된 RowIdxs를 저장하여 메인 커널에서 순차적으로 처리하게 함
    tl.store(RowIdxs + tl.gather(idxs, idxs4, axis=0), idxs, mask=has_4_or_more)

이 과정은 엄격한 argsort가 아니라, 작업량에 따른 버킷 정렬(Bucket Sort)에 가깝습니다. 이를 통해 메인 커널은 "무거운 행"부터 "가벼운 행"까지 차례대로 처리할 수 있게 됩니다.

2. 루프 바운드 최적화 및 비트맵 활용

정렬된 인덱스를 바탕으로 메인 로직인 _reduce_forward_innerLIMIT라는 인자를 전달받습니다. 이 LIMIT는 해당 블록이 처리해야 할 최대 활성 요소 개수를 의미하며, 이를 통해 루프를 조기에 종료하거나 컴파일 타임에 최적화할 수 있습니다.

[After: 비트맵을 활용한 조건부 로드]

@triton.jit(noinline=True)
def _reduce_forward_inner(..., LIMIT: tl.constexpr):
    # ...
    if USE_BITMAP:
        # 정렬된 인덱스로 행을 불러옴
        offs_s0 = tl.load(RowIdxs + offs_s0, mask=valid_s0, other=0)
        
        # 마스크를 비트맵(uint32)으로 압축하여 어떤 요소가 활성 상태인지 빠르게 파악
        m = (m != 0).to(tl.uint32) << tl.arange(0, BLOCK_K)[None, :]
        bitmap = tl.sum(m, axis=1)

기존에는 마스크를 매번 확인하며 분기 처리를 했다면, 이제는 비트맵을 통해 한 번에 활성 상태를 파악하고 LIMIT만큼만 루프를 돌게 됩니다.

3. 런타임 설정 최적화 (Chaining Factor)

대규모 텐서 처리를 위해 chain_factor 개념이 강화되었습니다. 이는 하나의 프로그램(CTA)이 연속적으로 여러 블록을 처리하게 하여 커널 런칭 오버헤드를 줄이고 캐시 효율을 높입니다.

[Before vs After: 설정 로직 변경]

# Before
return OptFlags(block_s0, block_s1, block_s1, 4, use_static_loop)

# After
if mask_chainable and S0 >= 32768 and x_dtype.itemsize >= 2:
    # SM 개수와 데이터 크기에 따라 chain_factor 계산
    chain_factor = (grid_m * grid_n) // (target_info.num_sms() * 4)
    chain_factor = min(max(1, chain_factor), grid_n)
    # ...
    if chain_factor > 1:
        use_static_loop = False # 체이닝 시 동적 루프 사용
else:
    chain_factor = 1

return OptFlags(32, 128, 128 // reduction_n, num_warps, use_static_loop, chain_factor)

왜 이게 좋은 최적화인가?

1. 제어 흐름 분산(Control Divergence) 감소

GPU는 Warp 단위로 명령어를 실행합니다. 만약 같은 Warp 내의 스레드들이 서로 다른 루프 횟수를 가진다면 성능 저하가 발생합니다. 이 PR은 작업량이 비슷한 행들을 미리 모아줌으로써 Warp 내의 모든 스레드가 최대한 비슷한 시점에 루프를 종료하게 만듭니다.

2. 컴파일러 최적화 유도

LIMITtl.constexpr로 전달함으로써 Triton 컴파일러는 루프 언롤링(Unrolling)과 데드 코드 제거(Dead Code Elimination)를 더 공격적으로 수행할 수 있습니다. 특히 LIMIT가 작은 경우(예: 2 또는 3), 불필요한 메모리 접근을 완전히 제거할 수 있습니다.

3. 실질적인 성능 향상

GB300(NVIDIA의 차세대 아키텍처로 추정) 벤치마크 결과, 배치 사이즈가 커질수록 최적화 효과가 두드러졌습니다.

  • B=8192: +3.6%
  • B=65536: +7.2% (8998 GB/s -> 9650 GB/s)

메모리 대역폭(Bandwidth)이 이미 9TB/s에 육박하는 극한의 상황에서 7% 이상의 추가 성능을 뽑아낸 것은 매우 인상적인 결과입니다.

결론

이번 Triton의 변경사항은 단순히 알고리즘을 개선하는 것을 넘어, 데이터의 특성(Mask 분포)을 파악하여 하드웨어 가속기에 최적화된 형태로 재배치하는 것이 얼마나 중요한지 보여줍니다.

소프트웨어 엔지니어로서 우리는 "모든 데이터를 공평하게 처리"하려는 유혹에 빠지기 쉽지만, 고성능 컴퓨팅의 세계에서는 "비슷한 것끼리 모아서 다르게 처리"하는 것이 승리의 열쇠가 됩니다.


참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글