[sglang] FlashInfer TRTLLM-Gen MoE 커널 최적화: NemotronH 모델 지원 및 성능 향상
PR 링크: sgl-project/sglang#21321 상태: Merged | 변경: +None / -None
들어가며
대규모 언어 모델(LLM)의 발전과 함께 Mixture-of-Experts (MoE) 아키텍처는 모델의 파라미터 수를 크게 늘리면서도 계산 비용을 효율적으로 유지하는 핵심 기술로 부상했습니다. 특히, NVIDIA의 NemotronH-120B와 같은 최신 모델들은 MoE 구조를 활용하며, FP4 및 FP8과 같은 저정밀도 양자화를 통해 메모리 사용량과 추론 속도를 최적화합니다. 하지만 이러한 최적화된 모델을 최대한 활용하기 위해서는 기저의 텐서 연산 커널 또한 모델의 특성에 맞춰 최적화되어야 합니다.
이번 PR은 sgl-project/sglang 레포지토리에서 FlashInfer TRTLLM-Gen MoE 커널에 NemotronH-120B 모델을 위한 중요한 개선 사항을 도입합니다. 주요 목표는 NemotronH 모델에서 사용되는 non-gated (relu2) 활성화 함수를 FP4 및 FP8 MoE 커널에서 지원하고, Tensor Parallelism (TP) 환경에서 발생하는 가중치 정렬 문제를 해결하여 전반적인 성능과 정확도를 향상시키는 것입니다. 이 PR을 통해 NemotronH 모델의 추론 성능이 flashinfer_cutlass 백엔드 대비 최대 1.85배 향상되는 인상적인 결과를 보여주었습니다.
코드 분석: NemotronH MoE 커널 최적화
이 PR은 주로 python/sglang/srt/layers/moe/flashinfer_trtllm_moe.py와 python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py 파일에 걸쳐 변경 사항을 포함합니다. 핵심은 non-gated 활성화 함수 지원과 가중치 정렬 로직 개선입니다.
1. python/sglang/srt/layers/moe/flashinfer_trtllm_moe.py: activation_type 파라미터 추가
이 파일에서는 FlashInfer의 trtllm_fp8_block_scale_moe 및 trtllm_fp8_block_scale_routed_moe, trtllm_fp8_per_tensor_scale_moe 함수 호출에 activation_type 파라미터를 추가하여 non-gated 활성화 함수를 지원하도록 확장했습니다. 이는 NemotronH 모델이 사용하는 relu2 활성화 함수를 FlashInfer 커널이 올바르게 처리할 수 있도록 합니다.
Before:
@torch.jit.script
def _fake_fp8_block_scale_moe(
hidden_states: torch.Tensor,
w13_weight: torch.Tensor,
w2_weight: torch.Tensor,
w13_weight_scale: torch.Tensor,
w2_weight_scale: torch.Tensor,
input_scale: torch.Tensor,
activation_scale: torch.Tensor,
output1_scales_scalar: torch.Tensor,
output1_scales_gate_scalar: torch.Tensor,
output2_scales_scalar: torch.Tensor,
expert_weights: torch.Tensor,
topk: int,
renormalize: bool,
enable_pdl: Optional[bool] = None,
tune_max_num_tokens: int = 8192,
fp8_quantization_type: Optional[int] = None,
) -> torch.Tensor:
return torch.empty(
hidden_states.shape, dtype=torch.bfloat16, device=hidden_states.device
)
After:
@torch.jit.script
def _fake_fp8_block_scale_moe(
hidden_states: torch.Tensor,
w13_weight: torch.Tensor,
w2_weight: torch.Tensor,
w13_weight_scale: torch.Tensor,
w2_weight_scale: torch.Tensor,
input_scale: torch.Tensor,
activation_scale: torch.Tensor,
output1_scales_scalar: torch.Tensor,
output1_scales_gate_scalar: torch.Tensor,
output2_scales_scalar: torch.Tensor,
expert_weights: torch.Tensor,
topk: int,
renormalize: bool,
enable_pdl: Optional[bool] = None,
tune_max_num_tokens: int = 8192,
fp8_quantization_type: Optional[int] = None,
activation_type: Optional[int] = None,
) -> torch.Tensor:
return torch.empty(
hidden_states.shape, dtype=torch.bfloat16, device=hidden_states.device
)
activation_type 파라미터가 추가되었고, 이 값은 FlashInfer의 ActivationType enum으로 변환되어 커널에 전달됩니다. 이를 통해 relu2와 같은 non-gated 활성화 함수를 명시적으로 지정할 수 있게 됩니다.
if activation_type is not None:
from flashinfer.fused_moe.core import ActivationType
kwargs["activation_type"] = ActivationType(activation_type)
2. python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py: 가중치 정렬 및 스케일링 로직 개선
이 파일에서는 MoE 레이어의 가중치를 FlashInfer TRTLLM 커널에 맞게 정렬하고 패딩하는 로직이 추가 및 수정되었습니다. 특히, non-gated 활성화 함수를 사용하는 경우 intermediate 차원의 정렬 요구사항이 128로 더 엄격해집니다. 또한, is_gated 헬퍼 함수를 도입하여 활성화 함수의 종류에 따라 다른 처리 로직을 적용합니다.
_is_gated 헬퍼 함수 도입
MoE 레이어가 gated 활성화 함수를 사용하는지 여부를 판단하는 헬퍼 함수입니다. 이를 통해 이후 로직에서 gated/non-gated 모델에 따라 분기 처리가 가능해집니다.
def _is_gated(layer: Module) -> bool:
"""Return whether the MoE layer uses a gated activation (default True)."""
is_gated = (
getattr(layer, "moe_runner_config", None) and layer.moe_runner_config.is_gated
)
return True if is_gated is None else is_gated
_align_fp8_moe_weights 함수 추가
FP8 MoE 가중치에 대한 패딩 로직을 담당하는 함수입니다. is_gated 여부에 따라 min_alignment를 다르게 적용합니다. non-gated 모델의 경우 128의 정렬이 필요합니다.
Before: (해당 로직이 없었음)
After:
def _align_fp8_moe_weights(
w13: torch.Tensor,
w2: torch.Tensor,
is_gated: bool,
min_alignment: int = 16,
) -> tuple[torch.Tensor, torch.Tensor, int]:
"""Pad intermediate size so FlashInfer TRTLLM FP8 kernels' alignment holds.
Returns (w13, w2, padded_intermediate).
"""
num_experts, hidden_size, intermediate = w2.shape
padded_intermediate = round_up_to_multiple(intermediate, min_alignment)
if padded_intermediate == intermediate:
return w13, w2, intermediate
logger.info(
"FP8 MoE: padding intermediate size from %d to %d (alignment=%d)",
intermediate,
padded_intermediate,
min_alignment,
)
up_mult = 2 if is_gated else 1
padded_gate_up = up_mult * padded_intermediate
padded_w13 = w13.new_zeros((num_experts, padded_gate_up, w13.shape[2]))
padded_w13[:, : w13.shape[1], :] = w13
padded_w2 = w2.new_zeros((num_experts, hidden_size, padded_intermediate))
padded_w2[:, :, :intermediate] = w2
return padded_w13, padded_w2, padded_intermediate
align_fp8_moe_weights_for_flashinfer_trtllm 함수 수정
기존 FP8 가중치 정렬 함수에 _is_gated를 활용한 로직 분기와 _align_fp8_moe_weights 호출이 추가되었습니다. 특히, swap_w13_halves 로직이 gated 모델에만 적용되도록 변경되었고, reorder_rows_for_gated_act_gemm 또한 gated 모델에만 적용됩니다. non-gated 모델의 경우 w13_processed는 단순히 w13_weight가 됩니다.
Before:
# Optionally swap W13 halves: [Up, Gate] -> [Gate, Up]
if swap_w13_halves:
inter = two_n // 2
w13_weight = (
w13_weight.reshape(num_experts, 2, inter, hidden)
.flip(dims=[1])
.reshape(num_experts, two_n, hidden)
)
w13_interleaved_list = [
reorder_rows_for_gated_act_gemm(w13_weight[i]) for i in range(num_experts)
]
w13_interleaved: torch.Tensor = torch.stack(w13_interleaved_list).reshape(
num_experts, two_n, hidden
)
# Shuffle weights for transposed MMA output (both W13, W2)
epilogue_tile_m = 128
w13_shuffled = [
shuffle_matrix_a(w13_interleaved[i].view(torch.uint8), epilogue_tile_m)
for i in range(num_experts)
]
After:
is_gated = _is_gated(layer)
w13_weight = cast(torch.Tensor, layer.w13_weight)
w2_weight = cast(torch.Tensor, layer.w2_weight)
num_experts, gate_up_dim, hidden = w13_weight.shape
# Optionally swap W13 halves: [Up, Gate] -> [Gate, Up] (only for gated)
if swap_w13_halves and is_gated:
inter = gate_up_dim // 2
w13_weight = (
w13_weight.reshape(num_experts, 2, inter, hidden)
.flip(dims=[1])
.reshape(num_experts, gate_up_dim, hidden)
)
# Pad for kernel alignment (non-gated needs 128, gated needs 16)
min_alignment = 16 if is_gated else 128
w13_weight, w2_weight, _ = _align_fp8_moe_weights(
w13_weight, w2_weight, is_gated, min_alignment
)
num_experts, gate_up_dim, hidden = w13_weight.shape
epilogue_tile_m = 128
if is_gated:
from flashinfer import reorder_rows_for_gated_act_gemm
w13_interleaved_list = [
reorder_rows_for_gated_act_gemm(w13_weight[i]) for i in range(num_experts)
]
w13_processed: torch.Tensor = torch.stack(w13_interleaved_list).reshape(
num_experts, gate_up_dim, hidden
)
else:
w13_processed = w13_weight
# Shuffle weights for transposed MMA output (both W13, W2)
w13_shuffled = [
shuffle_matrix_a(w13_processed[i].view(torch.uint8), epilogue_tile_m)
for i in range(num_experts)
]
또한, 스케일링 팩터 계산 로직도 is_gated 여부에 따라 달라집니다. non-gated (Relu2)의 경우 게이트 dequantization 기여가 없으므로 output1_scales_scalar 계산이 단순화됩니다.
Before:
output1_scales_scalar = w13_weight_scale * input_scale * (1.0 / activation_scale)
After:
# For gated (SwiGLU): g1_alphas = w1_scale * a1_scale, g1_scale_c = g1_alphas / a2_scale
# For non-gated (Relu2): g1_scale_c = 1 / a2_scale (no gate dequant contribution)
if is_gated:
output1_scales_scalar = (
w13_weight_scale * input_scale * (1.0 / activation_scale)
)
else:
output1_scales_scalar = torch.ones_like(w13_weight_scale) * (
1.0 / activation_scale
)
_align_mxfp8_moe_weights 함수 추가
MXFP8 MoE 가중치 및 스케일에 대한 패딩 로직을 담당하는 함수입니다. FP8과 유사하게 is_gated 여부에 따라 min_alignment를 다르게 적용하며, 스케일 텐서 또한 함께 패딩합니다.
Before: (해당 로직이 없었음)
After:
def _align_mxfp8_moe_weights(
w13: torch.Tensor,
w13_scale: torch.Tensor,
w2: torch.Tensor,
w2_scale: torch.Tensor,
is_gated: bool,
min_alignment: int = 16,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]:
"""Pad intermediate size so FlashInfer TRTLLM MXFP8 kernels' alignment holds.
Returns (w13, w13_scale, w2, w2_scale, padded_intermediate).
"""
num_experts, hidden_size, intermediate = w2.shape
padded_intermediate = round_up_to_multiple(intermediate, min_alignment)
if padded_intermediate == intermediate:
return w13, w13_scale, w2, w2_scale, intermediate
logger.info(
"MXFP8 MoE: padding intermediate size from %d to %d (alignment=%d)",
intermediate,
padded_intermediate,
min_alignment,
)
up_mult = 2 if is_gated else 1
padded_gate_up = up_mult * padded_intermediate
padded_w13 = w13.new_zeros((num_experts, padded_gate_up, w13.shape[2]))
padded_w13[:, : w13.shape[1], :] = w13
padded_w2 = w2.new_zeros((num_experts, hidden_size, padded_intermediate))
padded_w2[:, :, :intermediate] = w2
padded_w13_scale = w13_scale.new_zeros(
(num_experts, padded_gate_up, w13_scale.shape[2])
)
padded_w13_scale[:, : w13_scale.shape[1], :] = w13_scale
# Scale's last dim tracks intermediate / block_size (MXFP8 block_size = 32)
scale_block_k = intermediate // w2_scale.shape[2] if w2_scale.shape[2] > 0 else 32
padded_w2_scale = w2_scale.new_zeros(
(num_experts, hidden_size, padded_intermediate // scale_block_k)
)
padded_w2_scale[:, :, : w2_scale.shape[2]] = w2_scale
return padded_w13, padded_w13_scale, padded_w2, padded_w2_scale, padded_intermediate
리뷰어 피드백 반영
리뷰 과정에서 Fridge003님은 round_up_to_multiple 함수 사용에 대한 일관성을 지적하고, gated/non-gated 모델에 대한 가중치 정렬 로직을 분리하는 것을 제안했습니다. danielafrimi님은 이를 반영하여 _align_fp8_moe_weights와 _align_mxfp8_moe_weights와 같은 새로운 헬퍼 함수를 도입하여 코드의 가독성과 유지보수성을 높였습니다. 또한, 테스트 케이스 추가를 통해 non-gated 활성화 패딩, 활성화 타입 매핑, _is_gated 로직, non-gated 스케일링 팩터 계산 등을 검증했습니다. 이는 코드 변경의 안정성을 확보하는 데 중요한 역할을 합니다.
왜 이게 좋은 최적화/개선인가?
이 PR은 다음과 같은 이유로 좋은 최적화 및 개선 사항으로 평가할 수 있습니다.
-
NemotronH 모델 지원 및 호환성 확보: NemotronH-120B와 같은 최신 MoE 모델은
relu2와 같은 non-gated 활성화 함수를 사용합니다. 기존 FlashInfer TRTLLM-Gen 커널은 주로 gated (SwiGLU) 활성화에 최적화되어 있었기 때문에,activation_type파라미터 추가와 이에 따른 내부 로직 변경은 NemotronH 모델을sglang프레임워크에서 효율적으로 실행할 수 있도록 하는 필수적인 개선입니다. -
성능 향상: FlashInfer TRTLLM-Gen 커널은 NVIDIA TensorRT-LLM을 기반으로 하여 CUDA 커널 수준에서 높은 최적화를 제공합니다. 이 PR은 이러한 커널이 NemotronH 모델의 특성(non-gated 활성화, 특정 정렬 요구사항)을 최대한 활용할 수 있도록 가중치 정렬 및 스케일링 로직을 정교하게 조정했습니다. 결과적으로,
flashinfer_cutlass백엔드 대비 NVFP4 모델에서 최대 1.85배, FP8 모델에서 최대 1.27배의 추론 속도 향상을 달성했습니다. 이는 LLM 추론 서비스의 처리량과 지연 시간을 크게 개선합니다.- NVFP4 (NemotronH-120B-NVFP4, TP=1, single B200):
- Decode-heavy (1024→8192):
flashinfer_trtllm988 tok/s (1.85x speedup) - Prefill-heavy (8192→1024):
flashinfer_trtllm764 tok/s (1.70x speedup)
- Decode-heavy (1024→8192):
- FP8 (NemotronH-120B-FP8, TP=2, 2x B200):
- Decode-heavy (1024→8192):
flashinfer_trtllm1089 tok/s (1.27x speedup) - Prefill-heavy (8192→1024):
flashinfer_trtllm877 tok/s (1.25x speedup)
- Decode-heavy (1024→8192):
- NVFP4 (NemotronH-120B-NVFP4, TP=1, single B200):
-
정확도 유지: 성능 향상과 더불어 GSM8K 4-shot 정확도 평가에서 NemotronH-120B 모델의 정확도가 유지되거나 소폭 향상되는 것을 확인했습니다. 이는 최적화가 모델의 출력 품질에 부정적인 영향을 미치지 않음을 의미합니다.
- NemotronH-120B-A12B-NVFP4 (TP=4): 0.9310 (flexible)
- NemotronH-120B-A12B-FP8 (TP=2): 0.9234 (flexible)
-
코드의 모듈성 및 확장성:
_is_gated와 같은 헬퍼 함수와_align_fp8_moe_weights,_align_mxfp8_moe_weights와 같은 전용 패딩 함수를 도입함으로써, 코드가 gated/non-gated 모델 특성에 따라 명확하게 분리되고 모듈화되었습니다. 이는 향후 다른 MoE 모델이나 새로운 양자화 기법을 추가할 때 코드 변경의 위험을 줄이고 유지보수를 용이하게 합니다. -
저정밀도 양자화의 효율적 활용: FP4/FP8/MXFP8과 같은 저정밀도 양자화는 메모리 대역폭과 계산량을 줄여주지만, 이를 효율적으로 활용하기 위해서는 하드웨어 아키텍처(예: NVIDIA GPU의 Tensor Cores)에 맞는 데이터 정렬이 필수적입니다. 이 PR은
intermediate차원의128정렬 요구사항을 충족시키기 위한 패딩 로직을 추가하여, 저정밀도 연산이 GPU에서 최적의 성능을 발휘하도록 합니다.
결론
이 PR은 FlashInfer TRTLLM-Gen MoE 커널에 NemotronH-120B 모델을 위한 중요한 최적화와 개선을 성공적으로 통합했습니다. non-gated 활성화 함수 지원, 가중치 정렬 패딩, 그리고 스케일링 팩터 계산 로직의 정교한 조정은 NemotronH 모델의 추론 성능을 크게 향상시키면서도 정확도를 유지하는 데 기여했습니다. 이러한 최적화는 LLM 서비스의 효율성을 높이고, 더 다양한 최신 모델들을 sglang 프레임워크에서 효과적으로 활용할 수 있는 기반을 마련합니다. 이는 고성능 딥러닝 추론 시스템을 구축하는 데 있어 커널 수준의 최적화가 얼마나 중요한지를 보여주는 좋은 사례입니다.
참고 자료
- https://github.com/flashinfer/flashinfer/blob/main/python/flashinfer/fused_moe/core.py#L18-L23
- https://github.com/flashinfer/flashinfer/blob/main/python/flashinfer/fused_moe/__init__.py
- https://pytorch.org/docs/stable/generated/torch.Tensor.html
- https://pytorch.org/docs/stable/generated/torch.nn.Module.html
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [sglang] SGLang, FP4 KV 캐시 도입으로 LLM 추론 성능 극대화: NVFP4 최적화 분석
- 현재글 : [sglang] FlashInfer TRTLLM-Gen MoE 커널 최적화: NemotronH 모델 지원 및 성능 향상
- 다음글 [triton] Triton의 Ragged Matmul 메타데이터 계산 최적화: CPU 동기화 없는 효율적인 프로파일링
댓글