본문으로 건너뛰기

[SGLang] Data Parallel Attention 스케줄러: DP Attention 믹스인

들어가며

Data Parallel Attention(DP Attention)은 어텐션 연산을 여러 DP 워커에 분산하되, MLP는 Tensor Parallel로 처리하는 하이브리드 병렬화 전략이다. 각 DP 워커가 서로 다른 배치를 처리할 수 있기 때문에, 워커 간 토큰 수와 배치 속성을 동기화하는 메커니즘이 필수적이다. SGLang은 이 로직을 SchedulerDPAttnMixinMLPSyncBatchInfo로 구현하며, 파일은 python/sglang/srt/managers/scheduler_dp_attn_mixin.py에 위치한다.

구조도

DP Worker 0          DP Worker 1          DP Worker 2
┌──────────┐        ┌──────────┐        ┌──────────┐
│ Attn(독립)│        │ Attn(독립)│        │ Attn(독립)│
│  tokens=5 │        │ tokens=12│        │  tokens=0│
└────┬─────┘        └────┬─────┘        └────┬─────┘
     │                   │                   │
     └──────── all_gather (MLPSyncBatchInfo) ┘
                         │
                ┌────────▼────────┐
                │  MLP (TP 동기)   │
                │ global_tokens=  │
                │   [5, 12, 0]    │
                └─────────────────┘

핵심 코드 분석

1. MLPSyncBatchInfo: DP 워커 간 상태 동기화

각 DP 워커는 자신의 로컬 배치 정보를 6개 정수 텐서로 인코딩한다. all_gather를 통해 모든 워커의 정보를 수집한 뒤, TP rank 0의 정보만 추출해 글로벌 상태를 결정한다.

@dataclass
class MLPSyncBatchInfo:
    dp_size: int
    tp_size: int
    cp_size: int
    num_tokens: int
    num_tokens_for_logprob: int
    can_cuda_graph: bool
    is_extend_in_batch: bool
    local_can_run_tbo: bool
    local_forward_mode: int

    def _get_local_tensor(self, device, dtype=torch.int64) -> torch.Tensor:
        return torch.tensor(
            [
                self.num_tokens,
                self.num_tokens_for_logprob,
                int(self.can_cuda_graph),
                int(self.is_extend_in_batch),
                int(self.local_can_run_tbo),
                self.local_forward_mode,
            ],
            device=device, dtype=dtype,
        )

6개 필드를 하나의 텐서에 패킹하여 단 한 번의 all_gather로 모든 정보를 교환한다.

2. All-Gather와 글로벌 상태 결정

all_gather 메서드는 (dp_size, tp_size * cp_size, 6) 형상의 텐서에 모든 워커의 정보를 수집한다. 비활성 TP 랭크에는 fallback 값을 설정하여 글로벌 판단에 영향을 주지 않도록 한다.

def all_gather(self, device, group):
    local_info_tensor = self._get_local_tensor(device=device)
    global_info_tensor = torch.empty(
        (self.dp_size, self.tp_size * self.cp_size, 6),
        dtype=torch.int64, device=device,
    )
    torch.distributed.all_gather_into_tensor(
        global_info_tensor.flatten(), local_info_tensor, group=group,
    )
    # 비활성 랭크에 fallback 설정
    tp_info = global_info_tensor.view(self.dp_size * self.tp_size * self.cp_size, 6)
    tp_info[tp_active_ranks == 0] = self._get_fallback_tensor(device=device)

    tp0_info = global_info_tensor[:, 0, :]
    self.tp0_info = tp0_info
    # D2H 복사 한 번으로 CPU 데이터 획득
    cpu_data = tp0_info[:, :2].cpu()
    self.global_num_tokens = cpu_data[:, 0].tolist()
    self.global_num_tokens_for_logprob = cpu_data[:, 1].tolist()
    self.can_cuda_graph = bool(tp0_info[:, 2].min().item())
    self.is_extend_in_batch = bool(tp0_info[:, 3].max().item())

주목할 점은 can_cuda_graphmin(모든 워커가 가능해야), is_extend_in_batchmax(하나라도 extend가 있으면)를 사용하는 것이다. 이 보수적/낙관적 집계 전략이 DP 환경의 정합성을 보장한다.

3. Fallback 텐서: 비활성 워커 처리

비활성 TP 랭크가 글로벌 판단을 왜곡하지 않도록, fallback 값은 신중하게 설정된다.

def _get_fallback_tensor(self, device, dtype=torch.int64) -> torch.Tensor:
    return torch.tensor(
        [
            0,  # num_tokens - 토큰 없음
            0,  # num_tokens_for_logprob
            1,  # can_cuda_graph - True (min에서 영향 없음)
            0,  # is_extend_in_batch - False (max에서 영향 없음)
            1,  # local_can_run_tbo - True
            ForwardMode.IDLE.value,  # IDLE 모드
        ],
        device=device, dtype=dtype,
    )

4. 토큰 수 계산: Extend vs Decode

prepare_mlp_sync_batch_raw 함수는 배치의 forward mode에 따라 토큰 수를 다르게 계산한다.

def prepare_mlp_sync_batch_raw(local_batch, dp_size, ...):
    if local_batch is None or local_batch.forward_mode.is_prebuilt():
        num_tokens = 0
        num_tokens_for_logprob = 0
    elif local_batch.forward_mode.is_decode():
        num_tokens = local_batch.batch_size()
        num_tokens_for_logprob = num_tokens
    else:  # extend (prefill)
        num_tokens = local_batch.extend_num_tokens
        num_tokens_for_logprob = sum(
            max(extend_len - logprob_start_len, 1)
            for logprob_start_len, extend_len in zip(
                local_batch.extend_logprob_start_lens,
                local_batch.extend_lens,
            )
        )

Decode 모드에서는 batch_size()가 곧 토큰 수이고, extend 모드에서는 logprob 시작 위치를 고려해 실제 샘플링 대상 토큰만 카운트한다.

5. Idle 배치 생성

로컬 워커에 실행할 배치가 없어도, 다른 DP 워커가 배치를 실행 중이면 MLP 동기화를 위해 idle 배치를 생성해야 한다.

need_idle_batch = skip_all_gather or max(mlp_sync_info.global_num_tokens) > 0
if need_idle_batch:
    batch_to_gather = local_batch
    if local_batch is None:
        batch_to_gather = local_batch = get_idle_batch()
    elif local_batch.forward_mode.is_prebuilt():
        batch_to_gather = local_batch.inner_idle_batch = get_idle_batch()

6. 믹스인 클래스: 스케줄러와의 결합

SchedulerDPAttnMixin은 스케줄러의 설정값들을 raw 함수에 전달하는 얇은 래퍼다.

class SchedulerDPAttnMixin:
    def prepare_mlp_sync_batch(self: Scheduler, local_batch):
        return prepare_mlp_sync_batch_raw(
            local_batch,
            dp_size=self.server_args.dp_size,
            attn_tp_size=self.attn_tp_size,
            attn_cp_size=self.attn_cp_size,
            tp_group=self.tp_group,
            get_idle_batch=self.get_idle_batch,
            disable_cuda_graph=self.server_args.disable_cuda_graph,
            require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
            disable_overlap_schedule=self.server_args.disable_overlap_schedule,
            offload_tags=self.offload_tags,
        )

왜 이 설계인가

어텐션과 MLP의 분리: DP Attention에서 어텐션은 각 워커가 독립 실행하지만, MLP는 TP로 동기 실행해야 한다. 이 비대칭 구조 때문에 MLP 진입 전 글로벌 배치 정보 동기화가 필수적이다. MLPSyncBatchInfo가 바로 이 경계에서 정보를 교환하는 역할을 한다.

단일 all_gather로 최소 통신: 6개 정수를 하나의 텐서에 패킹하고 단 한 번의 all_gather_into_tensor로 교환한다. D2H 메모리 복사도 한 번만 수행해 통신 오버헤드를 최소화한다.

보수적/낙관적 집계: CUDA graph 사용 가능 여부는 min(보수적)으로, extend 존재 여부는 max(낙관적)로 판단한다. 이는 DP 워커 간 불일치 상황에서도 안전하고 정확한 실행을 보장한다.

CPU vs GPU 통신 경로: overlap schedule이 비활성화되었거나 NCCL all_gather 모드일 때는 GPU 디바이스 그룹을, 그렇지 않으면 CPU 그룹을 사용한다. 이 분기는 통신과 연산의 오버랩 가능성에 따라 최적 경로를 선택한다.

관련 포스트

  • [SGLang] Pipeline Parallelism 스케줄러: PP 믹스인 설계 (별도 포스트)
  • [SGLang] Prefill Delayer: 전략적 프리필 지연 분석 (별도 포스트)

참고

댓글

관련 포스트

SGLang 의 다른글