본문으로 건너뛰기

[vLLM] Compilation Fusion Passes: 컴파일 퓨전 최적화

들어가며

LLM 추론에서 연산 퓨전은 커널 런치 오버헤드와 메모리 대역폭을 줄이는 핵심 최적화다. vLLM은 vllm/compilation/passes/fusion/ 디렉토리에서 torch.compile의 패턴 매칭 인프라를 활용한 다양한 퓨전 패스를 구현하고 있다. RMSNorm+양자화, AllReduce+RMSNorm, RoPE+KVCache 등 LLM 특화 퓨전을 제공한다.

공식 문서

vLLM 공식 문서: Fusions

핵심 구조/코드 분석

퓨전 패스 목록

vllm/compilation/passes/fusion/ 디렉토리에는 다음 퓨전 패스들이 있다:

  • rms_quant_fusion.py - RMSNorm + FP8 양자화 퓨전
  • allreduce_rms_fusion.py - AllReduce + RMSNorm 퓨전
  • rope_kvcache_fusion.py - RoPE + KV Cache 퓨전
  • act_quant_fusion.py - Activation + 양자화 퓨전
  • attn_quant_fusion.py - Attention + 양자화 퓨전
  • mla_attn_quant_fusion.py - MLA Attention + 양자화 퓨전
  • sequence_parallelism.py - 시퀀스 병렬화 퓨전
  • collective_fusion.py - 집합 통신 퓨전

RMSNormQuantFusionPass 상세

class RMSNormQuantFusionPass(VllmPatternMatcherPass):
    @enable_fake_mode
    def __init__(self, config: VllmConfig) -> None:
        self.patterns = PatternMatcherPass(pass_name="rmsnorm_quant_fusion_pass")
        for epsilon in [1e-5, 1e-6]:
            FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
            RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
            FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
            RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)

두 가지 epsilon 값(1e-5, 1e-6)에 대해 모든 패턴 조합을 등록한다. LLaMA 계열은 1e-5, 일부 모델은 1e-6을 사용하기 때문이다.

패턴 매칭 구조

class RMSNormStaticQuantPattern(RMSNormQuantPattern):
    def register(self, pm_pass: PatternMatcherPass) -> None:
        def pattern(input, weight, scale):
            result_rms = vllm.ir.ops.rms_norm(input, weight, self.epsilon)
            return self.quant_matcher(result_rms, scale)[0]

        def replacement(input, weight, scale):
            result = torch.empty(input.shape, device=input.device, dtype=self.quant_dtype)
            at = auto_functionalized(
                self.FUSED_OP,  # torch.ops._C.rms_norm_static_fp8_quant
                result=result, input=input, weight=weight,
                scale=scale, epsilon=self.epsilon,
            )
            return at[1]

        pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass,
                                extra_check=_rms_input_weight_dtype_match)

pattern 함수는 "이런 연산 패턴을 찾아라"를 정의하고, replacement는 "그걸 이렇게 바꿔라"를 정의한다. auto_functionalized로 fused C++ 커널을 호출한다.

Fused Add + RMSNorm + Group Quantization

FUSED_OPS = {
    FusedRMSQuantKey(kFp8StaticTensorSym, False): torch.ops._C.rms_norm_static_fp8_quant,
    FusedRMSQuantKey(kFp8StaticTensorSym, True): torch.ops._C.fused_add_rms_norm_static_fp8_quant,
    FusedRMSQuantKey(kFp8DynamicTokenSym, False): torch.ops._C.rms_norm_dynamic_per_token_quant,
    FusedRMSQuantKey(kFp8Dynamic128Sym, False): torch.ops._C.rms_norm_per_block_quant,
    ...
}

Static/Dynamic per-tensor, per-token, per-block(128/64) 양자화와 fused_add 유무를 조합한 매핑 테이블이다. 각 조합마다 전용 C++ 커널이 있다.

타입 검증

def _rms_input_weight_dtype_match(match: pm.Match) -> bool:
    for node in match.nodes:
        if node.target == _RMS_NORM_OP:
            x, weight = node.args[0], node.args[1]
            if isinstance(x, fx.Node) and isinstance(weight, fx.Node):
                return x.meta["val"].dtype == weight.meta["val"].dtype
    return True

RMSNorm의 입력과 가중치 dtype이 다르면 퓨전을 방지한다. 현재 fused 커널이 mixed dtype을 지원하지 않기 때문이다.

왜 이 설계인가

  1. 패턴 매칭 기반 접근: 수동으로 그래프를 순회하며 패턴을 찾는 대신, torch의 PatternMatcherPass 인프라를 활용한다. 이는 패턴 정의와 변환을 선언적으로 작성할 수 있게 해주며, 패턴 간 충돌도 자동으로 관리한다.

  2. Fused Add 우선 매칭: FusedAddRMSNorm 패턴을 일반 RMSNorm보다 먼저 등록한다. 일반 RMSNorm은 FusedAdd 패턴의 부분집합이므로, 순서가 반대면 더 큰 퓨전 기회를 놓치게 된다.

  3. E8M0 스케일 지원: NVFP4 같은 새로운 양자화 형식을 위해 E8M0 스케일, column-major 스케일, TMA 정렬 등의 옵션을 조합적으로 등록한다. 이는 하드웨어 발전에 따라 새로운 양자화 커널을 쉽게 추가할 수 있도록 하는 확장 가능한 설계다.

참고 자료

댓글

관련 포스트

vLLM 의 다른글