본문으로 건너뛰기

[flashinfer] FlashInfer의 고성능 분산 연산: All-Gather Matmul 최적화 분석

PR 링크: flashinfer-ai/flashinfer#2665 상태: Merged | 변경: +None / -None

들어가며

대규모 언어 모델(LLM)의 추론 및 학습 과정에서 Tensor Parallelism(TP)이나 Sequence Parallelism(SP)을 적용할 때, All-gatherMatmul을 결합한 연산은 필수적입니다. 기존 방식은 통신과 연산을 순차적으로 수행하여 GPU 유휴 시간이 발생하곤 했습니다. 이번 FlashInfer PR에서는 Push-Wait 알고리즘을 도입하여 통신과 연산을 오버랩(Overlap)함으로써 고성능 분산 행렬 곱셈을 구현했습니다.

코드 분석

1. Push-Wait 알고리즘의 핵심 (all_gather_matmul_cutile.py)

이 PR의 핵심은 wait_signal_matmul_kernel입니다. 각 랭크는 자신의 로컬 데이터를 피어 GPU의 메모리에 직접 쓰고(Push), 연산 커널은 데이터가 도착했음을 알리는 신호를 기다리며(Wait) 즉시 연산을 수행합니다.

# Wait for input ready signal
if shift > 0:
    signal_index = (peer, chunk_idx)
    signal = ct.load(signal_pad, index=signal_index, shape=(), padding_mode=zero_pad)
    while signal == 0:
        signal = ct.load(signal_pad, index=signal_index, shape=(), padding_mode=zero_pad)

위 코드처럼 signal_pad를 통해 데이터 가용성을 확인하며, ct.mma를 통해 Tensor Core 연산을 수행합니다. 이는 전통적인 NCCL All-gather 후 GEMM을 수행하는 방식보다 훨씬 낮은 레이턴시를 보장합니다.

2. 동적 백엔드 라우팅 (all_gather_matmul.py)

하드웨어 아키텍처에 따라 최적화된 커널을 선택하도록 설계되었습니다. Blackwell(SM >= 100) 아키텍처에서는 cuTile을, 그 이전 세대에서는 Triton을 사용합니다.

@register_custom_op("flashinfer::all_gather_matmul", mutates_args=[])
def all_gather_matmul(...):
    major, _ = torch.cuda.get_device_capability(inp.device)
    if major >= 10:
        return all_gather_matmul_cutile(inp, w, group, verbose=verbose)
    return all_gather_matmul_triton(inp, w, group, verbose=verbose)

왜 이게 좋은가

성능 수치

4 x H100 환경에서 벤치마크 결과, 토큰 수가 많을수록(16384 기준) 기존 방식 대비 1.32배의 속도 향상을 보였습니다. 이는 통신과 연산의 병렬화가 실제 워크로드에서 얼마나 큰 이득을 주는지 잘 보여줍니다.

일반적 교훈

  1. Overlap의 중요성: 분산 시스템에서 통신은 병목의 주범입니다. 연산과 통신을 비동기적으로 겹치는(Overlap) 설계는 필수적입니다.
  2. Hardware-Specific Optimization: 최신 GPU(Blackwell)의 기능을 활용하기 위해 cuTile과 같은 저수준 API를 활용하는 전략은 성능 최적화의 핵심입니다.
  3. 코드 구조: register_custom_op를 통해 PyTorch 생태계와 자연스럽게 통합하면서도, 내부적으로는 고도로 최적화된 커스텀 커널을 호출하는 방식은 확장성과 성능을 모두 잡는 좋은 패턴입니다.

리뷰 피드백 반영

리뷰어들은 코드의 가독성과 API 일관성을 위해 @flashinfer_api 데코레이터 추가를 제안했으며, 이는 최종 코드에 반영되어 라이브러리 사용자들에게 더 나은 개발자 경험(DX)을 제공하게 되었습니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글