본문으로 건너뛰기

[vLLM] Context Parallelism: 컨텍스트 병렬화

들어가며

긴 컨텍스트(수만~수십만 토큰)를 처리할 때 단일 GPU의 KV 캐시 메모리로는 부족하다. Context Parallelism은 KV 캐시를 여러 GPU에 분산하여 이 문제를 해결한다. vLLM의 Decode Context Parallelism(DCP)은 디코드 단계에서 All-to-All(A2A) 통신을 사용하여 기존 AllGather + ReduceScatter 방식보다 통신 횟수를 줄인다.

소스 경로: vllm/v1/attention/ops/dcp_alltoall.py

논문: Context Parallelism for Long-Context LLM Inference

공식 문서

vLLM 공식 문서: Context Parallel Deployment

핵심 구조/코드 분석

핵심 아이디어

기존 방식(AG+RS)은 어텐션 레이어당 3번의 NCCL 호출이 필요하다:

  1. Q 텐서 AllGather
  2. K 메타데이터 AllGather
  3. 출력 ReduceScatter

A2A 방식은 이를 2번으로 줄인다:

  1. 부분 어텐션 출력 All-to-All
  2. LSE(Log-Sum-Exp) 값 All-to-All

LSE 가중 결합 (CPU 참조 구현)

def _lse_weighted_combine(
    outputs: torch.Tensor,  # [N, B, H, D]
    lses: torch.Tensor,     # [N, B, H]
    return_lse: bool = False,
    is_lse_base_on_e: bool = True,
) -> torch.Tensor:
    N, B, H, D = outputs.shape

    # NaN/Inf 처리
    lses = torch.where(
        torch.isnan(lses) | torch.isinf(lses),
        torch.tensor(float("-inf"), device=lses.device, dtype=lses.dtype),
        lses,
    )

    # 수치 안정성을 위한 max LSE 계산
    lse_max, _ = lses.max(dim=0)

    # Softmax 가중치 계산
    weights = torch.exp(lses - lse_max.unsqueeze(0))
    weight_sum = weights.sum(dim=0, keepdim=True)
    weights = weights / weight_sum.clamp(min=1e-10)

    # 가중 결합
    result = (outputs * weights.unsqueeze(-1)).sum(dim=0)
    return result

각 GPU가 자신의 KV 샤드에 대해 계산한 부분 어텐션 출력과 LSE를 교환한 후, softmax 가중 결합으로 최종 출력을 생성한다. 수치 안정성을 위해 max LSE를 빼서 오버플로를 방지한다.

Triton 커널

@triton.jit
def _dcp_lse_combine_kernel(
    recv_output_ptr, recv_lse_ptr, out_ptr, out_lse_ptr,
    # strides...
    N: tl.constexpr, HEAD_DIM: tl.constexpr,
    IS_BASE_E: tl.constexpr, RETURN_LSE: tl.constexpr,
):
    batch_idx = tl.program_id(0).to(tl.int64)
    head_idx = tl.program_id(1).to(tl.int64)

    # 1단계: max LSE 계산 (수치 안정성)
    lse_max = -float("inf")
    for n in tl.static_range(N):
        lse_val = tl.load(recv_lse_ptr + n * rl_stride_N + base_lse_offset)
        lse_max = tl.maximum(lse_max, lse_val)

    # 2단계: exp(lse - max) 합산
    lse_sum = 0.0
    for n in tl.static_range(N):
        lse_val = tl.load(recv_lse_ptr + ...)
        if IS_BASE_E:
            lse_sum += tl.exp(lse_val - lse_max)
        else:
            lse_sum += tl.exp2(lse_val - lse_max)

    # 3단계: 가중 결합
    d_offsets = tl.arange(0, HEAD_DIM)
    acc = tl.zeros([HEAD_DIM], dtype=tl.float32)
    for n in tl.static_range(N):
        weight = tl.exp(lse_val - global_lse)
        out_vals = tl.load(recv_output_ptr + ...)
        acc += out_vals.to(tl.float32) * weight

    tl.store(out_ptr + final_offsets, acc)

Triton 커널은 (B, H_local) 그리드로 실행되며, 각 프로그램이 하나의 (batch, head) 조합을 처리한다. 3-pass 알고리즘으로 max LSE, softmax 가중치, 가중 결합을 순차적으로 계산한다. tl.static_range(N)으로 루프를 언롤링하여 성능을 극대화한다.

All-to-All 통신 파이프라인

def dcp_a2a_lse_reduce(cp_attn_out, cp_attn_lse, cp_group, ...):
    B, H, D = local_output.shape
    H_per_rank = H // world_size

    # [B, H, D] -> [N, B, H/N, D]로 재구성
    send_output = local_output.view(B, world_size, H_per_rank, D)\
                              .permute(1, 0, 2, 3).contiguous()
    recv_output = torch.empty_like(send_output)

    # 비동기 All-to-All (출력과 LSE 오버랩)
    work_output = dist.all_to_all_single(
        recv_output.view(-1), send_output.view(-1),
        group=cp_group.device_group, async_op=True,
    )
    work_lse = dist.all_to_all_single(
        recv_lse.view(-1), send_lse.view(-1),
        group=cp_group.device_group, async_op=True,
    )
    work_output.wait()
    work_lse.wait()

    # Triton 커널로 로컬 결합 (통신 불필요)
    return dcp_lse_combine_triton(recv_output, recv_lse, ...)

텐서 흐름은 다음과 같다:

  1. 각 랭크가 전체 헤드의 부분 어텐션 출력을 가짐 [B, H, D]
  2. 헤드를 N개 청크로 분할하여 All-to-All 교환
  3. 교환 후 각 랭크는 자신의 헤드에 대한 모든 KV 샤드의 결과를 가짐
  4. Triton 커널로 로컬 LSE 가중 결합

왜 이 설계인가

  1. 통신 횟수 감소: AG+RS가 레이어당 3회 NCCL 호출을 필요로 하는 반면, A2A는 2회로 줄인다. 긴 컨텍스트 디코드에서 NCCL 레이턴시가 상당한 비중을 차지하므로 이 차이가 크다.

  2. 비동기 오버랩: 출력과 LSE의 All-to-All을 async_op=True로 동시에 시작한다. 두 통신이 네트워크 대역폭을 공유하면서 파이프라인 효과를 낸다.

  3. Triton으로 로컬 결합: All-to-All 이후의 LSE 가중 결합은 통신이 필요 없는 순수 로컬 연산이다. Triton 커널로 구현하여 GPU에서 효율적으로 실행한다.

  4. 사용법 간소화: vllm serve model --tp 16 --dcp 16 --dcp-comm-backend a2a 한 줄로 활성화할 수 있다. 기존 텐서 병렬 설정과 직교하여 독립적으로 조합 가능하다.

논문 핵심 내용

Helix Parallelism / DCP: Context Parallelism for Long-Context LLM Inference (2507.07120) 논문은 긴 컨텍스트 디코딩에 최적화된 하이브리드 병렬화 전략을 제안했다.

핵심 아이디어: Helix Parallelism은 어텐션 단계에서 KV 병렬화를 적용해 KV 캐시를 GPU 간에 분산하고, FFN 단계에서는 같은 GPU를 텐서 병렬(TP) 또는 TP x Expert Parallel(EP)로 재사용하는 하이브리드 실행 전략이다. 이렇게 하면 단일 병렬화 방식의 한계를 극복할 수 있다.

주요 성능 수치

메트릭 모델 수치
TTL(Token-to-Token Latency) 감소 DeepSeek-R1 최대 1.5x
TTL 감소 Llama-405B 1.13x
배치 크기 확장 DeepSeek-R1 동일 지연 예산 내 32x
처리량 및 배치 확장 Llama-405B TP 대비 4x

NVIDIA Blackwell(GB200) GPU에서 FP4 정밀도로 100만 토큰 컨텍스트를 처리하는 시나리오에서 테스트됐다. DeepSeek-R1에서 동일한 지연 예산으로 32배 더 큰 배치를 처리할 수 있다는 결과는, MoE 모델의 FFN 단계에서 EP를 활용하는 것이 매우 효과적임을 보여준다. Llama-405B 같은 Dense 모델에서도 TP 대비 4배 높은 처리량을 달성했다.

참고

댓글

관련 포스트

vLLM 의 다른글