[sglang] SGLang의 SM120 FP8 Blockwise GEMM 성능 최적화: Pingpong 스케줄 도입
PR 링크: sgl-project/sglang#20887 상태: Merged | 변경: +123 / -66
들어가며
최근 LLM 추론 엔진인 SGLang에서 NVIDIA SM120(Blackwell) 아키텍처를 타겟으로 하는 FP8 Blockwise GEMM 커널의 성능 최적화가 이루어졌습니다. 기존 구현체는 KernelScheduleAuto를 사용하고 있었는데, 이는 SM120 환경에서 기본적으로 'Cooperative' 커널 스케줄을 선택하게 됩니다. 하지만 Cooperative 방식은 작은 M 사이즈(행렬의 행 크기)에서 하드웨어 리소스를 충분히 활용하지 못하는 한계가 있었습니다. 본 PR은 이를 해결하기 위해 Pingpong 스케줄을 도입하여 성능을 개선했습니다.
코드 분석
sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu 변경 사항
핵심 변경은 기존의 단일 커널 실행 방식에서 m <= 64 조건에 따라 두 가지 경로를 선택하도록 리팩토링한 것입니다.
1. 스케줄링 전략 분기
기존에는 KernelScheduleAuto를 통해 자동으로 스케줄이 결정되었으나, 이제는 명시적으로 KernelScheduleSm120Blockwise와 KernelTmaWarpSpecializedBlockwisePingpongSm120을 사용합니다.
// Before
cutlass::gemm::collective::KernelScheduleAuto
// After
// M > 64일 때 Cooperative 사용
cutlass::gemm::KernelScheduleSm120Blockwise
// M <= 64일 때 Pingpong 사용
cutlass::gemm::KernelTmaWarpSpecializedBlockwisePingpongSm120
2. 런타임 조건부 실행
m <= 64인 경우 Pingpong 스케줄을 우선 시도하고, 실패 시 Cooperative로 폴백(fallback)하는 로직을 추가했습니다.
if (m <= 64) {
status = run_gemm(PingpongGemmKernel{});
if (status != cutlass::Status::kSuccess) {
status = run_gemm(CooperativeGemmKernel{});
}
} else {
status = run_gemm(CooperativeGemmKernel{});
}
왜 이게 좋은가
성능 향상
RTX 5090(SM120) 환경에서 벤치마크를 수행한 결과, M=8 및 M=64와 같은 작은 사이즈에서 기존 0.063ms 대비 0.034ms로 약 2배 가까운 지연 시간(Latency) 단축을 달성했습니다. 특히 BS=1 환경에서 토큰 생성 속도가 기존 34.14 token/s에서 52.11 token/s로 비약적으로 향상되었습니다.
최적화의 교훈
- Schedule 선택의 중요성: CUTLASS와 같은 고성능 라이브러리를 사용할 때,
Auto스케줄러가 항상 최적의 성능을 보장하지는 않습니다. 특히 아키텍처 특성(SM120)에 따라Pingpong스케줄이 Cooperative보다 작은 행렬 연산에서 훨씬 유리할 수 있음을 확인했습니다. - Fallback 전략: 새로운 스케줄을 도입할 때 발생할 수 있는 잠재적 오류를 방지하기 위해, 실패 시 기존의 안정적인 경로로 돌아가는 방어적인 프로그래밍(Fallback)이 필수적입니다.
- Tile Shape 최적화:
Pingpong스케줄을 위해64x128x128의 타일 사이즈를 명시적으로 설정하여 하드웨어의 TMA(Tensor Memory Accelerator) 활용도를 극대화했습니다.
결론
이번 최적화는 SGLang이 최신 하드웨어인 SM120의 기능을 얼마나 세밀하게 제어할 수 있는지 보여주는 좋은 사례입니다. 작은 배치 사이즈에서의 성능 병목을 해결함으로써 LLM 추론의 실시간성을 크게 개선했습니다.
참고 자료
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
- [sglang] SGLang: MiniMax-M2.5 MoE 모델을 위한 FP8 FlashInfer TRT-LLM 라우팅 최적화
- [sglang] SGLang의 디코드 성능 향상을 위한 Temperature 및 Softmax 커널 융합
- [sglang] FlashInfer v0.6.7 MXFP8 Gemm 통합: CUTLASS와 TensorRT-LLM 백엔드 분리
- [sglang] JIT RMSNorm 커널 업데이트 - Blackwell 최적화 및 벤치마크 통합
- [sglang] fused_qknorm_rope 최적화 - interleave RoPE에서 sincosf 중복 제거
PR Analysis 의 다른글
- 이전글 [Axolotl] LoRA 커널에 bias, dropout, DoRA, embedding 지원 추가
- 현재글 : [sglang] SGLang의 SM120 FP8 Blockwise GEMM 성능 최적화: Pingpong 스케줄 도입
- 다음글 [Ultralytics] detect/obb Loss 계산의 preprocess를 벡터화하여 학습 속도 향상
댓글