[vLLM] Context Parallelism: 컨텍스트 병렬화
들어가며
긴 컨텍스트(수만~수십만 토큰)를 처리할 때 단일 GPU의 KV 캐시 메모리로는 부족하다. Context Parallelism은 KV 캐시를 여러 GPU에 분산하여 이 문제를 해결한다. vLLM의 Decode Context Parallelism(DCP)은 디코드 단계에서 All-to-All(A2A) 통신을 사용하여 기존 AllGather + ReduceScatter 방식보다 통신 횟수를 줄인다.
소스 경로: vllm/v1/attention/ops/dcp_alltoall.py
논문: Context Parallelism for Long-Context LLM Inference
공식 문서
vLLM 공식 문서: Context Parallel Deployment
핵심 구조/코드 분석
핵심 아이디어
기존 방식(AG+RS)은 어텐션 레이어당 3번의 NCCL 호출이 필요하다:
- Q 텐서 AllGather
- K 메타데이터 AllGather
- 출력 ReduceScatter
A2A 방식은 이를 2번으로 줄인다:
- 부분 어텐션 출력 All-to-All
- LSE(Log-Sum-Exp) 값 All-to-All
LSE 가중 결합 (CPU 참조 구현)
def _lse_weighted_combine(
outputs: torch.Tensor, # [N, B, H, D]
lses: torch.Tensor, # [N, B, H]
return_lse: bool = False,
is_lse_base_on_e: bool = True,
) -> torch.Tensor:
N, B, H, D = outputs.shape
# NaN/Inf 처리
lses = torch.where(
torch.isnan(lses) | torch.isinf(lses),
torch.tensor(float("-inf"), device=lses.device, dtype=lses.dtype),
lses,
)
# 수치 안정성을 위한 max LSE 계산
lse_max, _ = lses.max(dim=0)
# Softmax 가중치 계산
weights = torch.exp(lses - lse_max.unsqueeze(0))
weight_sum = weights.sum(dim=0, keepdim=True)
weights = weights / weight_sum.clamp(min=1e-10)
# 가중 결합
result = (outputs * weights.unsqueeze(-1)).sum(dim=0)
return result
각 GPU가 자신의 KV 샤드에 대해 계산한 부분 어텐션 출력과 LSE를 교환한 후, softmax 가중 결합으로 최종 출력을 생성한다. 수치 안정성을 위해 max LSE를 빼서 오버플로를 방지한다.
Triton 커널
@triton.jit
def _dcp_lse_combine_kernel(
recv_output_ptr, recv_lse_ptr, out_ptr, out_lse_ptr,
# strides...
N: tl.constexpr, HEAD_DIM: tl.constexpr,
IS_BASE_E: tl.constexpr, RETURN_LSE: tl.constexpr,
):
batch_idx = tl.program_id(0).to(tl.int64)
head_idx = tl.program_id(1).to(tl.int64)
# 1단계: max LSE 계산 (수치 안정성)
lse_max = -float("inf")
for n in tl.static_range(N):
lse_val = tl.load(recv_lse_ptr + n * rl_stride_N + base_lse_offset)
lse_max = tl.maximum(lse_max, lse_val)
# 2단계: exp(lse - max) 합산
lse_sum = 0.0
for n in tl.static_range(N):
lse_val = tl.load(recv_lse_ptr + ...)
if IS_BASE_E:
lse_sum += tl.exp(lse_val - lse_max)
else:
lse_sum += tl.exp2(lse_val - lse_max)
# 3단계: 가중 결합
d_offsets = tl.arange(0, HEAD_DIM)
acc = tl.zeros([HEAD_DIM], dtype=tl.float32)
for n in tl.static_range(N):
weight = tl.exp(lse_val - global_lse)
out_vals = tl.load(recv_output_ptr + ...)
acc += out_vals.to(tl.float32) * weight
tl.store(out_ptr + final_offsets, acc)
Triton 커널은 (B, H_local) 그리드로 실행되며, 각 프로그램이 하나의 (batch, head) 조합을 처리한다. 3-pass 알고리즘으로 max LSE, softmax 가중치, 가중 결합을 순차적으로 계산한다. tl.static_range(N)으로 루프를 언롤링하여 성능을 극대화한다.
All-to-All 통신 파이프라인
def dcp_a2a_lse_reduce(cp_attn_out, cp_attn_lse, cp_group, ...):
B, H, D = local_output.shape
H_per_rank = H // world_size
# [B, H, D] -> [N, B, H/N, D]로 재구성
send_output = local_output.view(B, world_size, H_per_rank, D)\
.permute(1, 0, 2, 3).contiguous()
recv_output = torch.empty_like(send_output)
# 비동기 All-to-All (출력과 LSE 오버랩)
work_output = dist.all_to_all_single(
recv_output.view(-1), send_output.view(-1),
group=cp_group.device_group, async_op=True,
)
work_lse = dist.all_to_all_single(
recv_lse.view(-1), send_lse.view(-1),
group=cp_group.device_group, async_op=True,
)
work_output.wait()
work_lse.wait()
# Triton 커널로 로컬 결합 (통신 불필요)
return dcp_lse_combine_triton(recv_output, recv_lse, ...)
텐서 흐름은 다음과 같다:
- 각 랭크가 전체 헤드의 부분 어텐션 출력을 가짐
[B, H, D] - 헤드를 N개 청크로 분할하여 All-to-All 교환
- 교환 후 각 랭크는 자신의 헤드에 대한 모든 KV 샤드의 결과를 가짐
- Triton 커널로 로컬 LSE 가중 결합
왜 이 설계인가
-
통신 횟수 감소: AG+RS가 레이어당 3회 NCCL 호출을 필요로 하는 반면, A2A는 2회로 줄인다. 긴 컨텍스트 디코드에서 NCCL 레이턴시가 상당한 비중을 차지하므로 이 차이가 크다.
-
비동기 오버랩: 출력과 LSE의 All-to-All을
async_op=True로 동시에 시작한다. 두 통신이 네트워크 대역폭을 공유하면서 파이프라인 효과를 낸다. -
Triton으로 로컬 결합: All-to-All 이후의 LSE 가중 결합은 통신이 필요 없는 순수 로컬 연산이다. Triton 커널로 구현하여 GPU에서 효율적으로 실행한다.
-
사용법 간소화:
vllm serve model --tp 16 --dcp 16 --dcp-comm-backend a2a한 줄로 활성화할 수 있다. 기존 텐서 병렬 설정과 직교하여 독립적으로 조합 가능하다.
논문 핵심 내용
Helix Parallelism / DCP: Context Parallelism for Long-Context LLM Inference (2507.07120) 논문은 긴 컨텍스트 디코딩에 최적화된 하이브리드 병렬화 전략을 제안했다.
핵심 아이디어: Helix Parallelism은 어텐션 단계에서 KV 병렬화를 적용해 KV 캐시를 GPU 간에 분산하고, FFN 단계에서는 같은 GPU를 텐서 병렬(TP) 또는 TP x Expert Parallel(EP)로 재사용하는 하이브리드 실행 전략이다. 이렇게 하면 단일 병렬화 방식의 한계를 극복할 수 있다.
주요 성능 수치
| 메트릭 | 모델 | 수치 |
|---|---|---|
| TTL(Token-to-Token Latency) 감소 | DeepSeek-R1 | 최대 1.5x |
| TTL 감소 | Llama-405B | 1.13x |
| 배치 크기 확장 | DeepSeek-R1 | 동일 지연 예산 내 32x |
| 처리량 및 배치 확장 | Llama-405B | TP 대비 4x |
NVIDIA Blackwell(GB200) GPU에서 FP4 정밀도로 100만 토큰 컨텍스트를 처리하는 시나리오에서 테스트됐다. DeepSeek-R1에서 동일한 지연 예산으로 32배 더 큰 배치를 처리할 수 있다는 결과는, MoE 모델의 FFN 단계에서 EP를 활용하는 것이 매우 효과적임을 보여준다. Llama-405B 같은 Dense 모델에서도 TP 대비 4배 높은 처리량을 달성했다.
참고
관련 포스트
vLLM 의 다른글
- 이전글 [vLLM] Mamba (SSM): 선형 시간 복잡도 시퀀스 모델링
- 현재글 : [vLLM] Context Parallelism: 컨텍스트 병렬화
- 다음글 [vLLM] KV Transfer Connectors: KV 캐시 전송 프레임워크
댓글