[SGLang] 통신 연산: AllReduce, Broadcast, AllGather 구현
들어가며
SGLang의 모델 레이어들은 분산 통신을 직접 호출하지 않는다. 대신 communication_op.py에 정의된 래퍼 함수들을 통해 간접적으로 통신한다. 이 래퍼 계층이 존재하는 이유는, Attention과 MoE가 서로 다른 병렬 그룹을 사용하기 때문이다.
구조도
Model Layer (Linear, Attention, MoE)
│
▼
communication_op.py (래퍼 함수들)
│
├── tensor_model_parallel_all_reduce() ─── get_tp_group()
├── attention_tensor_model_parallel_...() ─── get_attn_tp_group()
├── moe_tensor_model_parallel_...() ─── get_moe_tp_group()
├── moe_expert_parallel_all_reduce() ─── get_moe_ep_group()
├── tensor_model_parallel_all_gather() ─── get_tp_group()
└── broadcast_tensor_dict() ─── get_tp_group()
│
▼
GroupCoordinator.all_reduce() / all_gather() / broadcast()
│
▼
PyNccl / Custom AllReduce / MSCCL++ / torch.distributed
핵심 코드 분석
기본 AllReduce
가장 빈번하게 호출되는 함수는 tensor_model_parallel_all_reduce다. TP(Tensor Parallelism) 그룹 내에서 텐서를 합산한다.
# communication_op.py
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
"""All-reduce the input tensor across model parallel group."""
return get_tp_group().all_reduce(input_)
단 한 줄이지만, 내부적으로 GroupCoordinator.all_reduce()가 Custom AllReduce, PyNccl, torch.distributed 중 최적의 백엔드를 선택한다.
Fused AllReduce + RMSNorm
SGLang은 AllReduce와 RMSNorm을 융합하여 커널 오버헤드를 줄이는 API도 제공한다.
def tensor_model_parallel_fused_allreduce_rmsnorm(
input_: torch.Tensor,
residual_inp_: torch.Tensor,
weight_: torch.Tensor,
eps: float,
) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
"""Fused TP all-reduce + RMSNorm.
Policy and backend selection are owned by GroupCoordinator:
it may dispatch to communicator-native fused APIs, custom fused kernels,
or return None so callers can run generic fallback paths.
"""
return get_tp_group().fused_allreduce_rmsnorm(input_, residual_inp_, weight_, eps)
반환값이 Optional인 이유는, 해당 백엔드가 fused 연산을 지원하지 않으면 None을 반환하고 호출자가 fallback 경로를 실행하기 때문이다.
Attention 전용 AllReduce
DP Attention이 활성화되면, Attention 레이어는 전체 TP 그룹이 아닌 ATTN_TP 서브그룹에서 AllReduce를 수행한다.
def attention_tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
"""All-reduce the input tensor across attention parallel group."""
return get_attn_tp_group().all_reduce(input_)
예를 들어 TP=8에서 DP Attention을 2로 설정하면, ATTN_TP 그룹 크기는 4가 된다. Attention은 4-way AllReduce만 수행하므로 통신량이 절반으로 줄어든다.
MoE 전용 통신
MoE 레이어는 두 종류의 AllReduce를 사용한다:
def moe_tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
"""All-reduce the input tensor across moe parallel group."""
return get_moe_tp_group().all_reduce(input_)
def moe_expert_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
"""All-reduce the input tensor across moe expert parallel group."""
return get_moe_ep_group().all_reduce(input_)
moe_tp_group은 Expert 내부의 텐서 분할을 합산하고, moe_ep_group은 Expert 간 결과를 합산한다.
AllGather와 Gather
AllGather는 각 랭크의 텐서 조각을 모아 전체 텐서를 복원한다.
def tensor_model_parallel_all_gather(
input_: torch.Tensor, dim: int = -1
) -> torch.Tensor:
"""All-gather the input tensor across model parallel group."""
return get_tp_group().all_gather(input_, dim)
def tensor_model_parallel_gather(
input_: torch.Tensor, dst: int = 0, dim: int = -1
) -> Optional[torch.Tensor]:
"""Gather the input tensor across model parallel group."""
return get_tp_group().gather(input_, dst, dim)
all_gather는 모든 랭크가 전체 텐서를 받고, gather는 dst 랭크만 전체 텐서를 받는다. 임베딩 레이어 같이 전체 vocab을 각 랭크가 분할 보유하는 경우에 사용된다.
Broadcast
broadcast_tensor_dict는 rank 0의 텐서 딕셔너리를 다른 모든 랭크에 전파한다.
def broadcast_tensor_dict(
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, src: int = 0
):
if not torch.distributed.is_initialized():
return tensor_dict
return get_tp_group().broadcast_tensor_dict(tensor_dict, src)
내부적으로 메타데이터는 CPU(Gloo)로, 텐서 데이터는 GPU(NCCL)로 전송하는 이중 채널 전략을 사용한다.
비교: 직접 호출 vs 래퍼
| 방식 | 코드 | 문제점 |
|---|---|---|
| 직접 호출 | torch.distributed.all_reduce(x, group=tp_group) |
그룹 인스턴스를 매번 전달해야 함 |
| 래퍼 사용 | tensor_model_parallel_all_reduce(x) |
그룹 자동 선택, 백엔드 자동 디스패치 |
래퍼를 사용하면 모델 코드가 병렬화 전략에 무관하게 작성된다. TP 그룹 크기가 바뀌어도 모델 레이어 코드는 변경이 필요 없다.
설계 근거
왜 Attention/MoE 전용 함수가 필요한가? Attention과 MoE가 서로 다른 병렬 그룹을 가질 수 있기 때문이다. DP Attention에서는 TP 8-way 중 Attention은 4-way, MoE는 2-way EP 같은 식으로 분할된다. 단일 all_reduce() 함수로는 이 구분이 불가능하다.
왜 fused_allreduce_rmsnorm이 Optional을 반환하는가? 모든 백엔드가 fused 커널을 지원하지 않는다. Custom AllReduce의 AMD 구현은 custom_fused_ar_rms를 제공하지만, 순수 NCCL 경로에서는 불가능하다. 호출자가 fallback을 처리하도록 None 반환이 필요하다.
관련 포스트
- SGLang Parallel State: TP/PP/DP/EP 병렬화 상태 관리
- SGLang Custom All-Reduce: NCCL 너머의 최적화된 집합 통신
- SGLang NCCL & MSCCL++: 집합 통신 라이브러리 통합
참고
관련 포스트
SGLang 의 다른글
- 이전글 [SGLang] Parallel State: TP/PP/DP/EP 병렬화 상태 관리
- 현재글 : [SGLang] 통신 연산: AllReduce, Broadcast, AllGather 구현
- 다음글 [SGLang] Custom All-Reduce: NCCL 너머의 최적화된 집합 통신
댓글