[vllm] AMD ROCm을 위한 Triton 기반 W4A16 커널 도입: MI300X 성능 최적화 분석
PR 링크: vllm-project/vllm#37352 상태: Merged | 변경: +None / -None
들어가며
LLM 추론 엔진인 vLLM 프로젝트에 AMD ROCm 플랫폼을 위한 새로운 Triton 기반 W4A16(Weight 4-bit, Activation 16-bit) 선형 커널이 도입되었습니다. 기존에는 AMD 환경에서 Conch나 ExLlama 같은 커널을 사용해왔으나, 이번 PR을 통해 Triton 기반의 고성능 커널이 추가되면서 특히 Batch Size가 큰 상황에서의 처리량(Throughput)이 획기적으로 개선되었습니다.
이 PR은 AMD MI300X(gfx942) 하드웨어에서 INT4 가중치와 FP16 활성화 함수를 사용하는 모델의 추론 속도를 높이는 것을 목표로 하며, 대칭(Symmetric) 및 비대칭(Asymmetric) 양자화를 모두 지원합니다.
코드 분석: 핵심 변경 사항
1. Triton 기반 GEMM 커널 구현 및 검증
가장 핵심적인 변화는 triton_w4a16_gemm 함수의 도입입니다. 이 함수는 가중치를 4비트로 압축하여 메모리 대역폭을 절약하면서도, 연산은 16비트 정밀도를 유지하여 정확도 손실을 최소화합니다.
테스트 코드에서는 다음과 같이 레퍼런스 구현과 Triton 커널의 결과를 비교하여 수치적 정확성을 검증합니다.
# Before: 레퍼런스 구현 (PyTorch 기반)
def _w4a16_reference(a_mk, b_packed_kn8, scales_gn, ...):
# ... 가중치 언패킹 및 FP32 변환 후 행렬 곱셈 수행 ...
w_fp = (w_int4 - z_full).to(torch.float32) * s_full
out = a_mk.to(torch.float32) @ w_fp
return out.to(a_mk.dtype)
# After: Triton 커널 호출
out = triton_w4a16_gemm(
a=a,
b_q=b_packed,
scales=scales,
qzeros=qzeros,
group_size=G,
zp_bias=8,
)
2. 가중치 레이아웃 재포장 (Weight Repacking)
Compressed-Tensors(CT) 체크포인트의 가중치 레이아웃은 Triton 커널이 기대하는 레이아웃과 다를 수 있습니다. 이를 위해 로딩 후 가중치를 커널 최적화에 적합한 형태로 변환하는 로직이 추가되었습니다.
# After: 가중치 로딩 후 레이아웃 변환 로직
def test_triton_w4a16_process_weights_after_loading_repacks_layout():
# ... (생략) ...
# CT 체크포인트 레이아웃 [N, K//8] -> Triton 커널용 [K, N//8] 변환 검증
kernel.process_weights_after_loading(layer)
assert tuple(layer.weight_packed.shape) == (K, N // 8)
assert tuple(layer.weight_scale.shape) == (K // G, N)
3. 하드웨어별 최적화 (MI300 vs RDNA 3.5)
리뷰 과정에서 AMD 엔지니어(@mgehre-amd)의 제안으로 RDNA 3.5(Strix Halo) 아키텍처를 위한 튜닝이 포함되었습니다. ROCm 플랫폼에서는 device_capability 대신 GPU 아키텍처 이름을 직접 확인하는 방식을 권장하며, 이에 따라 아키텍처별로 최적의 블록 크기를 설정합니다.
왜 이 최적화가 좋은가?
1. 압도적인 Throughput 향상
벤치마크 결과에 따르면, Granite 3.1-8B 모델 기준 Batch Size 32에서 기존 Conch 커널 대비 약 16~22%, ExLlama 대비 최대 122%의 성능 향상을 보였습니다.
| Batch | TritonW4A16 | Conch | ExLlama | Triton vs ExLlama |
|---|---|---|---|---|
| 1 | 37.1 tok/s | 31.7 tok/s | 40.3 tok/s | -8% |
| 32 | 1085.1 tok/s | 932.2 tok/s | 489.0 tok/s | +122% |
2. Triton의 확장성과 유지보수성
ExLlama는 C++/HIP으로 작성되어 커널 런칭 오버헤드가 적어 Batch 1(Latency 위주)에서는 유리하지만, Triton은 파이썬 기반으로 작성되어 유지보수가 쉽고 다양한 하드웨어 아키텍처에 맞춰 컴파일 타임 최적화가 가능합니다. 이번 PR에서도 RDNA 3.5와 MI300X를 위한 최적화 코드가 하나의 Triton 커널 내에서 깔끔하게 공존할 수 있었습니다.
3. 엄격한 데이터 정렬 및 연속성 보장
리뷰어 @tjtanaa의 피드백에 따라, 입력 텐서의 연속성(contiguous)을 강제하고 N % 8 == 0과 같은 제약 조건을 명시적으로 체크하도록 수정되었습니다. 이는 런타임에 발생할 수 있는 정의되지 않은 동작(Undefined Behavior)을 방지하는 중요한 안전장치입니다.
결론
이번 TritonW4A16LinearKernel의 도입은 AMD ROCm 사용자들에게 큰 선물과 같습니다. 특히 대규모 서빙 환경(High Batch)에서 Triton의 강력한 최적화 능력을 다시 한번 입증했습니다. vLLM은 이제 NVIDIA뿐만 아니라 AMD 하드웨어에서도 최상위권의 성능을 낼 수 있는 인프라를 더욱 공고히 다지게 되었습니다.
참고 자료
- https://pytorch.org/docs/stable/generated/torch.Tensor.contiguous.html
- https://triton-lang.org/main/index.html
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [cpython] CPython JIT 최적화: MAKE_FUNCTION의 불필요한 참조 카운팅 제거
- 현재글 : [vllm] AMD ROCm을 위한 Triton 기반 W4A16 커널 도입: MI300X 성능 최적화 분석
- 다음글 [vllm] vLLM Nemotron Nano VL: Pixel Shuffle 최적화를 통한 성능 향상 분석
댓글