[triton] Expert Parallelism 기본 구현과 Reduce 커널 추가
PR 링크: triton-lang/triton#8448 상태: Merged | 변경: +1700 / -153
들어가며
Mixture-of-Experts(MoE) 모델에서 Expert Parallelism은 각 expert를 서로 다른 GPU에 배치하여 병렬 처리하는 기법입니다. 이 PR은 Triton Kernels에 Expert Parallelism의 기본 구현을 추가하고, MoE matmul에서 사용되는 reduce 연산을 독립 커널로 분리합니다.
핵심 코드 분석
새로운 reduce 커널 핵심 루프
@triton.jit
def _reduce(X, stride_xr, stride_x0, stride_x1,
XMx, stride_xmxr, stride_xmx0, stride_xmx1,
Y, stride_y0, stride_y1, ...):
pid_s0 = tl.program_id(0)
pid_s1 = tl.program_id(1)
y = tl.zeros((BLOCK_S0, BLOCK_S1), dtype=tl.float32)
for k in tl.range(0, K, num_stages=2):
x_ptrs = X + k * stride_xr + offs_s0[:, None] * stride_x0 + offs_s1[None, :] * stride_x1
x = tl.load(x_ptrs, mask=valid_s0[:, None] & valid_s1[None, :], other=0.0)
# MX scale 적용
if XMx is not None:
xmx = (xmx.to(tl.uint32) << 23).to(tl.float32, bitcast=True)
x = (xmx[:, :, None] * x.reshape(...)).reshape(...)
y += x
Before - matmul_ogs에 내장된 reduce
# matmul_ogs.py 내부에 _reduce_grouped이 결합되어 있음
from .matmul_ogs_details._reduce_grouped import _reduce_grouped
After - 독립적인 reduce 모듈
# 독립 모듈로 분리
from .reduce import reduce
from .reduce import PostprocessFn as ReducePostprocessFn
왜 이게 좋은가
- 관심사 분리: reduce 연산을 matmul에서 독립시켜 다른 컨텍스트에서도 재사용 가능하게 만들었습니다.
- 유연한 후처리:
PostprocessFn을 통해 reduce 결과에 커스텀 후처리(SwiGLU 등)를 적용할 수 있습니다. - MoE 확장성: Expert Parallelism의 기반을 마련하여 split-k reduction과 inter-expert reduction을 분리할 수 있게 했습니다.
정리
이 PR은 Triton Kernels에 Expert Parallelism을 위한 첫 번째 발판을 마련하면서, reduce 커널을 독립 모듈로 분리하여 코드의 재사용성과 유지보수성을 높인 의미 있는 변경입니다. MoE 모델의 분산 학습/추론 파이프라인에 직접적으로 기여하는 인프라 작업입니다.
참고 자료
이 글은 AI(Claude)의 도움을 받아 작성되었으며, 원본 PR의 코드 변경 사항을 기반으로 분석한 내용입니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [triton] AMD ds_read_tr 명령어 제한 완화로 더 유연한 레이아웃 지원
- 현재글 : [triton] Expert Parallelism 기본 구현과 Reduce 커널 추가
- 다음글 [ultralytics] Ultralytics 8.3.215: 세그멘테이션 마스크 처리 성능 최적화 분석
댓글