본문으로 건너뛰기

[sglang] Diffusion 모델용 Fused QKNorm+RoPE CUDA 커널 추가

PR 링크: sgl-project/sglang#21440 상태: Merged | 변경: +961 / -78

들어가며

Diffusion 모델의 attention 연산에서는 QKNorm(RMSNorm on Q/K)과 RoPE(Rotary Position Embedding)가 순차적으로 적용됩니다. 별도 커널로 실행하면 Q/K 텐서를 두 번 읽고 써야 하므로 메모리 대역폭이 병목이 됩니다. 이번 PR은 두 연산을 하나의 warp-level CUDA 커널로 융합하여 메모리 접근을 절반으로 줄입니다.

핵심 코드 분석

1. 커널 구조: Warp-level 처리

template <int64_t kHeadDim, int64_t kRopeDim, bool kIsNeox, bool kUsePDL,
          typename DType, typename IdType>
__global__ void fused_qknorm_rope_warp(const QKNormRopeParams params) {
    constexpr uint32_t kElemsPerThread = kHeadDim / kWarpThreads;
    constexpr uint32_t kRotaryLanes = kRopeDim / kElemsPerThread;

    // 각 warp가 하나의 (token, head) 조합을 처리
    const uint32_t num_qk_heads = num_qo_heads + num_kv_heads;
    for (uint32_t idx = start_worker_id; idx < num_works; idx += num_workers) {
        const uint32_t token_id = idx / num_qk_heads;
        const uint32_t head_id = idx % num_qk_heads;

        // 1. 데이터 로드
        auto input_vec = load_as<Storage>(input, lane_id);

        // 2. RMSNorm: warp reduce로 sum of squares 계산
        // 3. Normalize
        // 4. RoPE 적용 (같은 레지스터에서 in-place)
        // 5. 결과 쓰기
    }
}

한 warp(32 threads)가 하나의 head를 처리합니다. kHeadDim=128일 때 thread당 4개 원소를 담당하며, RMSNorm의 reduction과 RoPE의 회전을 레지스터에서 수행합니다.

2. 분리 실행 vs 융합 커널 비교

Before (분리 실행):

# 1단계: QKNorm (Q, K 각각 읽기/쓰기)
fused_inplace_qknorm(q, k, q_weight, k_weight)
# 2단계: RoPE (Q, K 다시 읽기/쓰기)
apply_rope_with_cos_sin_cache_inplace(positions, q, k, ...)

After (융합 커널):

# 1단계: QKNorm + RoPE (Q, K 한 번만 읽기/쓰기)
fused_inplace_qknorm_rope(q, k, q_weight, k_weight,
    cos_sin_cache, positions, is_neox=False, rope_dim=128)

3. 벤치마크 인프라

BENCH_CASES = (
    CaseSpec("flux_1024", 1, 4096, 24, 128, 128, False),
    CaseSpec("qwen_image_1024", 1, 4096, 32, 128, 128, False),
    CaseSpec("zimage_1024", 1, 4096, 30, 128, 128, False),
)

Flux, Qwen Image, Z-Image-Turbo 등 실제 Diffusion 모델 설정에 대한 벤치마크가 포함되어 있습니다.

왜 이게 좋은가

  1. 메모리 대역폭 절약: Q/K 텐서를 한 번만 읽고 쓰므로, 메모리 바운드 연산에서 ~2x 속도 향상이 가능합니다.
  2. 레지스터 재사용: RMSNorm 결과를 레지스터에 유지한 채 RoPE를 적용하여 중간 메모리 할당이 불필요합니다.
  3. PDL 지원: PDLWaitPrimary를 통해 Programmatic Dependent Launch가 가능하여, 커널 체인의 오버헤드를 줄입니다.

정리

Kernel fusion은 GPU 프로그래밍의 핵심 최적화 기법입니다. 두 개의 element-wise/reduction 연산을 하나로 합치면 메모리 접근 횟수가 반으로 줄어, 특히 대형 시퀀스 길이에서 큰 성능 향상을 얻습니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글