[sglang] QKNorm Across Heads CUDA 커널 최적화: Q/K 분리로 레지스터 압력 해소
PR 링크: sgl-project/sglang#21503 상태: Merged | 변경: +32 / -81
들어가며
QKNorm은 Diffusion 모델의 attention에서 Query와 Key 텐서에 RMSNorm을 적용하는 연산입니다. 기존 CUDA 커널은 하나의 thread block이 Q와 K를 동시에 처리했는데, 이로 인해 레지스터와 shared memory 사용량이 두 배로 필요했습니다. 이번 PR은 2D grid를 활용하여 Q와 K를 별도 블록에서 처리하도록 개선합니다.
핵심 코드 분석
1. 변수 통합 및 분기 제거
Before:
vec_t v_q; // Save q
vec_t v_k; // Save k
vec_t v_q_weight; // Save q_weight
vec_t v_k_weight; // Save k_weight
vec_t v_q_out; // Save q output
vec_t v_k_out; // Save k output
float2 acc_square_q = make_float2(0.0f, 0.0f);
float2 acc_square_k = make_float2(0.0f, 0.0f);
After:
vec_t v_data;
vec_t v_weight;
const bool is_q = blockIdx.y == 0;
float2 acc_square = make_float2(0.0f, 0.0f);
vec_t* data = reinterpret_cast<vec_t*>(is_q ? q : k) + token_id * vec_hidden_size;
const vec_t* weight = reinterpret_cast<const vec_t*>(is_q ? q_weight : k_weight);
6개의 벡터 레지스터가 2개로 줄었습니다. blockIdx.y로 Q(0)와 K(1)를 구분하여 동일한 코드 경로로 처리합니다.
2. Shared memory 절반 감소
Before:
__shared__ float shared_memory[64]; // Q용 32 + K용 32
float* buffer_q = shared_memory; // [0, 31]
float* buffer_k = shared_memory + 32; // [32, 63]
After:
__shared__ float shared_memory[32]; // Q 또는 K 하나만
float* buffer = shared_memory;
3. CTA reduce 단순화
Before:
// Q용 CTA reduce
float cta_sum_q = cooperative_groups::reduce(...);
buffer_q[threadIdx.x] = rsqrtf(eps + cta_sum_q * ...);
// K용 CTA reduce (별도로 한 번 더)
float cta_sum_k = cooperative_groups::reduce(...);
buffer_k[threadIdx.x] = rsqrtf(eps + cta_sum_k * ...);
After:
float cta_sum = cooperative_groups::reduce(...);
if (threadIdx.x == 0) {
buffer[0] = rsqrtf(eps + cta_sum * inv_hidden_size);
}
reduce 결과를 warp 0의 모든 lane에 broadcast하는 대신, thread 0만 쓰고 __syncthreads() 후 모든 thread가 buffer[0]을 읽는 방식으로 변경했습니다.
4. Grid 차원 변경
// Before: 1D grid (token 수만큼)
LaunchKernel(static_cast<uint>(N.unwrap()), threads, device.unwrap())
// After: 2D grid (token x 2, Q와 K 분리)
LaunchKernel(dim3(static_cast<uint>(N.unwrap()), 2), threads, device.unwrap())
왜 이게 좋은가
- 레지스터 압력 감소: 레지스터 사용량이 절반으로 줄어 GPU SM당 더 많은 warp를 동시 실행할 수 있습니다(occupancy 향상).
- Shared memory 효율: 32 float만 사용하여 다른 커널과의 shared memory 경합이 줄어듭니다.
- 코드 단순화: 81줄 삭제, 32줄 추가로 코드가 더 읽기 쉬워졌습니다.
정리
2D grid 활용으로 Q/K 처리를 분리하는 고전적이지만 효과적인 최적화입니다. 하나의 커널이 너무 많은 일을 하면 레지스터 압력이 증가하므로, 작업을 분할하여 각 블록의 리소스 요구량을 줄이는 접근입니다.
참고 자료
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [llm-compressor] iMatrix Weighted MSE Observer - 중요도 행렬 기반 양자화
- 현재글 : [sglang] QKNorm Across Heads CUDA 커널 최적화: Q/K 분리로 레지스터 압력 해소
- 다음글 [sglang] Diffusion 모델용 Fused QKNorm+RoPE CUDA 커널 추가
댓글