본문으로 건너뛰기

[SGLang] 통신 연산: AllReduce, Broadcast, AllGather 구현

들어가며

SGLang의 모델 레이어들은 분산 통신을 직접 호출하지 않는다. 대신 communication_op.py에 정의된 래퍼 함수들을 통해 간접적으로 통신한다. 이 래퍼 계층이 존재하는 이유는, Attention과 MoE가 서로 다른 병렬 그룹을 사용하기 때문이다.

구조도

Model Layer (Linear, Attention, MoE)
    │
    ▼
communication_op.py (래퍼 함수들)
    │
    ├── tensor_model_parallel_all_reduce()     ─── get_tp_group()
    ├── attention_tensor_model_parallel_...()   ─── get_attn_tp_group()
    ├── moe_tensor_model_parallel_...()        ─── get_moe_tp_group()
    ├── moe_expert_parallel_all_reduce()       ─── get_moe_ep_group()
    ├── tensor_model_parallel_all_gather()      ─── get_tp_group()
    └── broadcast_tensor_dict()                ─── get_tp_group()
    │
    ▼
GroupCoordinator.all_reduce() / all_gather() / broadcast()
    │
    ▼
PyNccl / Custom AllReduce / MSCCL++ / torch.distributed

핵심 코드 분석

기본 AllReduce

가장 빈번하게 호출되는 함수는 tensor_model_parallel_all_reduce다. TP(Tensor Parallelism) 그룹 내에서 텐서를 합산한다.

# communication_op.py
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
    """All-reduce the input tensor across model parallel group."""
    return get_tp_group().all_reduce(input_)

단 한 줄이지만, 내부적으로 GroupCoordinator.all_reduce()가 Custom AllReduce, PyNccl, torch.distributed 중 최적의 백엔드를 선택한다.

Fused AllReduce + RMSNorm

SGLang은 AllReduce와 RMSNorm을 융합하여 커널 오버헤드를 줄이는 API도 제공한다.

def tensor_model_parallel_fused_allreduce_rmsnorm(
    input_: torch.Tensor,
    residual_inp_: torch.Tensor,
    weight_: torch.Tensor,
    eps: float,
) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
    """Fused TP all-reduce + RMSNorm.

    Policy and backend selection are owned by GroupCoordinator:
    it may dispatch to communicator-native fused APIs, custom fused kernels,
    or return None so callers can run generic fallback paths.
    """
    return get_tp_group().fused_allreduce_rmsnorm(input_, residual_inp_, weight_, eps)

반환값이 Optional인 이유는, 해당 백엔드가 fused 연산을 지원하지 않으면 None을 반환하고 호출자가 fallback 경로를 실행하기 때문이다.

Attention 전용 AllReduce

DP Attention이 활성화되면, Attention 레이어는 전체 TP 그룹이 아닌 ATTN_TP 서브그룹에서 AllReduce를 수행한다.

def attention_tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
    """All-reduce the input tensor across attention parallel group."""
    return get_attn_tp_group().all_reduce(input_)

예를 들어 TP=8에서 DP Attention을 2로 설정하면, ATTN_TP 그룹 크기는 4가 된다. Attention은 4-way AllReduce만 수행하므로 통신량이 절반으로 줄어든다.

MoE 전용 통신

MoE 레이어는 두 종류의 AllReduce를 사용한다:

def moe_tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
    """All-reduce the input tensor across moe parallel group."""
    return get_moe_tp_group().all_reduce(input_)

def moe_expert_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
    """All-reduce the input tensor across moe expert parallel group."""
    return get_moe_ep_group().all_reduce(input_)

moe_tp_group은 Expert 내부의 텐서 분할을 합산하고, moe_ep_group은 Expert 간 결과를 합산한다.

AllGather와 Gather

AllGather는 각 랭크의 텐서 조각을 모아 전체 텐서를 복원한다.

def tensor_model_parallel_all_gather(
    input_: torch.Tensor, dim: int = -1
) -> torch.Tensor:
    """All-gather the input tensor across model parallel group."""
    return get_tp_group().all_gather(input_, dim)

def tensor_model_parallel_gather(
    input_: torch.Tensor, dst: int = 0, dim: int = -1
) -> Optional[torch.Tensor]:
    """Gather the input tensor across model parallel group."""
    return get_tp_group().gather(input_, dst, dim)

all_gather는 모든 랭크가 전체 텐서를 받고, gatherdst 랭크만 전체 텐서를 받는다. 임베딩 레이어 같이 전체 vocab을 각 랭크가 분할 보유하는 경우에 사용된다.

Broadcast

broadcast_tensor_dict는 rank 0의 텐서 딕셔너리를 다른 모든 랭크에 전파한다.

def broadcast_tensor_dict(
    tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, src: int = 0
):
    if not torch.distributed.is_initialized():
        return tensor_dict
    return get_tp_group().broadcast_tensor_dict(tensor_dict, src)

내부적으로 메타데이터는 CPU(Gloo)로, 텐서 데이터는 GPU(NCCL)로 전송하는 이중 채널 전략을 사용한다.

비교: 직접 호출 vs 래퍼

방식 코드 문제점
직접 호출 torch.distributed.all_reduce(x, group=tp_group) 그룹 인스턴스를 매번 전달해야 함
래퍼 사용 tensor_model_parallel_all_reduce(x) 그룹 자동 선택, 백엔드 자동 디스패치

래퍼를 사용하면 모델 코드가 병렬화 전략에 무관하게 작성된다. TP 그룹 크기가 바뀌어도 모델 레이어 코드는 변경이 필요 없다.

설계 근거

왜 Attention/MoE 전용 함수가 필요한가? Attention과 MoE가 서로 다른 병렬 그룹을 가질 수 있기 때문이다. DP Attention에서는 TP 8-way 중 Attention은 4-way, MoE는 2-way EP 같은 식으로 분할된다. 단일 all_reduce() 함수로는 이 구분이 불가능하다.

왜 fused_allreduce_rmsnorm이 Optional을 반환하는가? 모든 백엔드가 fused 커널을 지원하지 않는다. Custom AllReduce의 AMD 구현은 custom_fused_ar_rms를 제공하지만, 순수 NCCL 경로에서는 불가능하다. 호출자가 fallback을 처리하도록 None 반환이 필요하다.

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글