본문으로 건너뛰기

[sglang] FlashInfer TRTLLM-Gen MoE 커널 최적화: NemotronH 모델 지원 및 성능 향상

PR 링크: sgl-project/sglang#21321 상태: Merged | 변경: +None / -None

들어가며

대규모 언어 모델(LLM)의 발전과 함께 Mixture-of-Experts (MoE) 아키텍처는 모델의 파라미터 수를 크게 늘리면서도 계산 비용을 효율적으로 유지하는 핵심 기술로 부상했습니다. 특히, NVIDIA의 NemotronH-120B와 같은 최신 모델들은 MoE 구조를 활용하며, FP4 및 FP8과 같은 저정밀도 양자화를 통해 메모리 사용량과 추론 속도를 최적화합니다. 하지만 이러한 최적화된 모델을 최대한 활용하기 위해서는 기저의 텐서 연산 커널 또한 모델의 특성에 맞춰 최적화되어야 합니다.

이번 PR은 sgl-project/sglang 레포지토리에서 FlashInfer TRTLLM-Gen MoE 커널에 NemotronH-120B 모델을 위한 중요한 개선 사항을 도입합니다. 주요 목표는 NemotronH 모델에서 사용되는 non-gated (relu2) 활성화 함수를 FP4 및 FP8 MoE 커널에서 지원하고, Tensor Parallelism (TP) 환경에서 발생하는 가중치 정렬 문제를 해결하여 전반적인 성능과 정확도를 향상시키는 것입니다. 이 PR을 통해 NemotronH 모델의 추론 성능이 flashinfer_cutlass 백엔드 대비 최대 1.85배 향상되는 인상적인 결과를 보여주었습니다.

코드 분석: NemotronH MoE 커널 최적화

이 PR은 주로 python/sglang/srt/layers/moe/flashinfer_trtllm_moe.pypython/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py 파일에 걸쳐 변경 사항을 포함합니다. 핵심은 non-gated 활성화 함수 지원과 가중치 정렬 로직 개선입니다.

1. python/sglang/srt/layers/moe/flashinfer_trtllm_moe.py: activation_type 파라미터 추가

이 파일에서는 FlashInfer의 trtllm_fp8_block_scale_moetrtllm_fp8_block_scale_routed_moe, trtllm_fp8_per_tensor_scale_moe 함수 호출에 activation_type 파라미터를 추가하여 non-gated 활성화 함수를 지원하도록 확장했습니다. 이는 NemotronH 모델이 사용하는 relu2 활성화 함수를 FlashInfer 커널이 올바르게 처리할 수 있도록 합니다.

Before:

@torch.jit.script
def _fake_fp8_block_scale_moe(
    hidden_states: torch.Tensor,
    w13_weight: torch.Tensor,
    w2_weight: torch.Tensor,
    w13_weight_scale: torch.Tensor,
    w2_weight_scale: torch.Tensor,
    input_scale: torch.Tensor,
    activation_scale: torch.Tensor,
    output1_scales_scalar: torch.Tensor,
    output1_scales_gate_scalar: torch.Tensor,
    output2_scales_scalar: torch.Tensor,
    expert_weights: torch.Tensor,
    topk: int,
    renormalize: bool,
    enable_pdl: Optional[bool] = None,
    tune_max_num_tokens: int = 8192,
    fp8_quantization_type: Optional[int] = None,
) -> torch.Tensor:
    return torch.empty(
        hidden_states.shape, dtype=torch.bfloat16, device=hidden_states.device
    )

After:

@torch.jit.script
def _fake_fp8_block_scale_moe(
    hidden_states: torch.Tensor,
    w13_weight: torch.Tensor,
    w2_weight: torch.Tensor,
    w13_weight_scale: torch.Tensor,
    w2_weight_scale: torch.Tensor,
    input_scale: torch.Tensor,
    activation_scale: torch.Tensor,
    output1_scales_scalar: torch.Tensor,
    output1_scales_gate_scalar: torch.Tensor,
    output2_scales_scalar: torch.Tensor,
    expert_weights: torch.Tensor,
    topk: int,
    renormalize: bool,
    enable_pdl: Optional[bool] = None,
    tune_max_num_tokens: int = 8192,
    fp8_quantization_type: Optional[int] = None,
    activation_type: Optional[int] = None,
) -> torch.Tensor:
    return torch.empty(
        hidden_states.shape, dtype=torch.bfloat16, device=hidden_states.device
    )

activation_type 파라미터가 추가되었고, 이 값은 FlashInfer의 ActivationType enum으로 변환되어 커널에 전달됩니다. 이를 통해 relu2와 같은 non-gated 활성화 함수를 명시적으로 지정할 수 있게 됩니다.

    if activation_type is not None:
        from flashinfer.fused_moe.core import ActivationType

        kwargs["activation_type"] = ActivationType(activation_type)

2. python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py: 가중치 정렬 및 스케일링 로직 개선

이 파일에서는 MoE 레이어의 가중치를 FlashInfer TRTLLM 커널에 맞게 정렬하고 패딩하는 로직이 추가 및 수정되었습니다. 특히, non-gated 활성화 함수를 사용하는 경우 intermediate 차원의 정렬 요구사항이 128로 더 엄격해집니다. 또한, is_gated 헬퍼 함수를 도입하여 활성화 함수의 종류에 따라 다른 처리 로직을 적용합니다.

_is_gated 헬퍼 함수 도입

MoE 레이어가 gated 활성화 함수를 사용하는지 여부를 판단하는 헬퍼 함수입니다. 이를 통해 이후 로직에서 gated/non-gated 모델에 따라 분기 처리가 가능해집니다.

def _is_gated(layer: Module) -> bool:
    """Return whether the MoE layer uses a gated activation (default True)."""
    is_gated = (
        getattr(layer, "moe_runner_config", None) and layer.moe_runner_config.is_gated
    )
    return True if is_gated is None else is_gated

_align_fp8_moe_weights 함수 추가

FP8 MoE 가중치에 대한 패딩 로직을 담당하는 함수입니다. is_gated 여부에 따라 min_alignment를 다르게 적용합니다. non-gated 모델의 경우 128의 정렬이 필요합니다.

Before: (해당 로직이 없었음)

After:

def _align_fp8_moe_weights(
    w13: torch.Tensor,
    w2: torch.Tensor,
    is_gated: bool,
    min_alignment: int = 16,
) -> tuple[torch.Tensor, torch.Tensor, int]:
    """Pad intermediate size so FlashInfer TRTLLM FP8 kernels' alignment holds.

    Returns (w13, w2, padded_intermediate).
    """
    num_experts, hidden_size, intermediate = w2.shape

    padded_intermediate = round_up_to_multiple(intermediate, min_alignment)
    if padded_intermediate == intermediate:
        return w13, w2, intermediate

    logger.info(
        "FP8 MoE: padding intermediate size from %d to %d (alignment=%d)",
        intermediate,
        padded_intermediate,
        min_alignment,
    )

    up_mult = 2 if is_gated else 1
    padded_gate_up = up_mult * padded_intermediate

    padded_w13 = w13.new_zeros((num_experts, padded_gate_up, w13.shape[2]))
    padded_w13[:, : w13.shape[1], :] = w13

    padded_w2 = w2.new_zeros((num_experts, hidden_size, padded_intermediate))
    padded_w2[:, :, :intermediate] = w2

    return padded_w13, padded_w2, padded_intermediate

align_fp8_moe_weights_for_flashinfer_trtllm 함수 수정

기존 FP8 가중치 정렬 함수에 _is_gated를 활용한 로직 분기와 _align_fp8_moe_weights 호출이 추가되었습니다. 특히, swap_w13_halves 로직이 gated 모델에만 적용되도록 변경되었고, reorder_rows_for_gated_act_gemm 또한 gated 모델에만 적용됩니다. non-gated 모델의 경우 w13_processed는 단순히 w13_weight가 됩니다.

Before:

    # Optionally swap W13 halves: [Up, Gate] -> [Gate, Up]
    if swap_w13_halves:
        inter = two_n // 2
        w13_weight = (
            w13_weight.reshape(num_experts, 2, inter, hidden)
            .flip(dims=[1])
            .reshape(num_experts, two_n, hidden)
        )

    w13_interleaved_list = [
        reorder_rows_for_gated_act_gemm(w13_weight[i]) for i in range(num_experts)
    ]
    w13_interleaved: torch.Tensor = torch.stack(w13_interleaved_list).reshape(
        num_experts, two_n, hidden
    )

    # Shuffle weights for transposed MMA output (both W13, W2)
    epilogue_tile_m = 128
    w13_shuffled = [
        shuffle_matrix_a(w13_interleaved[i].view(torch.uint8), epilogue_tile_m)
        for i in range(num_experts)
    ]

After:

    is_gated = _is_gated(layer)

    w13_weight = cast(torch.Tensor, layer.w13_weight)
    w2_weight = cast(torch.Tensor, layer.w2_weight)
    num_experts, gate_up_dim, hidden = w13_weight.shape

    # Optionally swap W13 halves: [Up, Gate] -> [Gate, Up] (only for gated)
    if swap_w13_halves and is_gated:
        inter = gate_up_dim // 2
        w13_weight = (
            w13_weight.reshape(num_experts, 2, inter, hidden)
            .flip(dims=[1])
            .reshape(num_experts, gate_up_dim, hidden)
        )

    # Pad for kernel alignment (non-gated needs 128, gated needs 16)
    min_alignment = 16 if is_gated else 128
    w13_weight, w2_weight, _ = _align_fp8_moe_weights(
        w13_weight, w2_weight, is_gated, min_alignment
    )
    num_experts, gate_up_dim, hidden = w13_weight.shape

    epilogue_tile_m = 128

    if is_gated:
        from flashinfer import reorder_rows_for_gated_act_gemm

        w13_interleaved_list = [
            reorder_rows_for_gated_act_gemm(w13_weight[i]) for i in range(num_experts)
        ]
        w13_processed: torch.Tensor = torch.stack(w13_interleaved_list).reshape(
            num_experts, gate_up_dim, hidden
        )
    else:
        w13_processed = w13_weight

    # Shuffle weights for transposed MMA output (both W13, W2)
    w13_shuffled = [
        shuffle_matrix_a(w13_processed[i].view(torch.uint8), epilogue_tile_m)
        for i in range(num_experts)
    ]

또한, 스케일링 팩터 계산 로직도 is_gated 여부에 따라 달라집니다. non-gated (Relu2)의 경우 게이트 dequantization 기여가 없으므로 output1_scales_scalar 계산이 단순화됩니다.

Before:

    output1_scales_scalar = w13_weight_scale * input_scale * (1.0 / activation_scale)

After:

    # For gated (SwiGLU): g1_alphas = w1_scale * a1_scale, g1_scale_c = g1_alphas / a2_scale
    # For non-gated (Relu2): g1_scale_c = 1 / a2_scale (no gate dequant contribution)
    if is_gated:
        output1_scales_scalar = (
            w13_weight_scale * input_scale * (1.0 / activation_scale)
        )
    else:
        output1_scales_scalar = torch.ones_like(w13_weight_scale) * (
            1.0 / activation_scale
        )

_align_mxfp8_moe_weights 함수 추가

MXFP8 MoE 가중치 및 스케일에 대한 패딩 로직을 담당하는 함수입니다. FP8과 유사하게 is_gated 여부에 따라 min_alignment를 다르게 적용하며, 스케일 텐서 또한 함께 패딩합니다.

Before: (해당 로직이 없었음)

After:

def _align_mxfp8_moe_weights(
    w13: torch.Tensor,
    w13_scale: torch.Tensor,
    w2: torch.Tensor,
    w2_scale: torch.Tensor,
    is_gated: bool,
    min_alignment: int = 16,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]:
    """Pad intermediate size so FlashInfer TRTLLM MXFP8 kernels' alignment holds.

    Returns (w13, w13_scale, w2, w2_scale, padded_intermediate).
    """
    num_experts, hidden_size, intermediate = w2.shape

    padded_intermediate = round_up_to_multiple(intermediate, min_alignment)
    if padded_intermediate == intermediate:
        return w13, w13_scale, w2, w2_scale, intermediate

    logger.info(
        "MXFP8 MoE: padding intermediate size from %d to %d (alignment=%d)",
        intermediate,
        padded_intermediate,
        min_alignment,
    )

    up_mult = 2 if is_gated else 1
    padded_gate_up = up_mult * padded_intermediate

    padded_w13 = w13.new_zeros((num_experts, padded_gate_up, w13.shape[2]))
    padded_w13[:, : w13.shape[1], :] = w13

    padded_w2 = w2.new_zeros((num_experts, hidden_size, padded_intermediate))
    padded_w2[:, :, :intermediate] = w2

    padded_w13_scale = w13_scale.new_zeros(
        (num_experts, padded_gate_up, w13_scale.shape[2])
    )
    padded_w13_scale[:, : w13_scale.shape[1], :] = w13_scale

    # Scale's last dim tracks intermediate / block_size (MXFP8 block_size = 32)
    scale_block_k = intermediate // w2_scale.shape[2] if w2_scale.shape[2] > 0 else 32
    padded_w2_scale = w2_scale.new_zeros(
        (num_experts, hidden_size, padded_intermediate // scale_block_k)
    )
    padded_w2_scale[:, :, : w2_scale.shape[2]] = w2_scale

    return padded_w13, padded_w13_scale, padded_w2, padded_w2_scale, padded_intermediate

리뷰어 피드백 반영

리뷰 과정에서 Fridge003님은 round_up_to_multiple 함수 사용에 대한 일관성을 지적하고, gated/non-gated 모델에 대한 가중치 정렬 로직을 분리하는 것을 제안했습니다. danielafrimi님은 이를 반영하여 _align_fp8_moe_weights_align_mxfp8_moe_weights와 같은 새로운 헬퍼 함수를 도입하여 코드의 가독성과 유지보수성을 높였습니다. 또한, 테스트 케이스 추가를 통해 non-gated 활성화 패딩, 활성화 타입 매핑, _is_gated 로직, non-gated 스케일링 팩터 계산 등을 검증했습니다. 이는 코드 변경의 안정성을 확보하는 데 중요한 역할을 합니다.

왜 이게 좋은 최적화/개선인가?

이 PR은 다음과 같은 이유로 좋은 최적화 및 개선 사항으로 평가할 수 있습니다.

  1. NemotronH 모델 지원 및 호환성 확보: NemotronH-120B와 같은 최신 MoE 모델은 relu2와 같은 non-gated 활성화 함수를 사용합니다. 기존 FlashInfer TRTLLM-Gen 커널은 주로 gated (SwiGLU) 활성화에 최적화되어 있었기 때문에, activation_type 파라미터 추가와 이에 따른 내부 로직 변경은 NemotronH 모델을 sglang 프레임워크에서 효율적으로 실행할 수 있도록 하는 필수적인 개선입니다.

  2. 성능 향상: FlashInfer TRTLLM-Gen 커널은 NVIDIA TensorRT-LLM을 기반으로 하여 CUDA 커널 수준에서 높은 최적화를 제공합니다. 이 PR은 이러한 커널이 NemotronH 모델의 특성(non-gated 활성화, 특정 정렬 요구사항)을 최대한 활용할 수 있도록 가중치 정렬 및 스케일링 로직을 정교하게 조정했습니다. 결과적으로, flashinfer_cutlass 백엔드 대비 NVFP4 모델에서 최대 1.85배, FP8 모델에서 최대 1.27배의 추론 속도 향상을 달성했습니다. 이는 LLM 추론 서비스의 처리량과 지연 시간을 크게 개선합니다.

    • NVFP4 (NemotronH-120B-NVFP4, TP=1, single B200):
      • Decode-heavy (1024→8192): flashinfer_trtllm 988 tok/s (1.85x speedup)
      • Prefill-heavy (8192→1024): flashinfer_trtllm 764 tok/s (1.70x speedup)
    • FP8 (NemotronH-120B-FP8, TP=2, 2x B200):
      • Decode-heavy (1024→8192): flashinfer_trtllm 1089 tok/s (1.27x speedup)
      • Prefill-heavy (8192→1024): flashinfer_trtllm 877 tok/s (1.25x speedup)
  3. 정확도 유지: 성능 향상과 더불어 GSM8K 4-shot 정확도 평가에서 NemotronH-120B 모델의 정확도가 유지되거나 소폭 향상되는 것을 확인했습니다. 이는 최적화가 모델의 출력 품질에 부정적인 영향을 미치지 않음을 의미합니다.

    • NemotronH-120B-A12B-NVFP4 (TP=4): 0.9310 (flexible)
    • NemotronH-120B-A12B-FP8 (TP=2): 0.9234 (flexible)
  4. 코드의 모듈성 및 확장성: _is_gated와 같은 헬퍼 함수와 _align_fp8_moe_weights, _align_mxfp8_moe_weights와 같은 전용 패딩 함수를 도입함으로써, 코드가 gated/non-gated 모델 특성에 따라 명확하게 분리되고 모듈화되었습니다. 이는 향후 다른 MoE 모델이나 새로운 양자화 기법을 추가할 때 코드 변경의 위험을 줄이고 유지보수를 용이하게 합니다.

  5. 저정밀도 양자화의 효율적 활용: FP4/FP8/MXFP8과 같은 저정밀도 양자화는 메모리 대역폭과 계산량을 줄여주지만, 이를 효율적으로 활용하기 위해서는 하드웨어 아키텍처(예: NVIDIA GPU의 Tensor Cores)에 맞는 데이터 정렬이 필수적입니다. 이 PR은 intermediate 차원의 128 정렬 요구사항을 충족시키기 위한 패딩 로직을 추가하여, 저정밀도 연산이 GPU에서 최적의 성능을 발휘하도록 합니다.

결론

이 PR은 FlashInfer TRTLLM-Gen MoE 커널에 NemotronH-120B 모델을 위한 중요한 최적화와 개선을 성공적으로 통합했습니다. non-gated 활성화 함수 지원, 가중치 정렬 패딩, 그리고 스케일링 팩터 계산 로직의 정교한 조정은 NemotronH 모델의 추론 성능을 크게 향상시키면서도 정확도를 유지하는 데 기여했습니다. 이러한 최적화는 LLM 서비스의 효율성을 높이고, 더 다양한 최신 모델들을 sglang 프레임워크에서 효과적으로 활용할 수 있는 기반을 마련합니다. 이는 고성능 딥러닝 추론 시스템을 구축하는 데 있어 커널 수준의 최적화가 얼마나 중요한지를 보여주는 좋은 사례입니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글