[sglang] MoE 모델 추론 최적화: Triton 커널 퓨전을 통한 TTFT 28% 개선
PR 링크: sgl-project/sglang#19672 상태: Merged | 변경: +67 / -9
들어가며
대규모 언어 모델(LLM)의 발전과 함께 Mixture-of-Experts(MoE) 모델은 그 효율성과 성능으로 주목받고 있습니다. MoE 모델은 여러 개의 '전문가(expert)' 네트워크를 두고 입력에 따라 특정 전문가를 선택하여 연산 부하를 줄이면서도 높은 성능을 유지합니다. 하지만 이러한 복잡한 구조는 추론 과정에서 여러 커널(kernel) 호출과 데이터 이동을 야기하여 성능 병목 현상을 초래할 수 있습니다. 특히, fused_moe_triton 커널과 moe_sum_all_reduce 커널은 MoE 모델의 핵심 연산으로, 이들 간의 불필요한 데이터 전송과 개별 커널 실행 오버헤드는 First Token Latency (TTFT)에 큰 영향을 미칩니다.
이번 PR은 이러한 문제를 해결하기 위해 fused_moe_triton과 moe_sum_all_reduce 커널을 하나로 퓨전(fusion)하여 MoE 모델 추론의 TTFT를 20%~30% 가량 개선하는 것을 목표로 합니다. 이 최적화는 중복되는 중간 계산과 데이터 처리를 제거하고, GPU 글로벌 메모리 접근을 줄여 메모리 대역폭 병목 현상을 완화하며, 커널 실행 및 스케줄링 오버헤드를 낮춰 GPU 자원 효율성을 향상시킵니다.
코드 분석
이 PR의 핵심 변경사항은 fused_moe_triton 커널 내에서 moe_sum_all_reduce의 기능을 통합하여, 두 연산을 하나의 GPU 커널에서 처리하도록 만든 것입니다. 이는 주로 python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py와 python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py 파일에서 이루어졌습니다.
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
이 파일에서는 커널 퓨전 활성화 여부를 결정하는 로직과, 퓨전된 커널을 호출할 때 인자를 전달하는 부분이 변경되었습니다.
1. 퓨전 활성화 조건 추가
fused_experts_impl 함수 내에서 use_fused_moe_sum_all_reduce 변수를 추가하여, 서버 설정(enable_fused_moe_sum_all_reduce 플래그)과 특정 조건(no_combine이 아니며, topk_ids.shape[1] > 2, int8/int4 양자화 미사용)이 만족될 때만 퓨전 커널을 사용하도록 했습니다.
--- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
+++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
@@ -466,6 +468,14 @@ def fused_experts_impl(
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
+ use_fused_moe_sum_all_reduce = (
+ get_global_server_args().enable_fused_moe_sum_all_reduce
+ and (not no_combine)
+ and (curr_topk_ids.shape[1] > 2)
+ and (not use_int8_w8a16)
+ and (not use_int4_w4a16)
+ )
+
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
curr_topk_ids, config["BLOCK_SIZE_M"], E
)
2. 출력 버퍼 처리 및 커널 호출 변경
퓨전 커널을 사용할 경우, out_hidden_states의 해당 슬라이스를 미리 0으로 초기화하고, 이 버퍼를 invoke_fused_moe_kernel 함수에 직접 전달하도록 변경했습니다. 이는 intermediate_cache3와 같은 중간 버퍼를 사용하지 않고 최종 결과 버퍼에 직접 누적 연산을 수행함으로써 메모리 접근을 최적화합니다.
--- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
+++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
@@ -569,14 +579,23 @@ def fused_experts_impl(
else:
raise ValueError(f"Unsupported activation: {activation=}, with {is_gated=}")
+ out_slice = None
+ if use_fused_moe_sum_all_reduce:
+ out_slice = out_hidden_states[begin_chunk_idx:end_chunk_idx]
+ out_slice.zero_()
+
invoke_fused_moe_kernel(
intermediate_cache2,
w2,
b2,
(
- intermediate_cache3
- if not no_combine and topk_ids.shape[1] != 1
- else out_hidden_states[begin_chunk_idx:end_chunk_idx].unsqueeze(0)
+ out_slice
+ if use_fused_moe_sum_all_reduce
+ else (
+ intermediate_cache3
+ if not no_combine and topk_ids.shape[1] != 1
+ else out_hidden_states[begin_chunk_idx:end_chunk_idx].unsqueeze(0)
+ )
),
a2_scale,
w2_scale,
@@ -599,6 +618,8 @@ def fused_experts_impl(
a_use_tma=down_moe_use_tma,
b_use_tma=down_moe_use_tma,
filter_expert=filter_expert,
+ fuse_sum_all_reduce=use_fused_moe_sum_all_reduce,
+ router_topk=curr_topk_ids.shape[1],
)
if routed_scaling_factor is None:
@@ -607,7 +628,13 @@ def fused_experts_impl(
if no_combine:
pass
elif _is_cuda:
- if topk_ids.shape[1] == 1 and routed_scaling_factor == 1.0:
+ if use_fused_moe_sum_all_reduce:
+ if routed_scaling_factor is None:
+ routed_scaling_factor = 1.0
+ if routed_scaling_factor != 1.0:
+ assert out_slice is not None
+ out_slice.mul_(routed_scaling_factor)
+ elif topk_ids.shape[1] == 1 and routed_scaling_factor == 1.0:
pass # we write directly into out_hidden_states
elif topk_ids.shape[1] == 2 and routed_scaling_factor == 1.0:
torch.add(
퓨전이 활성화되면, 기존에 intermediate_cache3에 저장된 후 torch.add나 moe_sum_reduce를 통해 out_hidden_states로 합쳐지던 과정이 생략됩니다. 대신, invoke_fused_moe_kernel 내부에서 직접 out_hidden_states에 결과를 누적하게 됩니다.
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py
이 파일은 실제 Triton 커널 구현을 포함하며, 퓨전 로직의 핵심이 여기에 있습니다.
1. 커널 인자 추가
fused_moe_kernel 함수에 FUSE_SUM_ALL_REDUCE와 ROUTER_TOPK라는 두 개의 tl.constexpr 인자가 추가되었습니다. 이들은 커널 내부에서 퓨전 로직을 조건부로 활성화하고, 라우터의 topk 값을 활용하는 데 사용됩니다.
--- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py
+++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py
@@ -377,6 +377,8 @@ def fused_moe_kernel(
c_sorted: tl.constexpr,
filter_expert: tl.constexpr,
swap_ab: tl.constexpr,
+ FUSE_SUM_ALL_REDUCE: tl.constexpr,
+ ROUTER_TOPK: tl.constexpr,
):
"""
Implements the fused computation for a Mixture of Experts (MOE) using
2. 출력 쓰기 로직 변경 (핵심)
가장 중요한 변경은 accumulator (전문가 연산 결과)를 최종 출력 c_ptr에 쓰는 방식입니다. FUSE_SUM_ALL_REDUCE가 True일 경우, tl.atomic_add를 사용하여 여러 전문가의 결과를 c_ptr의 동일한 위치에 원자적으로 누적합니다. 이는 moe_sum_reduce 커널이 별도로 호출되어 수행하던 sum_all_reduce 연산을 fused_moe_kernel 내부에서 직접 처리하는 것을 의미합니다.
--- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py
+++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py
@@ -601,14 +603,27 @@ def fused_moe_kernel(
# -----------------------------------------------------------
# Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
- if c_sorted:
+ if FUSE_SUM_ALL_REDUCE:
+ offs_token_out = offs_token // ROUTER_TOPK
c_ptrs = (
- c_ptr + stride_cm * offs_token_id[:, None] + stride_cn * offs_cn[None, :]
+ c_ptr + stride_cm * offs_token_out[:, None] + stride_cn * offs_cn[None, :]
)
+ c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
+ tl.atomic_add(c_ptrs, accumulator, mask=c_mask)
else:
- c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
- c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
- tl.store(c_ptrs, accumulator, mask=c_mask)
+ if c_sorted:
+ c_ptrs = (
+ c_ptr
+ + stride_cm * offs_token_id[:, None]
+ + stride_cn * offs_cn[None, :]
+ )
+ else:
+ c_ptrs = (
+ c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
+ )
+ c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
+ tl.store(c_ptrs, accumulator, mask=c_mask)
# -----------------------------------------------------------------------------
offs_token_out = offs_token // ROUTER_TOPK 계산은 각 토큰이 여러 전문가에게 라우팅될 때, 최종 출력 버퍼에서는 하나의 토큰 위치에 모든 전문가의 결과가 합쳐져야 함을 반영합니다. tl.atomic_add는 여러 스레드가 동시에 동일한 메모리 위치에 쓸 때 발생하는 경쟁 조건(race condition)을 방지하며 올바른 합산 결과를 보장합니다.
3. invoke_fused_moe_kernel 인자 전달 및 제약 조건 추가
invoke_fused_moe_kernel 함수도 fuse_sum_all_reduce와 router_topk 인자를 받도록 업데이트되었습니다. 또한, fuse_sum_all_reduce가 활성화된 경우 c_sorted가 False여야 하고, GPTQ/AWQ와 같은 양자화 커널에서는 지원되지 않는다는 제약 조건이 추가되었습니다.
--- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py
+++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py
@@ -700,6 +715,8 @@ def invoke_fused_moe_kernel(
b_use_tma: bool = False,
c_sorted: bool = False,
filter_expert: bool = True,
+ fuse_sum_all_reduce: bool = False,
+ router_topk: int = 1,
) -> None:
assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1
@@ -767,11 +784,17 @@ def invoke_fused_moe_kernel(
else:
even_Ks = False
+ if fuse_sum_all_reduce:
+ assert not c_sorted, "fuse_sum_all_reduce only supports c_sorted=False"
+
if (
(use_int8_w8a16 or use_int4_w4a16)
and block_shape is not None
and block_shape[1] > 0
):
+ assert (
+ not fuse_sum_all_reduce
+ ), "fuse_sum_all_reduce is not supported for GPTQ/AWQ kernels"
assert B_scale is not None and B_scale.ndim == 3
assert B_zp is None or B_zp.ndim == 3
assert bias is None
@@ -878,6 +901,8 @@ def invoke_fused_moe_kernel(
c_sorted=c_sorted,
filter_expert=filter_expert,
swap_ab=swap_ab,
+ FUSE_SUM_ALL_REDUCE=fuse_sum_all_reduce,
+ ROUTER_TOPK=router_topk,
**config,
)
python/sglang/srt/server_args.py
이 파일에서는 새로운 커널 퓨전 기능을 활성화하기 위한 CLI 인자 --enable-fused-moe-sum-all-reduce가 추가되었습니다. 이를 통해 사용자가 서버 실행 시 해당 최적화를 선택적으로 적용할 수 있습니다.
--- a/python/sglang/srt/server_args.py
+++ b/python/sglang/srt/server_args.py
@@ -644,6 +644,7 @@ class ServerArgs:
nsa_prefill_cp_mode: str = "round-robin-split"
enable_fused_qk_norm_rope: bool = False
enable_precise_embedding_interpolation: bool = False
+ enable_fused_moe_sum_all_reduce: bool = False
# Dynamic batch tokenizer
enable_dynamic_batch_tokenizer: bool = False
@@ -4892,6 +4893,11 @@ def add_cli_args(parser: argparse.ArgumentParser):
action="store_true",
help="Enable corner alignment for resize of embeddings grid to ensure more accurate(but slower) evaluation of interpolated embedding values.",
)
+ parser.add_argument(
+ "--enable-fused-moe-sum-all-reduce",
+ action="store_true",
+ help="Enable fused moe triton and sum all reduce.",
+ )
# Dynamic batch tokenizer
parser.add_argument(
왜 이게 좋은가
이 커널 퓨전 최적화는 MoE 모델 추론의 여러 측면에서 상당한 이점을 제공합니다.
-
TTFT (Time To First Token) 28.3% 개선: PR 설명에 따르면, 이 최적화를 통해 TTFT가 254ms에서 182ms로 약 28.3% 감소했습니다. 이는 사용자 경험에 직접적인 영향을 미치는 중요한 지표로, 특히 대화형 AI 서비스에서 체감 성능을 크게 향상시킵니다.
-
Main branch 벤치마크 결과:
-
PR 적용 후 벤치마크 결과:
-
-
메모리 대역폭 병목 완화: 기존에는
fused_moe_triton커널이 중간 결과를 GPU 글로벌 메모리에 쓰고,moe_sum_all_reduce커널이 이를 다시 읽어와 합산하는 과정이 있었습니다. 커널 퓨전을 통해 이 중간 쓰기/읽기 작업이 사라지면서 GPU 글로벌 메모리 접근이 줄어들고, 이는 메모리 대역폭 병목 현상을 크게 완화합니다. -
커널 실행 및 스케줄링 오버헤드 감소: 두 개의 개별 커널을 하나로 합치면서, 커널 런치(launch) 및 스케줄링에 필요한 CPU 오버헤드가 줄어듭니다. GPU는 커널 실행 준비 및 전환에 시간이 소요되므로, 커널 수를 줄이는 것은 전반적인 효율성을 높이는 데 기여합니다.
-
데이터 지역성(Data Locality) 향상: 중간 결과를 레지스터나 공유 메모리(shared memory)와 같은 온칩(on-chip) 메모리에 더 오래 유지할 수 있게 되어, 데이터 지역성이 향상됩니다. 이는 GPU의 메모리 계층 구조를 효율적으로 활용하여 데이터 접근 지연 시간을 줄입니다.
일반적 교훈
이 최적화는 GPU 기반 고성능 컴퓨팅에서 중요한 몇 가지 교훈을 제공합니다.
- 커널 퓨전의 중요성: 여러 개의 작은 커널을 하나의 큰 커널로 퓨전하는 것은 GPU 연산에서 매우 효과적인 최적화 기법입니다. 이는 불필요한 글로벌 메모리 접근과 커널 런치 오버헤드를 줄여줍니다.
- 메모리 접근 최적화: GPU 성능의 핵심은 메모리 대역폭입니다. 중간 결과를 글로벌 메모리에 불필요하게 쓰고 읽는 것을 피하고, 가능한 한 레지스터나 공유 메모리를 활용하는 것이 중요합니다.
- Triton의 활용: Triton과 같은 DSL(Domain-Specific Language)은 GPU 커널을 직접 작성하고 최적화하는 데 강력한 도구입니다. 이를 통해 PyTorch와 같은 고수준 프레임워크에서 달성하기 어려운 세밀한 성능 튜닝이 가능합니다.
- 조건부 최적화: 모든 상황에 하나의 최적화가 적용될 수 있는 것은 아닙니다. 이 PR에서 보듯이, 특정 조건(예:
topk_ids.shape[1] > 2, 특정 양자화 미사용)에서만 퓨전 커널을 활성화하여 범용성과 성능을 동시에 고려하는 접근 방식이 중요합니다.
이러한 최적화는 MoE 모델과 같이 복잡하고 연산 집약적인 모델의 실시간 추론 성능을 향상시키는 데 필수적이며, LLM 서비스의 사용자 경험을 크게 개선할 수 있습니다.
참고 자료
- https://pytorch.org/docs/stable/generated/torch.add.html
- https://pytorch.org/docs/stable/generated/torch.mul.html
- https://pytorch.org/docs/stable/generated/torch.Tensor.zero_.html
- https://pytorch.org/docs/stable/generated/torch.Tensor.unsqueeze.html
- https://triton-lang.org/main/language/core-language/atomic.html
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [triton] AMD BlockPingpong 패스의 non-MFMA dot 크래시 수정
- 현재글 : [sglang] MoE 모델 추론 최적화: Triton 커널 퓨전을 통한 TTFT 28% 개선
- 다음글 [sglang] SGLang, Helios 모델 통합으로 실시간 장편 비디오 생성의 새로운 지평을 열다
댓글