본문으로 건너뛰기

[vllm] [vLLM] W4A16 양자화 모델의 호환성 문제 해결: Triton 커널을 활용한 CUDA Fallback 구현

PR 링크: vllm-project/vllm#43731 상태: Merged | 변경: +3 / -2

들어가며: W4A16 양자화와 정렬(Alignment)의 딜레마

LLM 추론 최적화에서 W4A16(4-bit Weight, 16-bit Activation) 양자화는 메모리 대역폭을 절약하면서도 연산 정밀도를 유지할 수 있는 매우 효율적인 기법입니다. vLLM 프로젝트에서는 이를 위해 Marlin, Machete, Exllama 등 다양한 고성능 커널을 지원하고 있습니다.

하지만 고성능 커널들은 대개 하드웨어 가속을 극대화하기 위해 엄격한 정렬(Alignment) 제약을 가집니다. 예를 들어, Marlin 커널은 K 차원(입구 차원)이 128의 배수여야 한다는 제약이 있습니다. 문제는 최근 유행하는 MoE(Mixture of Experts) 모델이나 특정 아키텍처의 경우 intermediate_size가 2112나 704와 같이 128로 나누어떨어지지 않는 경우가 빈번하다는 점입니다.

이러한 경우, 기존 vLLM은 Ampere(SM80) 아키텍처에서 적절한 커널을 찾지 못해 ValueError를 발생시키며 실행이 중단되었습니다. 이번 포스트에서는 TritonW4A16LinearKernel을 CUDA 환경의 Fallback 커널로 도입하여 이 문제를 어떻게 우아하게 해결했는지 살펴보겠습니다.

문제의 핵심: 왜 특정 모델은 Ampere에서 실행되지 않았나?

PR 설명에 명시된 것처럼, Ampere GPU에서 W4A16 모델을 실행할 때 기존 커널들은 다음과 같은 한계를 가지고 있었습니다:

  1. Marlin: input_size_per_partition % 128 == 0 제약 (2112, 704 등에서 실패).
  2. CutlassW4A8 / Machete: SM90(Hopper) 이상의 최신 아키텍처 요구.
  3. Exllama: float16만 지원하며 bfloat16 활성화 함수 미지원.
  4. Conch / AllSpark: 특정 group_size 제약.

결과적으로, 128 정렬이 되지 않은 bfloat16 기반 W4A16 모델은 Ampere GPU에서 "실행 불가능"한 상태였습니다. 하지만 vLLM에는 이미 ROCm(AMD GPU)용으로 작성된 TritonW4A16LinearKernel이 존재했습니다. 이 커널은 Triton으로 작성되어 플랫폼 독립적이며, N % 8 == 0이라는 훨씬 완화된 제약 조건만을 가집니다.

코드 분석: Fallback 메커니즘 구현

이번 PR의 핵심은 기존에 ROCm 전용으로 묶여 있던 Triton 커널을 CUDA 환경에서도 사용할 수 있도록 개방하고, 이를 최하위 우선순위의 Fallback으로 배치한 것입니다.

1. CUDA 커널 리스트에 Triton 커널 추가

vllm/model_executor/kernels/linear/__init__.py 파일에서 CUDA 환경이 지원하는 커널 목록 끝에 TritonW4A16LinearKernel을 추가했습니다.

Before:

# vllm/model_executor/kernels/linear/__init__.py
PlatformEnum.CUDA: [
    MacheteLinearKernel,
    MarlinLinearKernel,
    ConchLinearKernel,
    ExllamaLinearKernel,
],

After:

# vllm/model_executor/kernels/linear/__init__.py
PlatformEnum.CUDA: [
    MacheteLinearKernel,
    MarlinLinearKernel,
    ConchLinearKernel,
    ExllamaLinearKernel,
    TritonW4A16LinearKernel, # 최하위 우선순위로 추가
],

여기서 중요한 점은 순서입니다. vLLM은 리스트의 앞쪽에 있는 커널부터 can_implement()를 호출하여 적합성을 판단합니다. Triton 커널을 맨 뒤에 배치함으로써, Marlin이나 Machete처럼 더 최적화된 커널이 처리할 수 있는 레이어는 기존 방식대로 고성능 커널이 처리하고, 오직 다른 모든 커널이 거부한 경우에만 Triton 커널이 선택되도록 설계되었습니다.

2. 플랫폼 게이트(Platform Gate) 해제

기존 커널 구현체에서는 ROCm 환경이 아니면 무조건 거절하도록 되어 있었습니다. 이를 CUDA까지 확장했습니다.

Before:

# vllm/model_executor/kernels/linear/mixed_precision/triton_w4a16.py
@classmethod
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
    if not current_platform.is_rocm():
        return False, "TritonW4A16LinearKernel only targets ROCm"
    # ... 생략

After:

# vllm/model_executor/kernels/linear/mixed_precision/triton_w4a16.py
@classmethod
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
    if not (current_platform.is_rocm() or current_platform.is_cuda()):
        return False, "TritonW4A16LinearKernel requires CUDA or ROCm"
    # ... 생략

Triton 언어 자체가 LLVM을 거쳐 PTX(NVIDIA)와 HSACO(AMD)를 모두 생성할 수 있기 때문에, 플랫폼 체크 로직만 수정하면 즉시 두 환경 모두에서 동작할 수 있습니다.

왜 이게 좋은 개선인가?

1. 견고성(Robustness) 확보

이 변경 전에는 특정 모델 형상에서 ValueError가 발생하며 서버가 죽는 치명적인 문제가 있었습니다. 최적화보다 중요한 것은 "일단 돌아가게 만드는 것"입니다. Triton 커널을 Fallback으로 활용함으로써 vLLM은 더 넓은 범위의 모델 형상을 지원할 수 있게 되었습니다.

2. 성능 저하 없는 기능 확장 (Zero Regression)

우선순위 기반의 커널 선택 로직 덕분에, 기존에 Marlin으로 잘 돌아가던 모델들은 성능 저하가 전혀 없습니다. 오직 "실행 불가능"했던 모델들만 Triton 커널을 타고 실행되므로, 사용자 입장에서는 손해 볼 것이 없는 업데이트입니다.

3. 코드 재사용성 극대화

플랫폼별로 별도의 커널을 작성하는 대신, 이미 검증된 Triton 커널의 플랫폼 제약을 풀어줌으로써 유지보수 비용을 최소화했습니다. 이는 "Write Once, Run Anywhere"라는 Triton의 철학을 잘 활용한 사례입니다.

테스트 결과 및 결론

테스트 결과, intermediate_size=2112인 MoE 모델이 A100(SM80)에서 성공적으로 로드되고 추론되는 것이 확인되었습니다. 이전에는 ValueError로 인해 시작조차 불가능했던 작업입니다.

# Before
ValueError: Failed to find a kernel that can implement the WNA16 linear layer

# After
Processed prompts: 100%|█████| 1/1 [00:05<00:00, 5.81s/it]
The capital of Brazil is Brasília.

이번 PR은 복잡한 알고리즘의 변경 없이도 시스템의 설계 구조(Priority-based selection)와 도구(Triton)의 특성을 잘 이해했을 때 얼마나 효과적인 개선이 가능한지를 보여줍니다. 시니어 엔지니어로서 우리는 항상 "가장 빠른 길"을 찾되, 그 길이 막혔을 때 안전하게 우회할 수 있는 "Fallback"을 설계하는 능력을 갖추어야 합니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글