본문으로 건너뛰기

[sglang] QKNorm Across Heads CUDA 커널 최적화: Q/K 분리로 레지스터 압력 해소

PR 링크: sgl-project/sglang#21503 상태: Merged | 변경: +32 / -81

들어가며

QKNorm은 Diffusion 모델의 attention에서 Query와 Key 텐서에 RMSNorm을 적용하는 연산입니다. 기존 CUDA 커널은 하나의 thread block이 Q와 K를 동시에 처리했는데, 이로 인해 레지스터와 shared memory 사용량이 두 배로 필요했습니다. 이번 PR은 2D grid를 활용하여 Q와 K를 별도 블록에서 처리하도록 개선합니다.

핵심 코드 분석

1. 변수 통합 및 분기 제거

Before:

vec_t v_q;         // Save q
vec_t v_k;         // Save k
vec_t v_q_weight;  // Save q_weight
vec_t v_k_weight;  // Save k_weight
vec_t v_q_out;     // Save q output
vec_t v_k_out;     // Save k output

float2 acc_square_q = make_float2(0.0f, 0.0f);
float2 acc_square_k = make_float2(0.0f, 0.0f);

After:

vec_t v_data;
vec_t v_weight;
const bool is_q = blockIdx.y == 0;

float2 acc_square = make_float2(0.0f, 0.0f);
vec_t* data = reinterpret_cast<vec_t*>(is_q ? q : k) + token_id * vec_hidden_size;
const vec_t* weight = reinterpret_cast<const vec_t*>(is_q ? q_weight : k_weight);

6개의 벡터 레지스터가 2개로 줄었습니다. blockIdx.y로 Q(0)와 K(1)를 구분하여 동일한 코드 경로로 처리합니다.

2. Shared memory 절반 감소

Before:

__shared__ float shared_memory[64];  // Q용 32 + K용 32
float* buffer_q = shared_memory;       // [0, 31]
float* buffer_k = shared_memory + 32;  // [32, 63]

After:

__shared__ float shared_memory[32];  // Q 또는 K 하나만
float* buffer = shared_memory;

3. CTA reduce 단순화

Before:

// Q용 CTA reduce
float cta_sum_q = cooperative_groups::reduce(...);
buffer_q[threadIdx.x] = rsqrtf(eps + cta_sum_q * ...);

// K용 CTA reduce (별도로 한 번 더)
float cta_sum_k = cooperative_groups::reduce(...);
buffer_k[threadIdx.x] = rsqrtf(eps + cta_sum_k * ...);

After:

float cta_sum = cooperative_groups::reduce(...);
if (threadIdx.x == 0) {
    buffer[0] = rsqrtf(eps + cta_sum * inv_hidden_size);
}

reduce 결과를 warp 0의 모든 lane에 broadcast하는 대신, thread 0만 쓰고 __syncthreads() 후 모든 thread가 buffer[0]을 읽는 방식으로 변경했습니다.

4. Grid 차원 변경

// Before: 1D grid (token 수만큼)
LaunchKernel(static_cast<uint>(N.unwrap()), threads, device.unwrap())

// After: 2D grid (token x 2, Q와 K 분리)
LaunchKernel(dim3(static_cast<uint>(N.unwrap()), 2), threads, device.unwrap())

왜 이게 좋은가

  1. 레지스터 압력 감소: 레지스터 사용량이 절반으로 줄어 GPU SM당 더 많은 warp를 동시 실행할 수 있습니다(occupancy 향상).
  2. Shared memory 효율: 32 float만 사용하여 다른 커널과의 shared memory 경합이 줄어듭니다.
  3. 코드 단순화: 81줄 삭제, 32줄 추가로 코드가 더 읽기 쉬워졌습니다.

정리

2D grid 활용으로 Q/K 처리를 분리하는 고전적이지만 효과적인 최적화입니다. 하나의 커널이 너무 많은 일을 하면 레지스터 압력이 증가하므로, 작업을 분할하여 각 블록의 리소스 요구량을 줄이는 접근입니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글