본문으로 건너뛰기

[sglang] SGLang: MiniMax-M2.5 MoE 모델을 위한 FP8 FlashInfer TRT-LLM 라우팅 최적화

PR 링크: sgl-project/sglang#20394 상태: Merged | 변경: +None / -None

들어가며

대규모 언어 모델(LLM)의 추론 성능은 서비스의 응답성과 비용 효율성에 직결됩니다. 특히 Mixture-of-Experts (MoE) 모델은 파라미터 수가 방대하여 높은 성능 최적화가 필수적입니다. 이번 PR은 sgl-project/sglang 레포지토리에서 MiniMax-M2.5 MoE 모델의 FP8 추론을 위해 flashinfer_trtllm_routed 백엔드를 활성화하고, 관련 최적화를 적용하여 성능을 향상시키는 것을 목표로 합니다.

기존 MoE 추론은 Triton 커널을 사용했지만, FlashInfer TRT-LLM 백엔드를 활용하여 FP8 양자화된 MoE 모델의 추론 속도를 개선하고, 특히 라우팅된(routed) MoE 모델에 대한 지원을 강화하는 것이 핵심입니다. 이를 통해 GB200 환경에서 최대 9.04%의 속도 향상을 달성했습니다.

코드 분석: 무엇이 왜 좋은 최적화/개선인가

이번 PR은 주로 flashinfer_trtllm MoE 러너의 동작을 개선하고, FP8 양자화된 가중치 처리 로직을 확장하며, FlashInfer의 특정 버그를 우회하는 데 초점을 맞추고 있습니다.

1. python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py

이 파일에서는 fused_experts_none_to_flashinfer_trtllm_fp8 함수 내에서 MoE 커널의 출력 데이터 타입을 처리하는 방식이 변경되었습니다. 기존에는 출력을 torch.bfloat16으로 강제했지만, 이제는 hidden_states의 원래 데이터 타입을 따르도록 수정되었습니다.

Before:

            symm_output = torch.empty(
                hidden_states.shape[0],
                hidden_states.shape[1],
                dtype=torch.bfloat16,
                device=hidden_states.device,
            )

After:

            symm_output = torch.empty(
                hidden_states.shape[0],
                hidden_states.shape[1],
                dtype=hidden_states.dtype,
                device=hidden_states.device,
            )

무엇이 왜 좋은가: 이 변경은 MoE 커널의 출력 데이터 타입을 hidden_statesdtype과 일치시킴으로써 불필요한 타입 캐스팅을 줄이고 데이터 일관성을 유지합니다. 특히 FP8과 같은 저정밀도 타입을 사용하는 경우, 중간 결과가 bfloat16으로 강제 변환되었다가 다시 원래 타입으로 돌아오는 과정에서 발생할 수 있는 오버헤드나 정밀도 손실을 방지할 수 있습니다. 이는 성능과 정확도 면에서 모두 긍정적인 영향을 미칩니다.

또한, FlashInfer 버그로 인해 symm_output에 결과를 복사하는 임시 방편이 추가되었습니다. 이는 flashinfer-ai/flashinfer/issues/2703 이슈가 해결되면 제거될 예정입니다.

        # TODO: Once https://github.com/flashinfer-ai/flashinfer/issues/2703 is fixed, pass output to moe kernel and remove this copy.
        symm_output.copy_(output)
        output = symm_output

TODO 주석은 현재 FlashInfer 라이브러리의 제약사항을 명시하고 있으며, 향후 라이브러리 업데이트 시 성능 개선의 여지가 있음을 보여줍니다. 현재는 커널이 직접 출력 텐서를 받지 못하므로 복사본을 생성하여 처리하는 임시 방편을 사용하고 있습니다.

2. python/sglang/srt/layers/quantization/fp8.py

이 파일에서는 FP8 가중치를 FlashInfer의 per-tensor 커널 레이아웃에 맞추는 로직이 flashinfer_trtllm_routed 백엔드에도 적용되도록 확장되었습니다.

Before:

            if get_moe_runner_backend().is_flashinfer_trtllm():

After:

            if (
                get_moe_runner_backend().is_flashinfer_trtllm()
                or get_moe_runner_backend().is_flashinfer_trtllm_routed()
            ):

무엇이 왜 좋은가: flashinfer_trtllm_routed 백엔드도 flashinfer_trtllm과 동일하게 FlashInfer의 특정 가중치 레이아웃을 필요로 합니다. 이 변경은 routed 버전의 MoE 백엔드를 사용할 때도 FP8 가중치가 FlashInfer 커널에 올바르게 전달될 수 있도록 보장합니다. zianglih 리뷰어의 코멘트처럼 두 백엔드 모두 동일한 swizzling(가중치 재배열)이 필요하므로, 이 확장은 필수적입니다. 이는 flashinfer_trtllm_routed를 통한 MoE 모델의 정확하고 효율적인 FP8 추론을 가능하게 합니다.

또한, getattr 함수의 기본값 처리 방식이 개선되었습니다. routing_method_typeNone으로 설정된 경우에도 기본값이 올바르게 적용되도록 수정되었습니다.

Before:

                routing_method_type=int(
                    getattr(layer, 

## 참고 자료
- https://pytorch.org/docs/stable/generated/torch.empty.html
- https://pytorch.org/docs/stable/generated/torch.Tensor.copy_.html
- https://pytorch.org/docs/stable/generated/torch.bfloat16.html

> ⚠️ **알림:** 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글