본문으로 건너뛰기

[triton] [Triton] Persistent Matmul 성능을 13% 향상시킨 정교한 Shared Memory 계산 기법 분석

PR 링크: triton-lang/triton#10386 상태: Merged | 변경: +13 / -10

들어가며

GPU 커널 최적화의 핵심은 '자원의 한계 내에서 얼마나 많은 작업을 병렬화(Pipelining)할 수 있는가'에 달려 있습니다. 특히 NVIDIA의 Hopper나 Blackwell 아키텍처에서 사용되는 Persistent Kernel 방식은 커널 런칭 오버헤드를 줄이고 하드웨어 활용도를 극대화하는 데 매우 효과적입니다.

하지만 최근 Triton 레포지토리에 반영된 [kernels] change heuristic of smem calculation PR에 따르면, 기존 Triton의 Shared Memory(이하 smem) 계산 로직이 다소 보수적이었음이 밝혀졌습니다. 이로 인해 충분한 메모리 공간이 있음에도 불구하고 파이프라인 스테이지(num_stages)가 제한되어 성능을 100% 끌어내지 못하고 있었습니다.

이번 글에서는 Triton이 어떻게 smem 계산 방식을 정교화하여 GB200에서 Matmul 성능을 약 13%(560 TFLOP/s -> 630 TFLOP/s) 향상시켰는지 코드 레벨에서 분석해 보겠습니다.


코드 분석: 무엇이 바뀌었는가?

1. matmul.py: 레이아웃 정보의 조기 확보

최적화의 핵심은 '현재 연산에 추가적인 smem 타일이 필요한가'를 판단하는 것입니다. 이를 위해 B 행렬(Weight)의 트랜스포즈 여부를 더 일찍 파악하도록 수정되었습니다.

Before:

# matmul 함수 하단부에서 뒤늦게 계산됨
b_transpose = b_is_shuffled or b.storage.data.stride()[-2] == 1

After:

# matmul 함수 상단부, opt_flags를 만들기 전에 미리 계산
a_transpose = a.stride(-1) != 1
b_transpose = b_is_shuffled or b.storage.data.stride()[-2] == 1

# ... 중략 ...

# opt_flags 생성 시 w_transpose 정보를 전달
make_opt_flags(
    # ...
    w_transpose = b_transpose,
    # ...
)

기존에는 b_transpose 여부를 커널 설정의 핵심인 opt_flags를 생성한 이후에 계산했습니다. 하지만 이제는 이를 미리 계산하여 compute_num_stages 로직에 주입함으로써, 실제 하드웨어 제약 사항을 더 정확히 시뮬레이션할 수 있게 되었습니다.

2. opt_flags_nvidia.py: 하드코딩된 '헤드룸' 제거와 정교한 모델링

가장 극적인 변화는 smem 가용 용량을 계산하는 휴리스틱 함수인 compute_num_stages에서 일어났습니다.

Before:

# Persistent fp32 커널에 대해 막연하게 32KB를 빼버림
if is_persistent and (lhs_dtype == FP32 or rhs_dtype == FP32):
    smem_capacity -= 32 * 1024

# ... 중략 ...

# 무조건 최대 스테이지를 3으로 캡핑(Capping)
if is_persistent and (lhs_dtype == FP32 or rhs_dtype == FP32):
    num_stages = min(num_stages, 3)

기존 로직은 FP32/TF32 연산 시 메타데이터나 TMA(Tensor Memory Accelerator) 상태를 위해 막연하게 32KB의 여유 공간을 남겨두고, 스테이지 수도 최대 3개로 강제 제한했습니다. 이는 안전하지만 비효율적인 방식이었습니다.

After:

# 막연한 32KB 차감 대신, 실제 필요한 '변환용 타일' 크기만큼만 차감
if rhs_dtype == FP32 and not w_transpose:
    # For fp32 B, a non-transposed input requires a transpose after its
    # TMA load before MMA. Persistent lowering materializes one extra
    # BLOCK_K x BLOCK_N tile for that conversion.
    smem_capacity -= int(block_k * block_n * weight_size)

# ... 중략 ...

# 하드코딩된 num_stages = 3 제한(min 함수) 제거
# 이제 smem_capacity가 허용하는 한 4개 이상의 스테이지도 가능해짐

개선된 로직은 "B 행렬이 FP32이고 트랜스포즈가 되어 있지 않은 경우"에만 주목합니다. 이 경우 TMA 로드 이후 MMA(Matrix Multiply-Accumulate) 연산 전 단계에서 레이아웃 변환을 위한 추가적인 BLOCK_K x BLOCK_N 크기의 타일 하나가 smem에 필요합니다.

따라서 막연한 32KB가 아니라, 실제 데이터 타입(weight_size)과 블록 크기에 기반한 정확한 바이트 수를 차감합니다. 만약 이미 트랜스포즈가 되어 있다면 이 차감조차 하지 않으므로, 더 많은 smem 공간을 파이프라인 스테이지 확장에 사용할 수 있게 됩니다.


왜 이게 좋은 최적화인가?

1. Latency Hiding의 극대화 (3-stage vs 4-stage)

GPU 연산에서 스테이지 수가 3에서 4로 늘어난다는 것은, 메모리에서 데이터를 가져오는 동안 연산을 수행하는 '더블 버퍼링' 이상의 효과를 의미합니다. 특히 Blackwell(GB200)과 같은 최신 아키텍처는 메모리 대역폭이 엄청나기 때문에, 파이프라인 깊이가 깊어질수록 연산 유닛(ALU)이 쉬지 않고 돌아갈 확률이 높아집니다. 이번 변경으로 M=N=K=4096 Matmul에서 13%의 성능 향상을 얻은 것이 그 증거입니다.

2. 하드웨어 특성에 기반한 정확한 비용 모델링

소프트웨어 엔지니어링 관점에서 "Magic Number"(여기서는 32 * 1024min(..., 3))를 제거하고, 실제 알고리즘(TMA load 후 transpose conversion)이 요구하는 메모리 비용을 수식화(block_k * block_n * weight_size)했다는 점이 훌륭합니다. 이는 코드를 더 유지보수하기 좋게 만들고, 다양한 블록 사이즈에 대해 유연하게 대응할 수 있게 합니다.

3. 조건부 최적화

모든 상황에서 스테이지를 늘리는 것이 아니라, NNN(Non-transposed) 레이아웃처럼 실제로 메모리가 더 필요한 경우에는 여전히 보수적인 계산을 유지하여 런타임 에러(Out of Resources)를 방지했습니다. 즉, 안전성을 해치지 않으면서 특정 조건(B가 올바른 레이아웃일 때)에서 성능 잠재력을 해방시킨 것입니다.

결론

이번 Triton의 변경사항은 고성능 컴퓨팅(HPC) 라이브러리에서 휴리스틱을 정교화하는 것이 얼마나 큰 임팩트를 줄 수 있는지 잘 보여줍니다. 단순히 "메모리가 부족할 것 같으니 제한하자"는 접근에서 벗어나, "정확히 어떤 연산이 얼마만큼의 메모리를 쓰는가"를 정의함으로써 최신 GPU 하드웨어의 성능을 한계까지 끌어올릴 수 있었습니다.

GB200과 같은 차세대 가속기를 다루는 엔지니어라면, 이러한 smem 관리 기법과 파이프라이닝 전략을 깊이 이해할 필요가 있습니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글