본문으로 건너뛰기

[sglang] JIT RMSNorm 커널 업데이트 - Blackwell 최적화 및 벤치마크 통합

PR 링크: sgl-project/sglang#21834 상태: Merged | 변경: +322 / -394

들어가며

SGLang의 JIT RMSNorm 커널은 hidden dimension 크기에 따라 단일 CTA(Cooperative Thread Array)로 처리하는 고성능 커널이다. 기존 구현은 hidden_size <= 8192만 지원했고, 벤치마크가 3개 파일로 분산되어 있었다. 이 PR은 16384까지 지원하는 새로운 커널 변형을 추가하고, 벤치마크를 하나로 통합하며, L2 캐시 효과를 방지하는 multi-layer 벤치마크 기법을 도입한다.

핵심 코드 분석

1. Pre-Blackwell: Double Load 커널 (hidden_size > 8192)

Before:

// hidden_size > 8192는 지원하지 않음
// 기존 rmsnorm_cta는 스레드당 1회 load/store

After:

template <int64_t kDim, bool kUsePDL, typename Float>
__global__ __launch_bounds__(kDim / 16) void rmsnorm_cta_double(
    const RMSNormParams __grid_constant__ params) {
  using Storage = AlignedVector<Float2, 4>;  // 16B vector

  // 스레드당 2회 load → sum_of_squares 누적
  const auto input_first = gmem.load(input_ptr, 0);
  const auto input_second = gmem.load(input_ptr, 1);

  // 두 벡터에 대해 각각 norm 적용 후 store
  gmem.store(output_ptr, output_first, 0);
  gmem.store(output_ptr, output_second, 1);
}

각 스레드가 16B 벡터를 2번 로드하여 hidden_size 16384까지 처리 가능. Warp shuffle + shared memory로 reduce sum을 수행한다.

2. Blackwell: Wide Vector 커널

After:

template <int64_t kDim, bool kUsePDL, typename Float>
__global__ __launch_bounds__(kDim / 16) void rmsnorm_cta_wide(
    const RMSNormParams __grid_constant__ params) {
  using Storage = AlignedVector<Float2, 8>;  // 32B vector

  // 스레드당 1회 load이지만 32B로 2배 데이터 처리
  const auto input_vec = gmem.load(input_ptr);
  // 8개 Float2 쌍에 대해 norm 적용
  gmem.store(output_ptr, output_vec);
}

Blackwell GPU의 32B 메모리 접근 폭을 활용하여 단일 로드로 16개 요소를 처리한다.

3. 벤치마크 통합 및 L2 캐시 효과 방지

Before:

# bench_rmsnorm.py, bench_fused_add_rmsnorm.py, bench_norm.py 3개 파일
input = torch.randn((batch_size, hidden_size), ...)
fn = lambda: jit_rmsnorm(input.clone(), weight)

After:

# bench_norm.py 하나로 통합
NUM_LAYERS = 4  # L2 캐시 효과 방지
input = torch.randn((NUM_LAYERS, batch_size, hidden_size), ...)
def f():
    for i in range(NUM_LAYERS):
        fn(input[i], weight[i], out=input[i])
return run_benchmark(f, scale=NUM_LAYERS)

4개 layer를 순회하며 벤치마크하고 scale 파라미터로 나누어 per-layer 시간을 측정한다. 이 방식은 실제 모델 추론 시의 L2 캐시 상태를 더 현실적으로 반영한다.

왜 이게 좋은가

  • 지원 범위 확장: hidden_size 16384까지 지원하여 대형 모델(예: Llama 405B의 hidden_size=16384) 커버
  • 아키텍처별 최적화: Pre-Blackwell은 double load, Blackwell은 wide vector로 각각 최적 경로 제공
  • 벤치마크 신뢰도 향상: L2 캐시 오염을 방지하는 multi-layer 벤치마크로 실제 성능에 가까운 측정
  • 코드 정리: 3개로 분산된 벤치마크 파일을 1개로 통합 (순감소 -72줄)

정리

이 PR은 JIT RMSNorm 커널의 커버리지와 성능을 모두 개선한다. hidden_size 제약을 2배로 늘리면서 Blackwell GPU에서의 메모리 대역폭 활용을 극대화하고, 벤치마크 방법론도 더 현실적인 방향으로 개선했다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글