[vLLM] Compilation Fusion Passes: 컴파일 퓨전 최적화
들어가며
LLM 추론에서 연산 퓨전은 커널 런치 오버헤드와 메모리 대역폭을 줄이는 핵심 최적화다. vLLM은 vllm/compilation/passes/fusion/ 디렉토리에서 torch.compile의 패턴 매칭 인프라를 활용한 다양한 퓨전 패스를 구현하고 있다. RMSNorm+양자화, AllReduce+RMSNorm, RoPE+KVCache 등 LLM 특화 퓨전을 제공한다.
공식 문서
vLLM 공식 문서: Fusions
핵심 구조/코드 분석
퓨전 패스 목록
vllm/compilation/passes/fusion/ 디렉토리에는 다음 퓨전 패스들이 있다:
rms_quant_fusion.py- RMSNorm + FP8 양자화 퓨전allreduce_rms_fusion.py- AllReduce + RMSNorm 퓨전rope_kvcache_fusion.py- RoPE + KV Cache 퓨전act_quant_fusion.py- Activation + 양자화 퓨전attn_quant_fusion.py- Attention + 양자화 퓨전mla_attn_quant_fusion.py- MLA Attention + 양자화 퓨전sequence_parallelism.py- 시퀀스 병렬화 퓨전collective_fusion.py- 집합 통신 퓨전
RMSNormQuantFusionPass 상세
class RMSNormQuantFusionPass(VllmPatternMatcherPass):
@enable_fake_mode
def __init__(self, config: VllmConfig) -> None:
self.patterns = PatternMatcherPass(pass_name="rmsnorm_quant_fusion_pass")
for epsilon in [1e-5, 1e-6]:
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
두 가지 epsilon 값(1e-5, 1e-6)에 대해 모든 패턴 조합을 등록한다. LLaMA 계열은 1e-5, 일부 모델은 1e-6을 사용하기 때문이다.
패턴 매칭 구조
class RMSNormStaticQuantPattern(RMSNormQuantPattern):
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(input, weight, scale):
result_rms = vllm.ir.ops.rms_norm(input, weight, self.epsilon)
return self.quant_matcher(result_rms, scale)[0]
def replacement(input, weight, scale):
result = torch.empty(input.shape, device=input.device, dtype=self.quant_dtype)
at = auto_functionalized(
self.FUSED_OP, # torch.ops._C.rms_norm_static_fp8_quant
result=result, input=input, weight=weight,
scale=scale, epsilon=self.epsilon,
)
return at[1]
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass,
extra_check=_rms_input_weight_dtype_match)
pattern 함수는 "이런 연산 패턴을 찾아라"를 정의하고, replacement는 "그걸 이렇게 바꿔라"를 정의한다. auto_functionalized로 fused C++ 커널을 호출한다.
Fused Add + RMSNorm + Group Quantization
FUSED_OPS = {
FusedRMSQuantKey(kFp8StaticTensorSym, False): torch.ops._C.rms_norm_static_fp8_quant,
FusedRMSQuantKey(kFp8StaticTensorSym, True): torch.ops._C.fused_add_rms_norm_static_fp8_quant,
FusedRMSQuantKey(kFp8DynamicTokenSym, False): torch.ops._C.rms_norm_dynamic_per_token_quant,
FusedRMSQuantKey(kFp8Dynamic128Sym, False): torch.ops._C.rms_norm_per_block_quant,
...
}
Static/Dynamic per-tensor, per-token, per-block(128/64) 양자화와 fused_add 유무를 조합한 매핑 테이블이다. 각 조합마다 전용 C++ 커널이 있다.
타입 검증
def _rms_input_weight_dtype_match(match: pm.Match) -> bool:
for node in match.nodes:
if node.target == _RMS_NORM_OP:
x, weight = node.args[0], node.args[1]
if isinstance(x, fx.Node) and isinstance(weight, fx.Node):
return x.meta["val"].dtype == weight.meta["val"].dtype
return True
RMSNorm의 입력과 가중치 dtype이 다르면 퓨전을 방지한다. 현재 fused 커널이 mixed dtype을 지원하지 않기 때문이다.
왜 이 설계인가
-
패턴 매칭 기반 접근: 수동으로 그래프를 순회하며 패턴을 찾는 대신, torch의
PatternMatcherPass인프라를 활용한다. 이는 패턴 정의와 변환을 선언적으로 작성할 수 있게 해주며, 패턴 간 충돌도 자동으로 관리한다. -
Fused Add 우선 매칭:
FusedAddRMSNorm패턴을 일반 RMSNorm보다 먼저 등록한다. 일반 RMSNorm은 FusedAdd 패턴의 부분집합이므로, 순서가 반대면 더 큰 퓨전 기회를 놓치게 된다. -
E8M0 스케일 지원: NVFP4 같은 새로운 양자화 형식을 위해 E8M0 스케일, column-major 스케일, TMA 정렬 등의 옵션을 조합적으로 등록한다. 이는 하드웨어 발전에 따라 새로운 양자화 커널을 쉽게 추가할 수 있도록 하는 확장 가능한 설계다.
참고 자료
관련 포스트
vLLM 의 다른글
- 이전글 [vLLM] Model Weight Offloading: 가중치 CPU 오프로딩
- 현재글 : [vLLM] Compilation Fusion Passes: 컴파일 퓨전 최적화
- 다음글 [vLLM] Reasoning & Tool Calling: 추론 파서와 도구 호출 파서
댓글