본문으로 건너뛰기

[SGLang] Pipeline Parallelism 스케줄러: PP 믹스인 설계

들어가며

대형 모델을 여러 GPU에 나눠 실행할 때, Pipeline Parallelism(PP)은 모델을 스테이지 단위로 분할하고 마이크로배치를 파이프라인처럼 흘려보내는 방식이다. SGLang은 이 PP 스케줄링 로직을 SchedulerPPMixin 클래스로 분리해, 기존 스케줄러에 믹스인 패턴으로 결합한다. 이 글에서는 python/sglang/srt/managers/scheduler_pp_mixin.py의 핵심 설계를 분석한다.

구조도

┌──────────────────────────────────────────────────────┐
│                  event_loop_pp()                      │
│  ┌─────────┐  ┌─────────┐       ┌─────────┐         │
│  │  MB #0  │→ │  MB #1  │→ ... →│ MB #N-1 │→ (반복) │
│  └────┬────┘  └────┬────┘       └────┬────┘         │
│       │            │                 │               │
│  ┌────▼────────────▼─────────────────▼────┐         │
│  │         per micro-batch step           │         │
│  │  1. recv_requests (from prev stage)    │         │
│  │  2. send_requests (to next stage)      │         │
│  │  3. get_next_batch_to_run              │         │
│  │  4. recv proxy tensors                 │         │
│  │  5. launch batch (GPU forward)         │         │
│  │  6. send/recv output tensors           │         │
│  │  7. process_batch_result               │         │
│  └────────────────────────────────────────┘         │
└──────────────────────────────────────────────────────┘

Stage 0 ──proxy──▶ Stage 1 ──proxy──▶ Stage 2
         ◀─output──          ◀─output──

핵심 코드 분석

1. PP 루프 상태 초기화

PP 루프는 pp_size + pp_async_batch_depth 개의 마이크로배치 슬롯을 운용한다. async batch depth를 추가하면 마지막 스테이지에서 GPU 연산과 CPU 후처리를 오버랩할 수 있다.

def init_pp_loop_state(self: Scheduler):
    self.pp_loop_size: int = self.pp_size + self.server_args.pp_async_batch_depth
    self.mbs = [None] * self.pp_loop_size
    self.last_mbs = [None] * self.pp_loop_size
    self.running_mbs = [
        ScheduleBatch(reqs=[], batch_is_full=False)
        for _ in range(self.pp_loop_size)
    ]
    self.mb_metadata: List[Optional[PPBatchMetadata]] = [None] * self.pp_loop_size
    self.pp_outputs: Optional[PPProxyTensors] = None
    self.last_rank_comm_queue: deque[Tuple[torch.cuda.Event, PPProxyTensors]] = deque()

last_rank_comm_queue는 마지막 스테이지에서 출력 텐서를 버퍼링하는 큐로, async batch depth 만큼의 출력을 지연 전송할 수 있게 한다.

2. 메인 이벤트 루프

event_loop_pp는 각 마이크로배치를 순회하며 동일한 스케줄을 반복한다. 핵심은 비동기 send + 동기 recv 조합이다.

def event_loop_pp(self: Scheduler):
    self.init_pp_loop_state()
    while True:
        server_is_idle = True
        for mb_id in range(self.pp_loop_size):
            self.running_batch = self.running_mbs[mb_id]
            next_first_rank_mb_id = (mb_id + self.pp_size) % self.pp_loop_size
            next_mb_id = (mb_id + 1) % self.pp_loop_size

            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)

            if not self.pp_group.is_last_rank:
                self._pp_commit_comm_work(self.send_req_work)
                self.send_req_work = self._pp_send_pyobj_to_next_stage(
                    recv_reqs, async_send=True,
                )

요청을 수신하자마자 비동기로 다음 스테이지에 전달한다. _pp_commit_comm_work은 이전 비동기 send의 완료를 보장한다.

3. 스테이지 간 텐서 통신: 타입 기반 디멀티플렉싱

PP 통신에서 proxy 텐서(hidden states)와 output 텐서(next_token_ids)는 같은 채널을 쓰지만, 메시지 타입으로 구분한다.

def _pp_send_dict_to_next_stage(
    self: Scheduler, tensor_dict, async_send=True, msg_type="default",
):
    tensor_dict["__msg_type__"] = msg_type
    p2p_work = []
    p2p_work.extend(
        self.pp_group.send_tensor_dict(
            tensor_dict=tensor_dict,
            all_gather_group=(
                self.attn_tp_group if self.require_attn_tp_allgather else None
            ),
            async_send=async_send,
        )
    )
    return p2p_work

수신 측에서는 _pp_recv_typed_dict가 기대하는 타입이 아닌 메시지를 inbox 큐에 스태시(stash)해 두고, 올바른 타입이 올 때까지 반복 수신한다.

def _pp_recv_typed_dict(self: Scheduler, expected_kind="default", ...):
    if expected_kind in self._pp_tensor_dict_inbox:
        inbox_queue = self._pp_tensor_dict_inbox[expected_kind]
        if inbox_queue:
            return inbox_queue.popleft()
    while True:
        tensor_dict = self.pp_group.recv_tensor_dict(...)
        received_kind = tensor_dict.get("__msg_type__", "default")
        if received_kind == expected_kind:
            return tensor_dict
        else:
            self._pp_tensor_dict_inbox[received_kind].append(tensor_dict)

4. 배치 실행과 출력 버퍼링

마지막 스테이지는 forward 결과를 즉시 전송하지 않고, CUDA 이벤트와 함께 큐에 넣어 둔다. 이 버퍼링이 async batch depth의 핵심이다.

def _pp_launch_batch(self: Scheduler, mb_id, pp_proxy_tensors, ...):
    with self.forward_stream_ctx:
        self.forward_stream.wait_stream(self.schedule_stream)
        result = self.run_batch(self.cur_batch, pp_proxy_tensors)
        event = torch.cuda.Event()
        event.record(torch.cuda.current_stream())
        if self.pp_group.is_last_rank:
            last_rank_comm_queue.append((
                event,
                PPProxyTensors(self._pp_prepare_tensor_dict(result, self.cur_batch)),
            ))
    return result, event

5. 동적 청크 사이징: ChunkSizePredictor

PP 환경에서 각 스테이지의 연산 시간을 균등하게 맞추기 위해, SGLang은 프리필 레이턴시를 프로파일링하고 이차(quadratic) 모델 f(l) = al^2 + bl + c로 피팅한다.

class ChunkSizePredictor:
    def predict_next_chunk_size(self, history_len, base_chunk_size, ...):
        # f(L+x) - f(L) = target_latency를 풀어 x를 구한다
        # ax^2 + (2aL+b)x - T = 0 → 이차방정식의 양의 근
        A = self.quadratic_coeff_a
        B = 2 * self.quadratic_coeff_a * history_len + self.linear_coeff_b
        C = -self.target_latency
        discriminant = B * B - 4 * A * C
        sqrt_discriminant = math.sqrt(discriminant)
        calculated_chunk_size_float = (-B + sqrt_discriminant) / (2 * A)

시퀀스가 길어질수록(history_len 증가) 어텐션의 O(n^2) 특성 때문에 청크 사이즈를 줄여야 하며, 이 예측기가 자동으로 조절한다.

왜 이 설계인가

믹스인 패턴의 이유: PP 로직은 TP/DP와 독립적으로 동작하는 직교(orthogonal) 관심사다. 믹스인으로 분리하면 PP 없는 환경에서는 이 코드가 전혀 로드되지 않고, PP+DP 조합도 자연스럽게 구성된다.

비동기 send + 동기 recv: 비동기 send는 GPU 연산과 통신을 오버랩해 버블을 줄이고, 동기 recv는 스테이지 간 순서가 어긋나지 않도록 보장한다. 완전 비동기로 가면 디버깅이 극도로 어려워지므로, 이 하이브리드 접근이 실용적이다.

async batch depth: 마지막 스테이지가 출력을 버퍼에 쌓아두면, GPU가 다음 배치를 처리하는 동안 CPU가 이전 결과를 후처리할 수 있다. 이는 마지막 랭크의 straggler 현상을 완화한다.

동적 청크 사이징: 고정 청크 사이즈는 짧은 시퀀스에서 GPU를 과소 활용하고, 긴 시퀀스에서 버블을 유발한다. 이차 모델 피팅으로 시퀀스 길이에 따라 적응적으로 청크 크기를 조절하면 파이프라인 효율이 극대화된다.

관련 포스트

  • [SGLang] DP Attention 믹스인 분석 (별도 포스트)
  • [SGLang] Prefill Delayer: 전략적 프리필 지연 분석 (별도 포스트)

참고

댓글

관련 포스트

SGLang 의 다른글