본문으로 건너뛰기

[SGLang] Parallel State: TP/PP/DP/EP 병렬화 상태 관리

들어가며

대규모 LLM 추론에서 하나의 GPU로는 모델 전체를 올릴 수 없다. SGLang은 Tensor Parallelism(TP), Pipeline Parallelism(PP), Data Parallelism(DP), Expert Parallelism(EP) 네 가지 병렬화를 동시에 지원하며, 이 모든 것의 기반이 되는 모듈이 parallel_state.py다.

이 파일은 vLLM의 동명 모듈을 기반으로 하되, SGLang 고유의 Attention DP, Context Parallelism, MoE DP 등을 추가로 관리한다. 핵심은 GroupCoordinator 클래스와 전역 그룹 변수들이다.

구조도

                    ┌──────────────────────────────────┐
                    │   initialize_model_parallel()     │
                    └──────────┬───────────────────────┘
                               │
            ┌──────────────────┼──────────────────────┐
            │                  │                      │
            ▼                  ▼                      ▼
      ┌──────────┐      ┌──────────┐           ┌──────────┐
      │  _TP     │      │  _PP     │           │  _MOE_EP │
      │ (tp)     │      │ (pp)     │           │  _MOE_TP │
      └────┬─────┘      └──────────┘           │  _MOE_DP │
           │                                   └──────────┘
     ┌─────┼──────┐
     ▼            ▼
┌─────────┐  ┌─────────┐
│ _ATTN_TP│  │ _ATTN_CP│
└─────────┘  └─────────┘

각 그룹 = GroupCoordinator 인스턴스
  ├── device_group (NCCL/HCCL/XCCL)
  ├── cpu_group    (Gloo)
  ├── pynccl_comm  (PyNccl)
  ├── ca_comm      (Custom AllReduce)
  └── mq_broadcaster (Shared Memory)

핵심 코드 분석

전역 그룹 변수

SGLang은 병렬화 그룹을 전역 싱글턴으로 관리한다.

# parallel_state.py
_TP: Optional[GroupCoordinator] = None
_ATTN_TP: Optional[GroupCoordinator] = None
_ATTN_CP: Optional[GroupCoordinator] = None
_PP: Optional[GroupCoordinator] = None
_MOE_DP: Optional[GroupCoordinator] = None
_MOE_EP: Optional[GroupCoordinator] = None
_MOE_TP: Optional[GroupCoordinator] = None

각 변수에 대응하는 getter 함수가 존재한다. 예를 들어 get_tp_group()_TP를 반환하되, PD-Multiplexing이 활성화된 경우 별도의 _PDMUX_PREFILL_TP_GROUP을 반환한다.

GroupCoordinator 초기화

GroupCoordinator는 PyTorch의 ProcessGroup을 감싸면서, 다양한 커뮤니케이터를 조건부로 생성한다.

class GroupCoordinator:
    def __init__(self, group_ranks, local_rank, torch_distributed_backend,
                 use_pynccl, use_pymscclpp, use_custom_allreduce,
                 use_torch_symm_mem_all_reduce, use_hpu_communicator,
                 use_xpu_communicator, use_npu_communicator,
                 use_message_queue_broadcaster=False, group_name=None, ...):
        # 1. device_group과 cpu_group 생성
        for ranks in group_ranks:
            device_group = torch.distributed.new_group(
                ranks, backend=torch_distributed_backend)
            cpu_group = torch.distributed.new_group(
                ranks, backend="gloo")

        # 2. 커뮤니케이터 조건부 생성
        if use_pynccl and self.world_size > 1:
            self.pynccl_comm = PyNcclCommunicator(...)
        if use_custom_allreduce and self.world_size > 1:
            CAClass = dispatch_custom_allreduce()
            self.ca_comm = CAClass(...)

모든 그룹은 device_group(GPU 통신용)과 cpu_group(Gloo, CPU 조율용)을 쌍으로 보유한다. 이 이중 구조는 메타데이터 교환과 GPU 텐서 통신을 분리하기 위함이다.

병렬화 계층 구조

initialize_model_parallel() 함수는 4종 병렬화의 그룹을 계층적으로 생성한다. 8 GPU 예시를 보면:

def initialize_model_parallel(
    tensor_model_parallel_size=1,
    expert_model_parallel_size=1,
    pipeline_model_parallel_size=1,
    attention_data_parallel_size=1,
    attention_context_model_parallel_size=1,
    moe_data_model_parallel_size=1, ...):

    # TP 그룹: [g0,g1,g2,g3], [g4,g5,g6,g7]
    # PP 그룹: [g0,g4], [g1,g5], [g2,g6], [g3,g7]
    # ATTN_CP 그룹: cp_size 기반 서브그룹
    # MOE_EP 그룹: ep_size 기반 서브그룹

Attention과 MoE는 독립적인 병렬화 계층을 갖는다:

Attention: Global(TP) -> DP -> ATTN_CP -> ATTN_TP (innermost)
MoE:       Global(TP) -> MOE_DP -> EP -> MOE_TP (innermost)

AllReduce 디스패치

GroupCoordinator의 all_reduce() 메서드는 상황에 따라 최적의 통신 백엔드를 선택한다.

def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
    if self.world_size == 1:
        return input_

    # 1. 하드웨어별 커뮤니케이터 확인
    if self.hpu_communicator is not None and not self.hpu_communicator.disabled:
        return self.hpu_communicator.all_reduce(input_)

    # 2. Symmetric Memory (DP Attention용)
    if self.pynccl_comm is not None and self.is_symmetric_memory_enabled():
        with self.pynccl_comm.change_state(enable=True):
            self.pynccl_comm.all_reduce(input_)
            return input_

    # 3. Custom AllReduce → Quick AllReduce → MSCCL++ → SymmMem → PyNccl → torch.distributed
    outplace_all_reduce_method = None
    if self.ca_comm is not None and self.ca_comm.should_custom_ar(input_):
        outplace_all_reduce_method = "ca"
    elif self.qr_comm is not None and self.qr_comm.should_quick_allreduce(input_):
        outplace_all_reduce_method = "qr"
    # ...

이 디스패치 로직은 텐서 크기, CUDA 그래프 모드, 하드웨어 종류에 따라 6단계 우선순위로 백엔드를 선택한다.

CUDA Graph 모드에서의 통신 전환

Eager 모드와 Graph 모드에서 활성화되는 통신 방식이 다르다:

# graph_capture 컨텍스트 내부
# allreduce \ Mode   |  Eager  |  Graph  |
# custom allreduce   | enabled | enabled |
# PyNccl             | disabled| enabled |
# PyMscclpp          | disabled| enabled |
# torch.distributed  | enabled | disabled|

설계 근거

왜 전역 싱글턴인가? 모델의 모든 레이어가 동일한 병렬화 그룹을 공유해야 하므로, 인스턴스를 전달하는 것보다 전역 접근이 실용적이다. _register_group()weakref로 생명주기를 관리한다.

왜 device_group과 cpu_group을 분리하는가? NCCL의 barrier는 내부적으로 GPU 텐서를 생성하여 디바이스 혼란을 일으킬 수 있다. CPU 조율은 Gloo를 쓰고, GPU 텐서 통신만 NCCL을 쓰는 구조가 안전하다.

왜 Attention과 MoE의 병렬화가 분리되어 있는가? Attention은 KV-cache 때문에 DP/CP가 유리하고, MoE는 Expert 분산이 핵심이다. 두 레이어의 통신 패턴이 근본적으로 다르기 때문에 별도 그룹이 필요하다.

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글