[SGLang] Batch Overlap: 연산-통신 오버랩 최적화
들어가며
MoE(Mixture of Experts) 모델의 Expert Parallelism에서 All-to-All 통신은 상당한 시간을 차지한다. SGLang의 Batch Overlap 모듈은 GPU의 연산(GEMM)과 통신(All-to-All)을 서로 다른 CUDA 스트림에서 동시에 실행하여 유휴 시간을 줄인다.
구조도
batch_overlap/
├── operations.py ── Stage 기반 실행 엔진
├── operations_strategy.py ── 연산 전략 정의
├── single_batch_overlap.py ── SBO: 단일 배치 내 오버랩
└── two_batch_overlap.py ── TBO: 두 배치 간 오버랩
[Single Batch Overlap - 타임라인]
Main Stream: │ Dispatch │ Gate+Up GEMM │ Down GEMM ──────│
Alt Stream: │ │ │ Combine (통신) │
├──────────┼──────────────┼─────────────────┤
t0 t1 t2 t3
[Two Batch Overlap - 타임라인]
Batch A: │ Stage 0 │ Stage 1 │ Stage 2 │ Stage 3 │
Batch B: │ Stage 0 │ Stage 1 │ Stage 2 │ Stage 3 │
├─────────┼─────────┼─────────┼─────────┼─────────┤
배치 A가 delta_stage만큼 선행 실행
핵심 코드 분석
Operations 실행 엔진
operations.py는 연산을 Stage 단위로 묶어 순차 실행하는 엔진이다. YieldOperation으로 스테이지 경계를 표시한다.
def execute_operations(inputs, operations):
stages = _convert_operations_to_stages(operations)
executor = _StageExecutor("primary", stages, inputs=inputs)
for _ in range(executor.num_stages):
executor.next()
return executor.output
두 배치를 오버랩할 때는 execute_overlapped_operations를 사용한다. 배치 A가 delta_stage만큼 먼저 실행되고, 이후 A와 B가 번갈아 실행된다.
def execute_overlapped_operations(inputs_arr, operations_arr, delta_stages):
executor_a = _StageExecutor("a", stages_a, inputs=inputs_a)
executor_b = _StageExecutor("b", stages_b, inputs=inputs_b)
for _ in range(delta_stage):
executor_a.next() # A만 선행
for _ in range(executor_a.num_stages - delta_stage):
executor_a.next() # A와 B 교차
executor_b.next()
for _ in range(delta_stage):
executor_b.next() # B 후행 완료
SBO 플래그와 오버랩 인자
single_batch_overlap.py의 SboFlags는 현재 환경에서 어떤 오버랩이 가능한지 결정한다.
class SboFlags:
@classmethod
def enable_combine_down_gemm_two_stream_overlap(cls):
return (
is_sbo_enabled()
and (get_moe_runner_backend().is_flashinfer_cutedsl()
or (get_moe_runner_backend().is_deep_gemm() and not is_blackwell()))
)
@classmethod
def enable_combine_shared_two_stream_overlap(cls):
return (is_sbo_enabled()
and not cls.enable_dispatch_shared_one_stream_overlap()
and not envs.SGLANG_BLACKWELL_OVERLAP_SHARED_EXPERTS_OUTSIDE_SBO.get())
SM 분할 전략
오버랩 시 GPU의 SM을 연산용과 통신용으로 분할한다. Blackwell에서는 통신에 32개 SM, Hopper에서는 3개 SM을 기본 할당한다.
def compute_overlap_args(dispatch_output, alt_stream):
total_num_sms = torch.cuda.get_device_properties("cuda").multi_processor_count
if envs.SGLANG_DEEPEP_LL_COMBINE_SEND_NUM_SMS.is_set():
communicate_num_sms = envs.SGLANG_DEEPEP_LL_COMBINE_SEND_NUM_SMS.get()
else:
communicate_num_sms = 32 if is_blackwell() else 3
compute_num_sms = total_num_sms - communicate_num_sms
Down GEMM + Combine 오버랩
Down GEMM (전문가 출력 프로젝션)과 Combine (All-to-All 결과 수집)을 동시에 실행한다. 시그널 텐서를 사용하여 GEMM 진행 상황을 통신 측에 알린다.
if SboFlags.enable_combine_down_gemm_two_stream_overlap():
if is_blackwell():
combine_signal = torch.zeros(
num_local_experts, dtype=torch.uint32, device=device)
else:
combine_signal_size = num_local_experts * (
(num_tokens_static + MIN_BLOCK_M - 1) // MIN_BLOCK_M)
combine_signal = torch.zeros(
combine_signal_size, dtype=torch.int32, device=device)
down_gemm_overlap_args = DownGemmOverlapArgs(
signal=combine_signal, start_event=combine_wait_event,
num_sms=compute_num_sms)
combine_overlap_args.overlap = True
combine_overlap_args.signal = combine_signal
TBO: 두 배치 교차 실행
two_batch_overlap.py는 현재 배치를 반으로 나누어 두 서브배치를 파이프라인 방식으로 교차 실행한다. 한 서브배치가 통신할 때 다른 서브배치가 연산한다.
def compute_split_seq_index(forward_mode, num_tokens, extend_lens, token_num_per_seq):
if forward_mode == ForwardMode.EXTEND:
return _split_extend_seqs(extend_lens)
elif forward_mode.is_decode():
return (num_tokens // token_num_per_seq) // 2
스테이지 실행 상태 관리
_StateDict는 스테이지 간 상태를 관리하되, 같은 키에 두 번 쓰는 것을 방지하여 메모리 누수를 조기에 탐지한다.
class _StateDict:
def __setattr__(self, key, value):
assert key not in self._data, \
f"`{key}` already exist, are you sure you want to override it?"
self._data[key] = value
SBO vs TBO 비교
| 구분 | SBO (Single Batch) | TBO (Two Batch) |
|---|---|---|
| 오버랩 대상 | GEMM + 통신 | 서브배치 A + 서브배치 B |
| SM 분할 | 연산/통신 고정 분할 | 배치 간 교차 |
| 복잡도 | 낮음 | 높음 (배치 분할 필요) |
| 효과 | 통신 지연 은닉 | 파이프라인 효율 |
관련 포스트
- Deep GEMM Wrapper: 최적화 행렬 곱 라이브러리
- Server Args: 300+ 서버 인자 완전 가이드
참고
- 소스 코드:
python/sglang/srt/batch_overlap/ - DeepEP: DeepSeek Expert Parallelism 통신 라이브러리
관련 포스트
SGLang 의 다른글
- 이전글 [SGLang] Sparsity Algorithms: QUEST와 DeepSeek NSA 희소 패턴
- 현재글 : [SGLang] Batch Overlap: 연산-통신 오버랩 최적화
- 다음글 [SGLang] Model Configuration 시스템: 모델 설정 관리
댓글