[sglang] DeepEP Low Latency FP8 Dispatch 변경 revert
PR 링크: sgl-project/sglang#21719 상태: Merged | 변경: +12 / -94
들어가며
DeepSeek-R1-0528-w4a8 모델을 위해 DeepEP Low Latency dispatch에 FP8 통신을 도입하는 변경이 있었다. 이 변경에는 per-token FP8 양자화를 per-tensor로 변환하는 Triton 커널과 관련 로직이 포함되었다. 그러나 이 변경이 문제를 야기하여 revert되었다.
핵심 코드 분석
1. FP8 per-token to per-tensor 변환 커널 제거
Before (revert 대상):
def cutlass_w4a8_moe_deepep_ll(
a_states: torch.Tensor, # FP8 텐서
a_scales: torch.Tensor, # per-token scale
...
):
gateup_input = torch.empty(a_states.shape, dtype=torch.float8_e4m3fn)
fp8_per_token_to_per_tensor_quant_triton(
x=a_states, x_scale=a_scales,
masked_m=masked_m, output_scale=a1_scale,
output=gateup_input,
)
After (revert 후):
def cutlass_w4a8_moe_deepep_ll(
a: torch.Tensor, # BF16 텐서
...
):
gateup_input = torch.empty(a.shape, dtype=torch.float8_e4m3fn)
per_tensor_quant_fp8(a, gateup_input, a1_scale.float(), True)
FP8 dispatch를 위해 추가된 fp8_per_token_to_per_tensor_quant_triton Triton 커널(약 76줄)이 완전히 제거되고, BF16 입력에 대한 기존 per_tensor_quant_fp8 경로로 복원되었다.
2. BF16 dispatch 강제 assertion 이동
Before:
# forward_cutlass_w4afp8 (normal dispatch)
assert envs.SGLANG_DEEPEP_BF16_DISPATCH.get(), \
"W4AFP8 does not support FP8 normal dispatch"
After:
# forward_cutlass_w4afp8_masked (low-latency dispatch)에만 assertion
assert envs.SGLANG_DEEPEP_BF16_DISPATCH.get(), \
"W4AFP8 does not support FP8 dispatch"
BF16 dispatch 강제 assertion이 normal dispatch에서 제거되고 low-latency(masked) dispatch에만 적용된다.
3. FP8 dispatch 비활성화
Before:
# deepep.py dispatch
else:
use_fp8 = True # 항상 FP8
After:
elif not envs.SGLANG_DEEPEP_BF16_DISPATCH.get():
use_fp8 = True # BF16 모드가 아닐 때만 FP8
SGLANG_DEEPEP_BF16_DISPATCH가 설정되어 있으면 FP8을 사용하지 않도록 조건을 추가했다.
왜 이게 좋은가
- 안정성 복원: 문제가 있던 FP8 dispatch 경로를 안전하게 비활성화
- 코드 정리: 사용되지 않는 Triton 커널 76줄 제거로 코드베이스 단순화
- 정확한 scope: normal dispatch는 FP8을 허용하되, low-latency dispatch만 BF16을 강제
정리
성능 최적화를 위한 FP8 통신 도입이 예상대로 동작하지 않아 revert된 케이스다. per-token to per-tensor FP8 변환 Triton 커널과 관련 인터페이스 변경이 깔끔하게 롤백되었고, dispatch 모드별 assertion이 올바른 위치로 조정되었다.
참고 자료
- sgl-project/sglang#21719 — 원본 PR
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [sglang] CI에서 NVIDIA wheel 로컬 캐싱으로 830MB 반복 다운로드 방지
- 현재글 : [sglang] DeepEP Low Latency FP8 Dispatch 변경 revert
- 다음글 [Triton] AMD gfx1250 Tensor Descriptor 기반 GEMM 테스트 추가
댓글