본문으로 건너뛰기

[sglang] SGLang LTX-2.3 Diffusion 모델 최적화: Residual-Gate 연산 CUDA Fast Path 도입

PR 링크: sgl-project/sglang#29361 상태: Merged | 변경: +672 / -10

들어가며

SGLang은 대규모 언어 모델(LLM) 및 멀티모달 모델의 추론 성능을 최적화하는 데 중점을 둔 프레임워크입니다. 특히, LTX-2.3과 같은 Diffusion 모델은 고해상도 이미지 및 비디오 생성을 위해 복잡하고 반복적인 연산을 수행하며, 이 과정에서 성능 병목이 발생하기 쉽습니다. 이번 PR은 LTX-2.3 모델의 핵심 연산 중 하나인 residual-gate update 패턴에 대한 성능 최적화를 목표로 합니다.

residual-gate updatehidden_states = hidden_states + update * gate 형태의 연산으로, Diffusion Transformer (DiT) 블록 내의 residual, attention, MLP 업데이트에서 빈번하게 사용됩니다. 기존에는 PyTorch eager 모드나 Triton 커널을 통해 처리되었으나, LTX-2.3 HQ 모델에서 발생하는 [1, 32640, 4096]과 같은 대규모 broadcast gate 또는 [1, 8160, 4096]과 같은 full gate 텐서 처리 시 성능 저하가 관찰되었습니다. 이 PR은 이러한 "핫 패턴"에 대해 네이티브 CUDA JIT 커널을 도입하여 성능 병목을 해결합니다.

코드 분석: 무엇이 어떻게 개선되었나

이 PR의 핵심은 diffusion_residual_gate_add라는 새로운 경량 CUDA JIT 커널을 추가하여 residual + update * gate 연산을 효율적으로 수행하는 것입니다. 이 커널은 특히 대규모 broadcast gate 패턴에 최적화된 row-tiled 방식을 사용합니다.

1. python/sglang/jit_kernel/csrc/diffusion/residual_gate_add.cuh 파일 추가

이 PR은 residual_gate_add.cuh라는 새로운 CUDA 커널 파일을 추가합니다. 기존에는 이 연산이 Python 또는 Triton 커널을 통해 처리되었으나, 이제는 특정 워크로드에 최적화된 CUDA 커널이 직접 사용됩니다.

Before (개념적인 Python/Triton 구현):

hidden_states = hidden_states + update * gate

After (새로운 CUDA 커널 파일 도입):

diff --git a/python/sglang/jit_kernel/csrc/diffusion/residual_gate_add.cuh b/python/sglang/jit_kernel/csrc/diffusion/residual_gate_add.cuh
new file mode 100644
index 000000000000..693641880298
--- /dev/null
+++ b/python/sglang/jit_kernel/csrc/diffusion/residual_gate_add.cuh
@@ -0,0 +1,322 @@
+// CUDA fast path for diffusion residual-gate elementwise updates.
+//
+// Implements:
+//   out = residual + update * gate
+//
+// The production shapes come from LTX-2.3 HQ residual/gate updates.  This is
+// intentionally narrow: contiguous residual/update/out tensors, with either a
+// full contiguous gate or a row-broadcast [1, 1, D] gate.
+//
+// Developed with MIT HAN Lab Kernel Design Agents:
+// https://github.com/mit-han-lab/kernel-design-agents
...

이 파일은 residual_gate_add_vec_kernel (full gate 처리)과 residual_gate_add_bcast_row_tile_kernel (row-broadcast gate 처리) 두 가지 주요 커널을 포함합니다. 특히 residual_gate_add_bcast_row_tile_kernel은 대규모 broadcast gate 처리에 최적화되어 있습니다.

핵심 연산 residual_gate_value:

template <typename T>
__device__ __forceinline__ T residual_gate_value(T residual, T update, T gate) {
  const T product = dtype_trait<T>::from(to_float(update) * to_float(gate));
  return dtype_trait<T>::from(to_float(residual) + to_float(product));
}

이 템플릿 함수는 fp16_t, bf16_t 등 다양한 데이터 타입에 대해 float으로 변환하여 연산을 수행하고 다시 원래 타입으로 변환하는 방식으로 정밀도를 유지하면서 유연성을 제공합니다.

2. residual_gate_add_bcast_row_tile_kernel의 최적화

대규모 broadcast gate의 경우, gate 텐서는 [B, 1, D] 형태를 가지며, 각 행에 대해 동일한 gate 벡터가 적용됩니다. 기존 방식은 각 행마다 gate 값을 다시 로드하거나 복잡한 인덱싱을 사용했을 수 있습니다. 이 커널은 row-tiled 방식을 사용하여 gate 메모리 접근을 최적화합니다.

template <typename T, int kVec>
__global__ void residual_gate_add_bcast_row_tile_kernel(
    const T* __restrict__ residual,
    const T* __restrict__ update,
    const T* __restrict__ gate,
    T* __restrict__ out,
    int64_t rows,
    int64_t row_vec) {
  const int64_t col_vec = static_cast<int64_t>(blockIdx.x) * kBcastColsVecPerBlock + threadIdx.x;
  if (col_vec >= row_vec) {
    return;
  }

  // Broadcast gate vector를 한 번 로드하여 여러 행에 재사용
  const Vec16<T> g{.raw = SGLANG_LDG(reinterpret_cast<const uint4*>(gate) + col_vec)};

  // Grid-stride over row tiles
  const int64_t row_tile_stride = static_cast<int64_t>(gridDim.y) * kBcastRowsPerBlock;
  for (int64_t row_base = static_cast<int64_t>(blockIdx.y) * kBcastRowsPerBlock; row_base < rows;
       row_base += row_tile_stride) {
#pragma unroll
    for (int row_offset = 0; row_offset < kBcastRowsPerBlock; ++row_offset) {
      const int64_t row = row_base + row_offset;
      if (row < rows) {
        const int64_t v = row * row_vec + col_vec;
        const Vec16<T> r{.raw = reinterpret_cast<const uint4*>(residual)[v]};
        const Vec16<T> u{.raw = reinterpret_cast<const uint4*>(update)[v]};

        Vec16<T> o;
#pragma unroll
        for (int i = 0; i < kVec; ++i) {
          o.elems[i] = residual_gate_value(r.elems[i], u.elems[i], g.elems[i]);
        }
        reinterpret_cast<uint4*>(out)[v] = o.raw;
      }
    }
  }
}
  • kBcastRowsPerBlock (4) 및 row_tile_stride: 이 상수는 한 블록이 처리할 행 타일의 수를 정의합니다. gate 벡터를 한 번 로드한 후 kBcastRowsPerBlock만큼의 행에 걸쳐 재사용함으로써, 반복적인 gate 메모리 트래픽을 줄입니다.
  • SGLANG_LDG: gate 값을 로드할 때 SGLANG_LDG (likely __ldg for global memory loads)를 사용하여 캐시 효율성을 높입니다.
  • Vec16<T>aligned16: 16바이트 단위로 메모리 접근을 벡터화하고 정렬함으로써 메모리 대역폭 활용을 극대화합니다. D % kVec == 0 조건은 벡터화된 접근이 가능한지 확인합니다.

3. launch_residual_gate_add 및 입력 검증

launch_residual_gate_add 함수는 입력 텐서의 속성(정렬, 차원, 연속성 등)을 확인하여 벡터화된 커널(residual_gate_add_vec_kernel 또는 residual_gate_add_bcast_row_tile_kernel)을 사용할지, 아니면 일반 스칼라 커널(residual_gate_add_scalar_kernel)로 폴백할지 결정합니다. 이는 최적화된 경로를 최대한 활용하면서도 다양한 입력에 대한 견고성을 보장합니다.

// launch_residual_gate_add 함수 내 일부
  const bool vec_ok = aligned16(residual_ptr) && aligned16(update_ptr) && aligned16(gate_ptr) && aligned16(out_ptr) &&
                      (D % kVec == 0) && (mode == GateMode::kBcastRow || total % kVec == 0);

  int64_t done = 0;
  if (vec_ok) {
    // ... 벡터화된 커널 실행 로직 ...
    done = n_vec * kVec;
  }

  if (done < total) {
    // ... 스칼라 커널 실행 로직 (나머지 처리) ...
  }

validate_residual_gate_add 함수는 모든 입력 텐서가 CUDA 디바이스에 있고, 동일한 디바이스 ID를 가지며, 차원과 데이터 타입이 일치하고, 메모리가 연속적인지 등을 엄격하게 검증합니다. 이는 커널의 안정적인 동작을 위한 필수적인 안전장치입니다.

왜 이 최적화가 좋은가

이 PR의 최적화는 Diffusion 모델의 핵심 연산에서 상당한 성능 향상을 가져왔으며, 이는 다음과 같은 이유로 좋은 최적화 사례입니다.

1. 벤치마크를 통한 명확한 성능 향상

B200 Kernel Benchmark:

Workload Gate Torch us Triton us CUDA us CUDA / Triton
ltx2_bcast_s32640_c4096 bcast 415.24 132.98 120.04 1.108x
ltx2_full_s8160_c4096 full 65.46 47.07 41.65 1.130x
ideogram4_bcast_s4096_c4608 bcast 64.71 34.94 19.92 1.754x
flux2_bcast_s4608_c3072 bcast 47.46 34.49 15.69 2.198x
flux2_bcast_s4096_c3072 bcast 41.41 34.53 15.56 2.219x
flux2_bcast_s512_c3072 bcast 13.41 34.45 15.51 2.221x
ltx2_full_s126_c2048 full 13.17 35.02 13.54 2.587x

H100 Kernel Benchmark:

Workload Gate Torch us Triton us CUDA us CUDA / Triton
ltx2_bcast_s32640_c4096 bcast 598.64 295.20 264.68 1.115x
ltx2_full_s8160_c4096 full 136.22 91.67 88.52 1.036x
ideogram4_bcast_s4096_c4608 bcast 91.05 43.35 41.30 1.050x
flux2_bcast_s4608_c3072 bcast 69.64 32.81 31.45 1.043x
flux2_bcast_s4096_c3072 bcast 61.24 29.83 27.50 1.085x
flux2_bcast_s512_c3072 bcast 9.21 22.33 10.26 2.176x
ltx2_full_s126_c2048 full 8.41 20.50 8.43 2.431x

벤치마크 결과에서 볼 수 있듯이, 특히 broadcast gate 패턴에서 Triton 대비 최대 2.2배 (B200) / 2.4배 (H100) 이상의 속도 향상을 달성했습니다. ltx2_bcast_s32640_c4096와 같은 LTX-2.3의 주요 워크로드에서도 10% 이상의 성능 향상을 보였습니다. 이는 Diffusion 모델의 추론 시간을 단축하는 데 직접적으로 기여합니다.

2. Nsight Compute 분석을 통한 최적화 효과 입증

대규모 broadcast workload (ltx2_bcast_s32640_c4096)에 대한 Nsight Compute 분석 결과는 row-tiled CUDA path의 효율성을 명확히 보여줍니다.

Implementation NCU duration us DRAM % SM % Global load inst L1/TEX load sectors Regs/thread Achieved occupancy
Original flat CUDA candidate 204.99 56.79 70.54 1,566,720 25,067,520 44 51.08%
Triton BLC fallback 157.60 73.34 16.07 1,566,720 20,889,600 112 22.79%
CUDA row-tile 4x 132.54 87.06 42.66 1,175,040 18,800,640 30 89.93%
  • NCU duration us 감소: 204.99us에서 132.54us로 크게 감소하여 실제 실행 시간이 단축되었음을 보여줍니다.
  • DRAM % 증가 (87.06%): DRAM 대역폭 활용률이 높아져 GPU의 메모리 시스템을 더욱 효율적으로 사용하고 있음을 나타냅니다.
  • Global load instL1/TEX load sectors 감소: 전역 메모리 로드 명령어와 L1/TEX 캐시 로드 섹터가 감소했습니다. 이는 row-tiled 방식이 gate 벡터를 재사용하여 불필요한 메모리 접근을 줄였기 때문입니다.
  • Regs/thread 감소 (30): 스레드당 레지스터 사용량이 줄어들어 더 많은 스레드가 동시에 실행될 수 있는 여지를 확보했습니다.
  • Achieved occupancy 대폭 증가 (89.93%): GPU의 Streaming Multiprocessor (SM) 점유율이 크게 높아져 하드웨어 리소스 활용률이 극대화되었습니다.

이러한 지표들은 "row-tiled CUDA path가 각 broadcast gate 벡터를 4개 행에 걸쳐 재사용하고, 핫 루프에서 비용이 많이 드는 flat-index modulo 연산을 제거하며, 반복적인 gate 메모리 트래픽을 줄인다"는 PR 설명의 핵심 최적화 전략이 성공적으로 작동했음을 증명합니다.

3. End-to-End 모델 성능 향상

B200 Model E2E Benchmark:

Implementation E2E ms LTX2AVDenoisingStage ms LTX2RefinementStage ms Peak reserved MB Speedup
Main torch expression 46644.08 29839.68 12758.12 67902 1.00x
CUDA JIT fast path 45198.37 28816.24 12419.63 67642 1.032x (+3.20%)

전체 LTX-2.3 HQ 모델 벤치마크에서도 3.2%의 유의미한 속도 향상을 달성했습니다. 이는 특정 "핫 패턴"에 대한 커널 최적화가 전체 모델 성능에 긍정적인 영향을 미칠 수 있음을 보여줍니다.

4. 일반적 교훈

  • 워크로드별 맞춤형 커널의 중요성: 범용적인 연산은 편리하지만, 특정 하드웨어 및 워크로드의 "핫 패턴"에서는 맞춤형 CUDA 커널이 압도적인 성능 우위를 제공할 수 있습니다.
  • 메모리 접근 패턴 최적화: row-tiled 방식과 같은 메모리 접근 패턴 최적화는 GPU의 메모리 대역폭 활용률을 높이고 캐시 미스를 줄여 성능을 크게 향상시킵니다.
  • JIT 컴파일러 호환성: torch.compile과 같은 JIT 컴파일러와의 호환성을 위해 커스텀 Op를 등록하고 fake impl을 제공하는 전략은 프레임워크의 유연성을 유지하면서도 최적화된 커널을 활용할 수 있게 합니다.
  • 견고한 입력 검증: 커널의 안정성을 위해 입력 텐서의 속성을 엄격하게 검증하는 것은 필수적입니다.

결론

이 PR은 SGLang의 LTX-2.3 Diffusion 모델에서 residual-gate update 연산의 성능 병목을 성공적으로 해결했습니다. 네이티브 CUDA JIT 커널과 row-tiled 메모리 접근 전략을 통해 대규모 broadcast gate 패턴에서 획기적인 속도 향상을 달성했으며, 이는 전체 모델의 추론 시간 단축으로 이어졌습니다. 이러한 최적화는 딥러닝 모델의 실제 배포 및 활용에 있어 매우 중요한 기여를 하며, SGLang이 고성능 멀티모달 모델을 지원하는 데 있어 핵심적인 역할을 수행하고 있음을 보여줍니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글