[vllm] vLLM W8W8 그룹 양자화 성능 최적화: 2D-Grid를 통한 Divmod 제거
PR 링크: vllm-project/vllm#42153 상태: Merged | 변경: +0 / -0
들어가며
vLLM은 대규모 언어 모델(LLM) 서빙을 위한 고성능 추론 엔진으로, GPU 활용률을 극대화하여 처리량을 높이는 데 중점을 둡니다. 이 PR은 vLLM의 핵심 연산 중 하나인 W8W8(Weight 8-bit, Activation 8-bit) 그룹 양자화 커널의 성능을 최적화합니다. 특히, CUDA 커널 내에서 그룹 ID를 행렬 좌표로 변환하는 과정에서 발생하는 divmod (나눗셈과 나머지 연산)의 오버헤드를 제거하여 전반적인 처리 속도를 향상시키는 것이 목표입니다.
divmod 연산은 GPU에서 상대적으로 비용이 많이 드는 연산으로 알려져 있습니다. 특히, 컴파일 타임에 상수로 결정될 수 있는 값들에 대해 런타임에 divmod를 수행하는 것은 불필요한 성능 저하를 야기할 수 있습니다. 이 PR은 이러한 divmod 연산을 2D-grid와 템플릿 상수를 활용한 인덱스 언팩킹으로 대체하여, 컴파일러가 이를 간단한 비트 연산으로 최적화할 수 있도록 합니다. 이는 GPU 커널의 실행 효율을 높여 LLM 추론 성능을 개선하는 데 기여합니다.
코드 분석: 2D-Grid로의 전환
이번 PR의 핵심은 per_token_group_quant_8bit_packed_register_kernel CUDA 커널 내에서 global_group_id를 sf_k_idx와 mn_idx로 변환하는 로직을 변경하는 것입니다. 기존에는 1D global_group_id를 padded_groups_per_row로 나누고 나머지 연산을 수행하여 2D 좌표를 계산했습니다. 이제는 2D-grid를 사용하여 블록 인덱스(blockIdx.x, blockIdx.y)와 스레드 인덱스(threadIdx.x)를 조합하여 직접 2D 좌표를 계산합니다.
GetGroupsPerBlockX 함수 추가
가장 먼저, 새로운 헬퍼 함수 GetGroupsPerBlockX가 추가되었습니다. 이 함수는 padded_groups_per_row의 가장 큰 약수 중 16 이하인 값을 반환합니다. 이는 2D-grid의 X축 방향 블록당 그룹 수를 결정하는 데 사용됩니다. kx 값은 16, 8, 4 중 하나가 되며, 이는 ry = 16 / kx와 함께 블록 내 스레드 배치를 최적화하는 데 기여합니다.
Before:
--- a/csrc/libtorch_stable/quantization/w8a8/fp8/per_token_group_quant.cu
+++ b/csrc/libtorch_stable/quantization/w8a8/fp8/per_token_group_quant.cu
@@ -156,6 +156,17 @@ inline int GetGroupsPerBlock(int64_t num_groups) {
return 1;
}
+// Largest divisor of padded_groups_per_row that is <= 16. ry = 16 / kx.
+inline int GetGroupsPerBlockX(int64_t padded_groups_per_row) {
+ if (padded_groups_per_row % 16 == 0) {
+ return 16;
+ }
+ if (padded_groups_per_row % 8 == 0) {
+ return 8;
+ }
+ return 4;
+}
+
void per_token_group_quant_8bit(const torch::stable::Tensor& input,
torch::stable::Tensor& output_q,
torch::stable::Tensor& output_s,
커널 시그니처 및 인덱스 계산 변경
per_token_group_quant_8bit_packed_register_kernel 커널의 템플릿 인수에 kGroupsPerBlockX와 kRowsPerBlock이 추가되었습니다. 이들은 컴파일 타임 상수로, 2D-grid의 블록 크기를 정의합니다. 또한, 커널 인자에서 num_groups_padded와 groups_per_block이 제거되고, padded_groups_per_row만 남게 됩니다.
커널 내부에서는 global_group_id를 직접 계산하는 대신, blockIdx.x, blockIdx.y, local_group_id를 사용하여 sf_k_idx와 mn_idx를 직접 계산합니다. 이 과정에서 divmod 연산이 사라지고, 템플릿 상수를 활용한 간단한 산술 연산으로 대체됩니다.
Before:
--- a/csrc/libtorch_stable/quantization/w8a8/fp8/per_token_group_quant.cu
+++ b/csrc/libtorch_stable/quantization/w8a8/fp8/per_token_group_quant.cu
@@ -247,11 +258,11 @@ void per_token_group_quant_8bit(const torch::stable::Tensor& input,
// Constraints: GROUP_SIZE % (THREADS_PER_GROUP * VEC_SIZE) == 0; for
// THREADS_PER_GROUP=8 and bf16/fp16 (VEC_SIZE=16), this means GROUP_SIZE=128.
-template <typename T, typename DST_DTYPE, int GROUP_SIZE>
+template <typename T, typename DST_DTYPE, int GROUP_SIZE, int kGroupsPerBlockX,
+ int kRowsPerBlock>
__global__ void per_token_group_quant_8bit_packed_register_kernel(
const T* __restrict__ input, void* __restrict__ output_q,
- unsigned int* __restrict__ output_s_packed, const int64_t num_groups_padded,
- const int groups_per_block, const int padded_groups_per_row,
+ unsigned int* __restrict__ output_s_packed, const int padded_groups_per_row,
const int groups_per_row, const int mn, const int output_q_mn_extent,
const int tma_aligned_mn, const int64_t num_scale_elems, const float eps,
const float min_8bit, const float max_8bit) {
@@ -260,27 +271,25 @@ __global__ void per_token_group_quant_8bit_packed_register_kernel(
constexpr int VEC_SIZE = 32 / sizeof(T); // 16 for bf16/fp16
static_assert(GROUP_SIZE == THREADS_PER_GROUP * VEC_SIZE,
"GROUP_SIZE must equal THREADS_PER_GROUP * VEC_SIZE");
- // Each group's 8 threads must live in a single warp octet so the
- // 0xffu << (threadIdx.x & 24u) shuffle mask selects exactly the lanes
- // that share a group. Requires 32 % THREADS_PER_GROUP == 0 and the host
- // to launch num_threads as a multiple of THREADS_PER_GROUP (which it does
- // via num_threads = groups_per_block * THREADS_PER_GROUP).
static_assert(32 % THREADS_PER_GROUP == 0,
"THREADS_PER_GROUP must divide warp size for the shuffle "
"mask to be valid");
+ static_assert(
+ kGroupsPerBlockX > 0 && (kGroupsPerBlockX & (kGroupsPerBlockX - 1)) == 0,
+ "kGroupsPerBlockX must be a positive power of 2");
+ static_assert(kRowsPerBlock > 0, "kRowsPerBlock must be positive");
const int local_group_id = threadIdx.x / THREADS_PER_GROUP;
const int lane_id = threadIdx.x % THREADS_PER_GROUP;
- const int64_t block_group_id = blockIdx.x * groups_per_block;
- const int64_t global_group_id = block_group_id + local_group_id;
- if (global_group_id >= num_groups_padded) {
+ const int sf_k_local = local_group_id % kGroupsPerBlockX;
+ const int row_local = local_group_id / kGroupsPerBlockX;
+ const int sf_k_idx = blockIdx.x * kGroupsPerBlockX + sf_k_local;
+ const int mn_idx = blockIdx.y * kRowsPerBlock + row_local;
+
+ if (mn_idx >= tma_aligned_mn) {
return;
}
-
- const int sf_k_idx =
- static_cast<int>(global_group_id % padded_groups_per_row);
- const int mn_idx = static_cast<int>(global_group_id / padded_groups_per_row);
const bool is_valid_group = (mn_idx < mn) && (sf_k_idx < groups_per_row);
// Load 16 input elements (32 B) into registers as two adjacent uint4
커널 런칭 로직 변경
호스트 코드(per_token_group_quant_8bit_packed 함수)에서 커널을 런칭하는 방식도 변경되었습니다. 기존에는 1D num_blocks를 계산하여 grid.x에 할당했지만, 이제는 blocks_x와 blocks_y를 계산하여 2D dim3 grid를 구성합니다. kx와 ry 값에 따라 적절한 템플릿 인수를 사용하여 커널을 인스턴스화하는 LAUNCH_REG_KERNEL_INST 매크로가 도입되었습니다.
Before:
--- a/csrc/libtorch_stable/quantization/w8a8/fp8/per_token_group_quant.cu
+++ b/csrc/libtorch_stable/quantization/w8a8/fp8/per_token_group_quant.cu
@@ -443,34 +452,53 @@ void per_token_group_quant_8bit_packed(const torch::stable::Tensor& input,
constexpr int THREADS_PER_GROUP = 8;
const int64_t padded_groups_per_row = k_num_packed_sfk * 4;
- const int64_t num_groups_padded = tma_aligned_mn * padded_groups_per_row;
const int64_t num_scale_elems = mn + (k_num_packed_sfk - 1) * tma_aligned_mn;
- const int groups_per_block = GetGroupsPerBlock(num_groups_padded);
+
+ STD_TORCH_CHECK(padded_groups_per_row % 4 == 0,
+ "padded_groups_per_row=", padded_groups_per_row,
+ " is not a multiple of 4.");
+ const int kx = GetGroupsPerBlockX(padded_groups_per_row);
+ const int ry = 16 / kx;
+ const int64_t blocks_x = padded_groups_per_row / kx;
+ const int64_t blocks_y = (tma_aligned_mn + ry - 1) / ry;
+ const int num_threads = (kx * ry) * THREADS_PER_GROUP;
+ // CUDA caps grid.x and grid.y at 2^31 - 1; guard against pathological inputs.
+ STD_TORCH_CHECK(blocks_x <= static_cast<int64_t>(INT32_MAX) &&
+ blocks_y <= static_cast<int64_t>(INT32_MAX),
+ "per_token_group_quant_8bit_packed grid too large: (",
+ blocks_x, ", ", blocks_y, ").");
auto dst_type = output_q.scalar_type();
- const int64_t num_blocks = num_groups_padded / groups_per_block;
- const int num_threads = groups_per_block * THREADS_PER_GROUP;
- // CUDA caps grid.x at 2^31 - 1; this fits any realistic shape but guard
- // against pathological inputs.
- STD_TORCH_CHECK(num_blocks <= static_cast<int64_t>(INT32_MAX),
- "per_token_group_quant_8bit_packed grid too large: ",
- num_blocks, " blocks (max ", INT32_MAX, ").");
-
-#define LAUNCH_REG_KERNEL(T, DST_DTYPE)
- do {
- dim3 grid(static_cast<unsigned int>(num_blocks));
- dim3 block(num_threads);
- per_token_group_quant_8bit_packed_register_kernel<T, DST_DTYPE, 128>
- <<<grid, block, 0, stream>>>(
- static_cast<const T*>(input.data_ptr()), output_q.data_ptr(),
- reinterpret_cast<unsigned int*>(output_s_packed.data_ptr()),
- num_groups_padded, groups_per_block,
- static_cast<int>(padded_groups_per_row),
- static_cast<int>(groups_per_row), static_cast<int>(mn),
- static_cast<int>(output_q_mn_extent),
- static_cast<int>(tma_aligned_mn), num_scale_elems,
- static_cast<float>(eps), static_cast<float>(min_8bit),
- static_cast<float>(max_8bit));
+```
**After:**
```diff
--- a/csrc/libtorch_stable/quantization/w8a8/fp8/per_token_group_quant.cu
+++ b/csrc/libtorch_stable/quantization/w8a8/fp8/per_token_group_quant.cu
@@ -443,34 +452,53 @@ void per_token_group_quant_8bit(const torch::stable::Tensor& input,
constexpr int THREADS_PER_GROUP = 8;
const int64_t padded_groups_per_row = k_num_packed_sfk * 4;
- const int64_t num_groups_padded = tma_aligned_mn * padded_groups_per_row;
const int64_t num_scale_elems = mn + (k_num_packed_sfk - 1) * tma_aligned_mn;
- const int groups_per_block = GetGroupsPerBlock(num_groups_padded);
+
+ STD_TORCH_CHECK(padded_groups_per_row % 4 == 0,
+ "padded_groups_per_row=", padded_groups_per_row,
+ " is not a multiple of 4.");
+ const int kx = GetGroupsPerBlockX(padded_groups_per_row);
+ const int ry = 16 / kx;
+ const int64_t blocks_x = padded_groups_per_row / kx;
+ const int64_t blocks_y = (tma_aligned_mn + ry - 1) / ry;
+ const int num_threads = (kx * ry) * THREADS_PER_GROUP;
+ // CUDA caps grid.x and grid.y at 2^31 - 1; guard against pathological inputs.
+ STD_TORCH_CHECK(blocks_x <= static_cast<int64_t>(INT32_MAX) &&
+ blocks_y <= static_cast<int64_t>(INT32_MAX),
+ "per_token_group_quant_8bit_packed grid too large: (",
+ blocks_x, ", ", blocks_y, ").");
auto dst_type = output_q.scalar_type();
- const int64_t num_blocks = num_groups_padded / groups_per_block;
- const int num_threads = groups_per_block * THREADS_PER_GROUP;
- // CUDA caps grid.x at 2^31 - 1; this fits any realistic shape but guard
- // against pathological inputs.
- STD_TORCH_CHECK(num_blocks <= static_cast<int64_t>(INT32_MAX),
- "per_token_group_quant_8bit_packed grid too large: ",
- num_blocks, " blocks (max ", INT32_MAX, ").");
-
-#define LAUNCH_REG_KERNEL(T, DST_DTYPE)
- do {
- dim3 grid(static_cast<unsigned int>(num_blocks));
- dim3 block(num_threads);
- per_token_group_quant_8bit_packed_register_kernel<T, DST_DTYPE, 128>
- <<<grid, block, 0, stream>>>(
- static_cast<const T*>(input.data_ptr()), output_q.data_ptr(),
- reinterpret_cast<unsigned int*>(output_s_packed.data_ptr()),
- num_groups_padded, groups_per_block,
- static_cast<int>(padded_groups_per_row),
- static_cast<int>(groups_per_row), static_cast<int>(mn),
- static_cast<int>(output_q_mn_extent),
- static_cast<int>(tma_aligned_mn), num_scale_elems,
- static_cast<float>(eps), static_cast<float>(min_8bit),
- static_cast<float>(max_8bit));
+```
**After:**
```diff
--- a/csrc/libtorch_stable/quantization/w8a8/fp8/per_token_group_quant.cu
+++ b/csrc/libtorch_stable/quantization/w8a8/fp8/per_token_group_quant.cu
@@ -475,6 +475,20 @@ void per_token_group_quant_8bit_packed(const torch::stable::Tensor& input,
static_cast<float>(max_8bit));
} while (0)
+#define LAUNCH_REG_KERNEL(T, DST_DTYPE) \
+ do { \
+ if (kx == 16) { \
+ LAUNCH_REG_KERNEL_INST(T, DST_DTYPE, 16, 1); \
+ } else if (kx == 8) { \
+ LAUNCH_REG_KERNEL_INST(T, DST_DTYPE, 8, 2); \
+ } else if (kx == 4) { \
+ LAUNCH_REG_KERNEL_INST(T, DST_DTYPE, 4, 4); \
+ } else { \
+ STD_TORCH_CHECK(false, "Unsupported kx value ", kx); \
+ } \
+ } while (0)
+
VLLM_STABLE_DISPATCH_HALF_TYPES(
input.scalar_type(), "per_token_group_quant_8bit_packed", ([&] {
if (dst_type == torch::kInt8) {
왜 이게 좋은 최적화인가?
이 최적화는 GPU 커널의 성능에 여러 가지 긍정적인 영향을 미칩니다.
-
divmod연산 제거: GPU에서 정수 나눗셈과 나머지 연산은 일반적으로 비트 시프트나 마스킹과 같은 비트 연산보다 훨씬 느립니다. 특히,padded_groups_per_row와 같이 컴파일 타임에 상수로 결정될 수 있는 값에 대한divmod는 컴파일러가 최적화하기 어렵습니다. 2D-grid와 템플릿 상수를 사용하면, 이러한 연산이 컴파일 시점에 상수 폴딩되거나 효율적인 비트 연산으로 대체될 수 있습니다. 이는 GPU의 ALU(Arithmetic Logic Unit) 부담을 줄여 커널 실행 속도를 향상시킵니다. -
컴파일 타임 상수 활용:
kGroupsPerBlockX와kRowsPerBlock을 템플릿 인수로 전달함으로써, 이 값들이 커널 내부에서 컴파일 타임 상수로 활용됩니다. 컴파일러는 이러한 상수를 사용하여 코드 경로를 최적화하고, 런타임에 계산될 필요가 없는 분기나 연산을 제거할 수 있습니다. 이는 더 효율적인 기계어 코드를 생성하여 성능을 높입니다. -
최적화된 스레드 블록 구성:
GetGroupsPerBlockX함수를 통해padded_groups_per_row에 따라kx값을 동적으로 선택하고, 이를 기반으로 2D-grid를 구성하는 것은 GPU의 스레드 블록(thread block) 및 워프(warp) 스케줄링을 최적화하는 데 도움이 됩니다. 이는 GPU의 병렬 처리 능력을 더 효과적으로 활용하게 합니다.
성능 수치
PR 설명에 포함된 마이크로 벤치마크 결과는 이 최적화의 효과를 명확히 보여줍니다. 특히 K와 M 값이 큰 경우, 최대 1.099배의 속도 향상을 보였습니다.
| K | M | baseline us | opt us | speedup |
|---|---|---|---|---|
| 7168 | 8192 | 36.95 | 33.62 | 1.099x |
| 7168 | 1024 | 5.68 | 5.26 | 1.080x |
| 2048 | 8192 | 10.84 | 10.05 | 1.079x |
| 1536 | 8192 | 8.55 | 7.96 | 1.075x |
| 384 | 8192 | 4.17 | 3.94 | 1.057x |
이러한 수치는 divmod 연산이 GPU 성능에 미치는 영향이 상당하며, 이를 제거하는 것이 큰 성능 개선으로 이어진다는 것을 입증합니다. 작은 M 값에서는 속도 향상이 미미하지만, 이는 전체 연산 시간이 짧아 divmod 연산의 상대적 비중이 줄어들기 때문으로 해석할 수 있습니다.
또한, jiahanc 님의 리뷰 댓글에서 제공된 DeepSeek V4 Flash TP4 모델에 대한 E2E 벤치마크 결과도 긍정적인 영향을 보여줍니다. Concurrency 1024 환경에서 Output token throughput이 14418.23 tok/s에서 14505.22 tok/s로 소폭 상승했으며, Mean TPOT(Time per Output Token)도 67.08ms에서 66.66ms로 개선되었습니다. 이는 마이크로 벤치마크만큼 극적인 수치는 아니지만, 실제 LLM 서빙 환경에서도 전반적인 처리량 개선에 기여함을 시사합니다.
일반적 교훈
이 PR은 고성능 컴퓨팅, 특히 GPU 프로그래밍에서 중요한 몇 가지 교훈을 제공합니다.
-
divmod연산의 주의: GPU 커널에서divmod연산은 가능한 한 피해야 합니다. 특히, 피연산자가 컴파일 타임에 상수로 알려진 경우, 이를 비트 연산이나 템플릿 메타프로그래밍을 통해 대체하는 방법을 적극적으로 고려해야 합니다. -
컴파일 타임 상수 활용의 중요성: 템플릿 인수를 통해 커널에 상수를 전달하면, 컴파일러가 더 많은 최적화를 수행할 수 있습니다. 이는 런타임 오버헤드를 줄이고 더 효율적인 코드를 생성하는 데 필수적입니다.
-
GPU Grid/Block 구성의 최적화: GPU 커널의 성능은 grid와 block의 구성에 크게 좌우됩니다. 데이터 접근 패턴과 연산 특성에 맞춰 1D-grid 대신 2D-grid 또는 3D-grid를 사용하고, 블록 내 스레드 배치를 신중하게 설계하는 것이 중요합니다. 이는 메모리 접근 패턴, 워프 스케줄링, 리소스 활용률에 직접적인 영향을 미칩니다.
-
마이크로 벤치마크의 가치: 특정 코드 변경이 전체 시스템에 미치는 영향을 정확히 파악하기 어렵더라도, 핵심 연산에 대한 마이크로 벤치마크는 최적화의 효과를 정량적으로 측정하고 방향성을 제시하는 데 매우 유용합니다.
이러한 원칙들은 vLLM과 같이 성능에 민감한 시스템을 개발할 때 항상 염두에 두어야 할 사항들입니다. 작은 최적화가 모여 전체 시스템의 효율성을 크게 향상시킬 수 있음을 보여주는 좋은 예시입니다.
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
- [vllm] Blackwell을 위한 새로운 MLA 백엔드: TOKENSPEED_MLA 분석 (DeepSeek R1 최적화)
- [vllm] vLLM의 MLA 성능 극대화: RoPE, KV Cache, q_concat 연산 퓨전 최적화
- [vllm] vLLM DeepSeek v4 Fused Indexer Q 양자화 커널 최적화: CuteDSL을 활용한 성능 향상
- [vllm] vLLM, Gemma 4 모델에 양자화된 Speculative Decoding 적용: 성능 향상의 비밀
- [flashinfer] FlashInfer, CUDA 그래프 호환성을 높이고 성능을 최적화하다: TRT-LLM FMHA v2 통합 및 불필요한 H2D 제거
PR Analysis 의 다른글
- 이전글 [sglang] NPU 성능 향상을 위한 causal_conv1d_update_v2 도입
- 현재글 : [vllm] vLLM W8W8 그룹 양자화 성능 최적화: 2D-Grid를 통한 Divmod 제거
- 다음글 [sglang] DeepseekV4 모델의 입력 레이어 정규화와 FP8 양자화를 융합하여 성능 최적화
댓글