본문으로 건너뛰기

[SGLang] Batch Overlap: 연산-통신 오버랩 최적화

들어가며

MoE(Mixture of Experts) 모델의 Expert Parallelism에서 All-to-All 통신은 상당한 시간을 차지한다. SGLang의 Batch Overlap 모듈은 GPU의 연산(GEMM)과 통신(All-to-All)을 서로 다른 CUDA 스트림에서 동시에 실행하여 유휴 시간을 줄인다.

구조도

batch_overlap/
├── operations.py           ── Stage 기반 실행 엔진
├── operations_strategy.py  ── 연산 전략 정의
├── single_batch_overlap.py ── SBO: 단일 배치 내 오버랩
└── two_batch_overlap.py    ── TBO: 두 배치 간 오버랩

[Single Batch Overlap - 타임라인]
Main Stream:  │ Dispatch │ Gate+Up GEMM │ Down GEMM ──────│
Alt Stream:   │          │              │ Combine (통신)  │
              ├──────────┼──────────────┼─────────────────┤
              t0         t1            t2                 t3

[Two Batch Overlap - 타임라인]
Batch A:  │ Stage 0 │ Stage 1 │ Stage 2 │ Stage 3 │
Batch B:            │ Stage 0 │ Stage 1 │ Stage 2 │ Stage 3 │
          ├─────────┼─────────┼─────────┼─────────┼─────────┤
          배치 A가 delta_stage만큼 선행 실행

핵심 코드 분석

Operations 실행 엔진

operations.py는 연산을 Stage 단위로 묶어 순차 실행하는 엔진이다. YieldOperation으로 스테이지 경계를 표시한다.

def execute_operations(inputs, operations):
    stages = _convert_operations_to_stages(operations)
    executor = _StageExecutor("primary", stages, inputs=inputs)
    for _ in range(executor.num_stages):
        executor.next()
    return executor.output

두 배치를 오버랩할 때는 execute_overlapped_operations를 사용한다. 배치 A가 delta_stage만큼 먼저 실행되고, 이후 A와 B가 번갈아 실행된다.

def execute_overlapped_operations(inputs_arr, operations_arr, delta_stages):
    executor_a = _StageExecutor("a", stages_a, inputs=inputs_a)
    executor_b = _StageExecutor("b", stages_b, inputs=inputs_b)

    for _ in range(delta_stage):
        executor_a.next()        # A만 선행
    for _ in range(executor_a.num_stages - delta_stage):
        executor_a.next()        # A와 B 교차
        executor_b.next()
    for _ in range(delta_stage):
        executor_b.next()        # B 후행 완료

SBO 플래그와 오버랩 인자

single_batch_overlap.pySboFlags는 현재 환경에서 어떤 오버랩이 가능한지 결정한다.

class SboFlags:
    @classmethod
    def enable_combine_down_gemm_two_stream_overlap(cls):
        return (
            is_sbo_enabled()
            and (get_moe_runner_backend().is_flashinfer_cutedsl()
                 or (get_moe_runner_backend().is_deep_gemm() and not is_blackwell()))
        )

    @classmethod
    def enable_combine_shared_two_stream_overlap(cls):
        return (is_sbo_enabled()
                and not cls.enable_dispatch_shared_one_stream_overlap()
                and not envs.SGLANG_BLACKWELL_OVERLAP_SHARED_EXPERTS_OUTSIDE_SBO.get())

SM 분할 전략

오버랩 시 GPU의 SM을 연산용과 통신용으로 분할한다. Blackwell에서는 통신에 32개 SM, Hopper에서는 3개 SM을 기본 할당한다.

def compute_overlap_args(dispatch_output, alt_stream):
    total_num_sms = torch.cuda.get_device_properties("cuda").multi_processor_count
    if envs.SGLANG_DEEPEP_LL_COMBINE_SEND_NUM_SMS.is_set():
        communicate_num_sms = envs.SGLANG_DEEPEP_LL_COMBINE_SEND_NUM_SMS.get()
    else:
        communicate_num_sms = 32 if is_blackwell() else 3
    compute_num_sms = total_num_sms - communicate_num_sms

Down GEMM + Combine 오버랩

Down GEMM (전문가 출력 프로젝션)과 Combine (All-to-All 결과 수집)을 동시에 실행한다. 시그널 텐서를 사용하여 GEMM 진행 상황을 통신 측에 알린다.

if SboFlags.enable_combine_down_gemm_two_stream_overlap():
    if is_blackwell():
        combine_signal = torch.zeros(
            num_local_experts, dtype=torch.uint32, device=device)
    else:
        combine_signal_size = num_local_experts * (
            (num_tokens_static + MIN_BLOCK_M - 1) // MIN_BLOCK_M)
        combine_signal = torch.zeros(
            combine_signal_size, dtype=torch.int32, device=device)

    down_gemm_overlap_args = DownGemmOverlapArgs(
        signal=combine_signal, start_event=combine_wait_event,
        num_sms=compute_num_sms)
    combine_overlap_args.overlap = True
    combine_overlap_args.signal = combine_signal

TBO: 두 배치 교차 실행

two_batch_overlap.py는 현재 배치를 반으로 나누어 두 서브배치를 파이프라인 방식으로 교차 실행한다. 한 서브배치가 통신할 때 다른 서브배치가 연산한다.

def compute_split_seq_index(forward_mode, num_tokens, extend_lens, token_num_per_seq):
    if forward_mode == ForwardMode.EXTEND:
        return _split_extend_seqs(extend_lens)
    elif forward_mode.is_decode():
        return (num_tokens // token_num_per_seq) // 2

스테이지 실행 상태 관리

_StateDict는 스테이지 간 상태를 관리하되, 같은 키에 두 번 쓰는 것을 방지하여 메모리 누수를 조기에 탐지한다.

class _StateDict:
    def __setattr__(self, key, value):
        assert key not in self._data, \
            f"`{key}` already exist, are you sure you want to override it?"
        self._data[key] = value

SBO vs TBO 비교

구분 SBO (Single Batch) TBO (Two Batch)
오버랩 대상 GEMM + 통신 서브배치 A + 서브배치 B
SM 분할 연산/통신 고정 분할 배치 간 교차
복잡도 낮음 높음 (배치 분할 필요)
효과 통신 지연 은닉 파이프라인 효율

관련 포스트

  • Deep GEMM Wrapper: 최적화 행렬 곱 라이브러리
  • Server Args: 300+ 서버 인자 완전 가이드

참고

댓글

관련 포스트

SGLang 의 다른글