본문으로 건너뛰기

[triton] Multi-CTA 예제에서 Program ID를 Shared Memory에 저장하여 재계산 방지

PR 링크: triton-lang/triton#9656 상태: Merged | 변경: +71 / -10

들어가며

Warp-specialized 커널에서 CLC(Cluster Launch Control) 기반 타일 스케줄링은 동적으로 작업 타일을 할당합니다. 이때 _planar_snake 함수로 tile_id를 (pid_m, pid_n)으로 변환하는 연산이 필요한데, 기존에는 consumer 파티션에서 매번 이 계산을 수행했습니다. 이 PR은 CLC 파티션에서 계산한 pid를 shared memory에 저장하여 consumer가 직접 읽도록 합니다.

핵심 코드 분석

Before (consumer에서 매번 계산):

@gluon.jit
def get_offsets(self):
    pid_m, pid_n = _planar_snake(self.tile_id, self.num_pid_m,
                                  self.num_pid_n, self.MINOR_DIM,
                                  self.GRID_TILE_WIDTH)
    return pid_m * self.TILE_M, pid_n * self.TILE_N

After (shared memory에서 직접 읽기):

# CLC 파티션에서 계산 후 shared memory에 저장
if has_work:
    tile_id = clc_res.program_id(0)
    pid_m, pid_n = _planar_snake(tile_id, num_pid_m, num_pid_n, ...)
packed_pid = (pid_m.to(gl.int64) << 32) | (pid_n.to(gl.int64) & 0xFFFFFFFF)
planar_slot.store(gl.full([1], packed_pid, gl.int64, layout=planar_layout))
mbarrier.arrive(planar_ready_bar)

# Consumer에서는 shared memory에서 읽기만 함
mbarrier.wait(self.clc_planar_ready_bars.index(counter.index), counter.phase)
packed_pid = planar_slot.load(planar_layout).reshape([])
pid_m = ((packed_pid >> 32) & 0xFFFFFFFF).to(gl.int32)
pid_n = (packed_pid & 0xFFFFFFFF).to(gl.int32)

왜 이게 좋은가

CLC 파티션은 이미 clc_res.program_id(0)에서 tile_id를 알고 있으므로, 여기서 계산한 pid를 재활용하는 것이 합리적입니다. _planar_snake는 나눗셈과 모듈로 연산을 포함하므로 GPU에서 비싼 연산입니다. 64비트 정수에 두 32비트 pid를 pack하여 single write/read로 전달하는 것은 shared memory 대역폭을 효율적으로 사용하며, mbarrier를 통해 동기화를 보장합니다.

정리

CLC 파티션에서 planar snake pid를 계산한 후 shared memory에 packed 형태로 저장하고, consumer 파티션에서는 barrier 동기화 후 읽기만 수행하도록 변경하여 중복 계산을 제거했습니다.

참고 자료

이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글