[sglang] Ascend NPU에서 Qwen3 모델을 위한 W8A8 MXFP8 양자화 지원
PR 링크: sgl-project/sglang#22352 상태: Merged | 변경: +376 / -11
들어가며
최근 SGLang 프로젝트는 Ascend NPU 환경에서의 양자화 지원 범위를 넓히는 데 집중하고 있습니다. 이번 PR은 Qwen3 및 Qwen3.5와 같은 Dense LLM 모델을 Ascend NPU에서 효율적으로 구동하기 위해 W8A8 MXFP8(Microscaling FP8) 양자화를 도입했습니다. 기존에는 NPU 환경에서 FP8 양자화 지원이 제한적이었으나, 이번 변경을 통해 온라인 및 오프라인 양자화 경로를 모두 확보하여 추론 성능을 크게 향상시켰습니다.
코드 분석
1. NPUMXFP8LinearMethod 구현 (srt/hardware_backend/npu/quantization/linear_method_npu.py)
온라인 양자화를 위해 NPUMXFP8LinearMethod 클래스를 추가했습니다. 이 클래스는 로드 시점에 FP16/BF16 가중치를 MXFP8로 변환합니다.
# Before: NPU 전용 MXFP8 Linear Method 부재
# After: NPUMXFP8LinearMethod 추가
class NPUMXFP8LinearMethod(_NPULinearMethodBase):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# ... 생략 ...
qw, w_scale = torch_npu.npu_dynamic_mx_quant(
weight_fp, dst_type=torch_npu.float8_e4m3fn
)
layer.weight = Parameter(qw.transpose(0, 1), requires_grad=False)
핵심은 npu_dynamic_mx_quant를 사용하여 가중치를 동적으로 양자화하고, npu_quant_matmul을 통해 추론 시 연산을 수행하는 것입니다. 특히 .contiguous()를 호출하지 않고 전치(transpose)된 뷰를 유지하여 메모리 대역폭을 최적화했습니다.
2. 오프라인 양자화 스킴 추가 (srt/layers/quantization/modelslim/schemes/modelslim_mxfp8.py)
msmodelslim으로 미리 양자화된 가중치를 로드하기 위한 ModelSlimMXFP8Scheme을 도입했습니다.
# Offline Path: .contiguous() 호출 금지
# calling .contiguous() would physically reorder the pre-quantized weight data
# and break the block-scale mapping, producing garbled output.
오프라인 경로에서는 이미 양자화된 가중치의 레이아웃을 보존하는 것이 중요하므로, .data 할당을 통해 비연속적인 뷰를 유지합니다.
3. 로터리 임베딩 안정성 강화 (srt/layers/rotary_embedding/base.py)
커널 부재 시 모듈 전체가 임포트되지 않는 문제를 방지하기 위해 예외 처리를 추가했습니다.
# Before
from sgl_kernel_npu import fused_rope_qk_mqa
# After
try:
from sgl_kernel_npu import fused_rope_qk_mqa
except ImportError:
fused_rope_qk_mqa = None
왜 이게 좋은가
이번 최적화는 Ascend NPU 환경에서 실질적인 성능 향상을 가져왔습니다. 벤치마크 결과에 따르면:
- 처리량(Throughput): BF16 대비 약 17% 향상.
- 지연 시간(Latency): E2E 지연 시간이 약 13% 감소.
- 효율성: MXFP8 온라인 양자화는 GSM8K 벤치마크에서 가장 높은 처리량을 기록했습니다.
이 PR의 교훈은 하드웨어 가속기(NPU)의 특성에 맞는 커널 활용과, 메모리 레이아웃 보존이 추론 성능에 결정적이라는 점입니다. 특히 .contiguous() 호출 여부가 양자화된 가중치의 블록 스케일 매핑을 깨뜨릴 수 있다는 점은 NPU 최적화 시 반드시 고려해야 할 핵심 사항입니다.
참고 자료
참고 자료
- https://gitee.com/ascend/pytorch/blob/master/docs/zh/api/npu_dynamic_mx_quant.md
- https://gitee.com/ascend/pytorch/blob/master/docs/zh/api/npu_quant_matmul.md
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [sglang] [성능 최적화] Wan2.2 모델을 위한 최적의 torch.compile 모드 찾기: 왜 'default'가 더 빠를까?
- 현재글 : [sglang] Ascend NPU에서 Qwen3 모델을 위한 W8A8 MXFP8 양자화 지원
- 다음글 [onnxruntime] ONNX Runtime WebGPU: Reduce 연산 최적화를 통한 성능 향상
댓글