[sglang] SGLang NPU 성능 최적화: INT8 TP 통신 압축 도입
PR 링크: sgl-project/sglang#20520 상태: Merged | 변경: +None / -None
들어가며
대규모 언어 모델(LLM)을 여러 디바이스에서 병렬로 추론할 때, Tensor Parallelism(TP)은 필수적입니다. 하지만 디바이스 간의 통신(All-Reduce)은 종종 병목 현상을 일으킵니다. 특히 프리필(prefill) 단계에서는 대량의 데이터를 처리해야 하므로 통신 오버헤드가 성능에 큰 영향을 미칩니다. 본 PR은 NPU 환경에서 Qwen3 모델의 TP 통신을 INT8로 압축하여 전송함으로써 통신 지연 시간을 줄이고, 전체적인 프리필 성능을 약 5% 향상시키는 최적화 기법을 도입했습니다.
코드 분석
1. NPU 통신 최적화 (python/sglang/srt/distributed/device_communicators/npu_communicator.py)
핵심은 npu_dynamic_quant를 사용하여 통신 전 데이터를 INT8로 양자화하고, 수신 후 복원하는 과정입니다.
Before:
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
dist.all_reduce(x, group=self.group)
return x
After:
def quant_all_reduce(self, x: torch.Tensor) -> torch.Tensor:
x_q, scale = npu_dynamic_quant(x, dst_type=torch.int8)
# All-gather 후 복원
dist.all_gather_into_tensor(output_tensor, x_q, group=self.group)
dist.all_gather_into_tensor(output_scale, scale, group=self.group)
output_tensor = output_tensor.to(x.dtype) * output_scale.unsqueeze(-1).to(x.dtype)
return output_tensor.sum(dim=0)
2. 레이어별 통신 제어 (python/sglang/srt/layers/linear.py)
모든 상황에서 양자화를 적용하는 대신, 프리필 단계에서만 활성화되도록 로직을 추가했습니다.
After:
quantize_communications = (
not forward_batch.forward_mode.is_decode_or_idle()
and get_global_server_args().enable_quant_communications
)
if quantize_communications:
output = tensor_model_parallel_quant_all_reduce(output_parallel)
else:
output = tensor_model_parallel_all_reduce(output_parallel)
왜 이게 좋은가
- 통신 대역폭 절감: FP16/BF16 데이터를 INT8로 압축하면 전송해야 할 데이터 양이 절반으로 줄어들어, 대역폭 제한이 있는 NPU 환경에서 통신 지연을 효과적으로 줄입니다.
- 정확도 유지: 실험 결과(BoolQ, C-Eval, HellaSwag)에서 양자화로 인한 성능 저하가 거의 없음을 확인했습니다.
- 선택적 적용: 서버 인자
--enable-quant-communications를 통해 사용자가 필요에 따라 기능을 켜고 끌 수 있게 설계되었습니다.
교훈: 분산 시스템에서 통신은 종종 연산보다 큰 병목입니다. 데이터의 정밀도를 약간 희생하더라도 통신량을 줄이는 것이 전체 시스템 처리량(Throughput) 개선에 매우 효과적일 수 있습니다.
리뷰어 피드백 반영
리뷰 과정에서 fp_comm이라는 모호한 변수명을 quantize_communications로 변경하여 가독성을 높였으며, is_decode() 대신 is_decode_or_idle()을 사용하여 더 정확한 상태 체크를 수행하도록 개선했습니다. 또한, NPU 환경에서만 동작하도록 _is_npu 플래그를 활용한 최적화를 적용했습니다.
참고 자료
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [cpython] Python JIT의 GDB 디버깅 지원: .eh_frame 생성을 통한 스택 언와인딩 구현
- 현재글 : [sglang] SGLang NPU 성능 최적화: INT8 TP 통신 압축 도입
- 다음글 [sglang] HunyuanVideo VAE 디코딩 성능 향상: GroupNorm SiLU 커널 최적화
댓글