[flashinfer] FlashInfer의 고성능 분산 연산: All-Gather Matmul 최적화 분석
PR 링크: flashinfer-ai/flashinfer#2665 상태: Merged | 변경: +None / -None
들어가며
대규모 언어 모델(LLM)의 추론 및 학습 과정에서 Tensor Parallelism(TP)이나 Sequence Parallelism(SP)을 적용할 때, All-gather와 Matmul을 결합한 연산은 필수적입니다. 기존 방식은 통신과 연산을 순차적으로 수행하여 GPU 유휴 시간이 발생하곤 했습니다. 이번 FlashInfer PR에서는 Push-Wait 알고리즘을 도입하여 통신과 연산을 오버랩(Overlap)함으로써 고성능 분산 행렬 곱셈을 구현했습니다.
코드 분석
1. Push-Wait 알고리즘의 핵심 (all_gather_matmul_cutile.py)
이 PR의 핵심은 wait_signal_matmul_kernel입니다. 각 랭크는 자신의 로컬 데이터를 피어 GPU의 메모리에 직접 쓰고(Push), 연산 커널은 데이터가 도착했음을 알리는 신호를 기다리며(Wait) 즉시 연산을 수행합니다.
# Wait for input ready signal
if shift > 0:
signal_index = (peer, chunk_idx)
signal = ct.load(signal_pad, index=signal_index, shape=(), padding_mode=zero_pad)
while signal == 0:
signal = ct.load(signal_pad, index=signal_index, shape=(), padding_mode=zero_pad)
위 코드처럼 signal_pad를 통해 데이터 가용성을 확인하며, ct.mma를 통해 Tensor Core 연산을 수행합니다. 이는 전통적인 NCCL All-gather 후 GEMM을 수행하는 방식보다 훨씬 낮은 레이턴시를 보장합니다.
2. 동적 백엔드 라우팅 (all_gather_matmul.py)
하드웨어 아키텍처에 따라 최적화된 커널을 선택하도록 설계되었습니다. Blackwell(SM >= 100) 아키텍처에서는 cuTile을, 그 이전 세대에서는 Triton을 사용합니다.
@register_custom_op("flashinfer::all_gather_matmul", mutates_args=[])
def all_gather_matmul(...):
major, _ = torch.cuda.get_device_capability(inp.device)
if major >= 10:
return all_gather_matmul_cutile(inp, w, group, verbose=verbose)
return all_gather_matmul_triton(inp, w, group, verbose=verbose)
왜 이게 좋은가
성능 수치
4 x H100 환경에서 벤치마크 결과, 토큰 수가 많을수록(16384 기준) 기존 방식 대비 1.32배의 속도 향상을 보였습니다. 이는 통신과 연산의 병렬화가 실제 워크로드에서 얼마나 큰 이득을 주는지 잘 보여줍니다.
일반적 교훈
- Overlap의 중요성: 분산 시스템에서 통신은 병목의 주범입니다. 연산과 통신을 비동기적으로 겹치는(Overlap) 설계는 필수적입니다.
- Hardware-Specific Optimization: 최신 GPU(Blackwell)의 기능을 활용하기 위해 cuTile과 같은 저수준 API를 활용하는 전략은 성능 최적화의 핵심입니다.
- 코드 구조:
register_custom_op를 통해 PyTorch 생태계와 자연스럽게 통합하면서도, 내부적으로는 고도로 최적화된 커스텀 커널을 호출하는 방식은 확장성과 성능을 모두 잡는 좋은 패턴입니다.
리뷰 피드백 반영
리뷰어들은 코드의 가독성과 API 일관성을 위해 @flashinfer_api 데코레이터 추가를 제안했으며, 이는 최종 코드에 반영되어 라이브러리 사용자들에게 더 나은 개발자 경험(DX)을 제공하게 되었습니다.
참고 자료
- https://pytorch.org/docs/stable/distributed.html
- https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
- [sglang] SGLang DeepSeekV3 Router GEMM 최적화: FlashInfer 커널 도입 및 벤치마킹
- [sglang] SGLang의 SM120 FP8 Blockwise GEMM 성능 최적화: Pingpong 스케줄 도입
- [flashinfer] FlashInfer 오토튜너 최적화: 하이브리드 토큰 버킷 도입
- [flashinfer] FlashInfer, CuTe DSL 기반 FMHA 커널 통합으로 사전 생성(Prefill) 성능 극대화
- [cpython] Python JIT Shim 빌드 프로세스 개선: 런타임 컴파일에서 빌드 타임 링크로
PR Analysis 의 다른글
- 이전글 [ACE-Step-1.5] ACE-Step에 파동대역 보정(DCW) 샘플러 훅 추가: SNR-t 편향 개선
- 현재글 : [flashinfer] FlashInfer의 고성능 분산 연산: All-Gather Matmul 최적화 분석
- 다음글 [onnxruntime] ONNX Runtime 스레드 풀의 지능형 대기: Exponential Backoff 도입으로 성능 및 전력 효율성 향상
댓글