[sglang] SGLang LTX-2.3 Diffusion 모델 최적화: Residual-Gate 연산 CUDA Fast Path 도입
PR 링크: sgl-project/sglang#29361 상태: Merged | 변경: +672 / -10
들어가며
SGLang은 대규모 언어 모델(LLM) 및 멀티모달 모델의 추론 성능을 최적화하는 데 중점을 둔 프레임워크입니다. 특히, LTX-2.3과 같은 Diffusion 모델은 고해상도 이미지 및 비디오 생성을 위해 복잡하고 반복적인 연산을 수행하며, 이 과정에서 성능 병목이 발생하기 쉽습니다. 이번 PR은 LTX-2.3 모델의 핵심 연산 중 하나인 residual-gate update 패턴에 대한 성능 최적화를 목표로 합니다.
residual-gate update는 hidden_states = hidden_states + update * gate 형태의 연산으로, Diffusion Transformer (DiT) 블록 내의 residual, attention, MLP 업데이트에서 빈번하게 사용됩니다. 기존에는 PyTorch eager 모드나 Triton 커널을 통해 처리되었으나, LTX-2.3 HQ 모델에서 발생하는 [1, 32640, 4096]과 같은 대규모 broadcast gate 또는 [1, 8160, 4096]과 같은 full gate 텐서 처리 시 성능 저하가 관찰되었습니다. 이 PR은 이러한 "핫 패턴"에 대해 네이티브 CUDA JIT 커널을 도입하여 성능 병목을 해결합니다.
코드 분석: 무엇이 어떻게 개선되었나
이 PR의 핵심은 diffusion_residual_gate_add라는 새로운 경량 CUDA JIT 커널을 추가하여 residual + update * gate 연산을 효율적으로 수행하는 것입니다. 이 커널은 특히 대규모 broadcast gate 패턴에 최적화된 row-tiled 방식을 사용합니다.
1. python/sglang/jit_kernel/csrc/diffusion/residual_gate_add.cuh 파일 추가
이 PR은 residual_gate_add.cuh라는 새로운 CUDA 커널 파일을 추가합니다. 기존에는 이 연산이 Python 또는 Triton 커널을 통해 처리되었으나, 이제는 특정 워크로드에 최적화된 CUDA 커널이 직접 사용됩니다.
Before (개념적인 Python/Triton 구현):
hidden_states = hidden_states + update * gate
After (새로운 CUDA 커널 파일 도입):
diff --git a/python/sglang/jit_kernel/csrc/diffusion/residual_gate_add.cuh b/python/sglang/jit_kernel/csrc/diffusion/residual_gate_add.cuh
new file mode 100644
index 000000000000..693641880298
--- /dev/null
+++ b/python/sglang/jit_kernel/csrc/diffusion/residual_gate_add.cuh
@@ -0,0 +1,322 @@
+// CUDA fast path for diffusion residual-gate elementwise updates.
+//
+// Implements:
+// out = residual + update * gate
+//
+// The production shapes come from LTX-2.3 HQ residual/gate updates. This is
+// intentionally narrow: contiguous residual/update/out tensors, with either a
+// full contiguous gate or a row-broadcast [1, 1, D] gate.
+//
+// Developed with MIT HAN Lab Kernel Design Agents:
+// https://github.com/mit-han-lab/kernel-design-agents
...
이 파일은 residual_gate_add_vec_kernel (full gate 처리)과 residual_gate_add_bcast_row_tile_kernel (row-broadcast gate 처리) 두 가지 주요 커널을 포함합니다. 특히 residual_gate_add_bcast_row_tile_kernel은 대규모 broadcast gate 처리에 최적화되어 있습니다.
핵심 연산 residual_gate_value:
template <typename T>
__device__ __forceinline__ T residual_gate_value(T residual, T update, T gate) {
const T product = dtype_trait<T>::from(to_float(update) * to_float(gate));
return dtype_trait<T>::from(to_float(residual) + to_float(product));
}
이 템플릿 함수는 fp16_t, bf16_t 등 다양한 데이터 타입에 대해 float으로 변환하여 연산을 수행하고 다시 원래 타입으로 변환하는 방식으로 정밀도를 유지하면서 유연성을 제공합니다.
2. residual_gate_add_bcast_row_tile_kernel의 최적화
대규모 broadcast gate의 경우, gate 텐서는 [B, 1, D] 형태를 가지며, 각 행에 대해 동일한 gate 벡터가 적용됩니다. 기존 방식은 각 행마다 gate 값을 다시 로드하거나 복잡한 인덱싱을 사용했을 수 있습니다. 이 커널은 row-tiled 방식을 사용하여 gate 메모리 접근을 최적화합니다.
template <typename T, int kVec>
__global__ void residual_gate_add_bcast_row_tile_kernel(
const T* __restrict__ residual,
const T* __restrict__ update,
const T* __restrict__ gate,
T* __restrict__ out,
int64_t rows,
int64_t row_vec) {
const int64_t col_vec = static_cast<int64_t>(blockIdx.x) * kBcastColsVecPerBlock + threadIdx.x;
if (col_vec >= row_vec) {
return;
}
// Broadcast gate vector를 한 번 로드하여 여러 행에 재사용
const Vec16<T> g{.raw = SGLANG_LDG(reinterpret_cast<const uint4*>(gate) + col_vec)};
// Grid-stride over row tiles
const int64_t row_tile_stride = static_cast<int64_t>(gridDim.y) * kBcastRowsPerBlock;
for (int64_t row_base = static_cast<int64_t>(blockIdx.y) * kBcastRowsPerBlock; row_base < rows;
row_base += row_tile_stride) {
#pragma unroll
for (int row_offset = 0; row_offset < kBcastRowsPerBlock; ++row_offset) {
const int64_t row = row_base + row_offset;
if (row < rows) {
const int64_t v = row * row_vec + col_vec;
const Vec16<T> r{.raw = reinterpret_cast<const uint4*>(residual)[v]};
const Vec16<T> u{.raw = reinterpret_cast<const uint4*>(update)[v]};
Vec16<T> o;
#pragma unroll
for (int i = 0; i < kVec; ++i) {
o.elems[i] = residual_gate_value(r.elems[i], u.elems[i], g.elems[i]);
}
reinterpret_cast<uint4*>(out)[v] = o.raw;
}
}
}
}
kBcastRowsPerBlock(4) 및row_tile_stride: 이 상수는 한 블록이 처리할 행 타일의 수를 정의합니다.gate벡터를 한 번 로드한 후kBcastRowsPerBlock만큼의 행에 걸쳐 재사용함으로써, 반복적인gate메모리 트래픽을 줄입니다.SGLANG_LDG:gate값을 로드할 때SGLANG_LDG(likely__ldgfor global memory loads)를 사용하여 캐시 효율성을 높입니다.Vec16<T>및aligned16: 16바이트 단위로 메모리 접근을 벡터화하고 정렬함으로써 메모리 대역폭 활용을 극대화합니다.D % kVec == 0조건은 벡터화된 접근이 가능한지 확인합니다.
3. launch_residual_gate_add 및 입력 검증
launch_residual_gate_add 함수는 입력 텐서의 속성(정렬, 차원, 연속성 등)을 확인하여 벡터화된 커널(residual_gate_add_vec_kernel 또는 residual_gate_add_bcast_row_tile_kernel)을 사용할지, 아니면 일반 스칼라 커널(residual_gate_add_scalar_kernel)로 폴백할지 결정합니다. 이는 최적화된 경로를 최대한 활용하면서도 다양한 입력에 대한 견고성을 보장합니다.
// launch_residual_gate_add 함수 내 일부
const bool vec_ok = aligned16(residual_ptr) && aligned16(update_ptr) && aligned16(gate_ptr) && aligned16(out_ptr) &&
(D % kVec == 0) && (mode == GateMode::kBcastRow || total % kVec == 0);
int64_t done = 0;
if (vec_ok) {
// ... 벡터화된 커널 실행 로직 ...
done = n_vec * kVec;
}
if (done < total) {
// ... 스칼라 커널 실행 로직 (나머지 처리) ...
}
validate_residual_gate_add 함수는 모든 입력 텐서가 CUDA 디바이스에 있고, 동일한 디바이스 ID를 가지며, 차원과 데이터 타입이 일치하고, 메모리가 연속적인지 등을 엄격하게 검증합니다. 이는 커널의 안정적인 동작을 위한 필수적인 안전장치입니다.
왜 이 최적화가 좋은가
이 PR의 최적화는 Diffusion 모델의 핵심 연산에서 상당한 성능 향상을 가져왔으며, 이는 다음과 같은 이유로 좋은 최적화 사례입니다.
1. 벤치마크를 통한 명확한 성능 향상
B200 Kernel Benchmark:
| Workload | Gate | Torch us | Triton us | CUDA us | CUDA / Triton |
|---|---|---|---|---|---|
| ltx2_bcast_s32640_c4096 | bcast | 415.24 | 132.98 | 120.04 | 1.108x |
| ltx2_full_s8160_c4096 | full | 65.46 | 47.07 | 41.65 | 1.130x |
| ideogram4_bcast_s4096_c4608 | bcast | 64.71 | 34.94 | 19.92 | 1.754x |
| flux2_bcast_s4608_c3072 | bcast | 47.46 | 34.49 | 15.69 | 2.198x |
| flux2_bcast_s4096_c3072 | bcast | 41.41 | 34.53 | 15.56 | 2.219x |
| flux2_bcast_s512_c3072 | bcast | 13.41 | 34.45 | 15.51 | 2.221x |
| ltx2_full_s126_c2048 | full | 13.17 | 35.02 | 13.54 | 2.587x |
H100 Kernel Benchmark:
| Workload | Gate | Torch us | Triton us | CUDA us | CUDA / Triton |
|---|---|---|---|---|---|
| ltx2_bcast_s32640_c4096 | bcast | 598.64 | 295.20 | 264.68 | 1.115x |
| ltx2_full_s8160_c4096 | full | 136.22 | 91.67 | 88.52 | 1.036x |
| ideogram4_bcast_s4096_c4608 | bcast | 91.05 | 43.35 | 41.30 | 1.050x |
| flux2_bcast_s4608_c3072 | bcast | 69.64 | 32.81 | 31.45 | 1.043x |
| flux2_bcast_s4096_c3072 | bcast | 61.24 | 29.83 | 27.50 | 1.085x |
| flux2_bcast_s512_c3072 | bcast | 9.21 | 22.33 | 10.26 | 2.176x |
| ltx2_full_s126_c2048 | full | 8.41 | 20.50 | 8.43 | 2.431x |
벤치마크 결과에서 볼 수 있듯이, 특히 broadcast gate 패턴에서 Triton 대비 최대 2.2배 (B200) / 2.4배 (H100) 이상의 속도 향상을 달성했습니다. ltx2_bcast_s32640_c4096와 같은 LTX-2.3의 주요 워크로드에서도 10% 이상의 성능 향상을 보였습니다. 이는 Diffusion 모델의 추론 시간을 단축하는 데 직접적으로 기여합니다.
2. Nsight Compute 분석을 통한 최적화 효과 입증
대규모 broadcast workload (ltx2_bcast_s32640_c4096)에 대한 Nsight Compute 분석 결과는 row-tiled CUDA path의 효율성을 명확히 보여줍니다.
| Implementation | NCU duration us | DRAM % | SM % | Global load inst | L1/TEX load sectors | Regs/thread | Achieved occupancy |
|---|---|---|---|---|---|---|---|
| Original flat CUDA candidate | 204.99 | 56.79 | 70.54 | 1,566,720 | 25,067,520 | 44 | 51.08% |
| Triton BLC fallback | 157.60 | 73.34 | 16.07 | 1,566,720 | 20,889,600 | 112 | 22.79% |
| CUDA row-tile 4x | 132.54 | 87.06 | 42.66 | 1,175,040 | 18,800,640 | 30 | 89.93% |
NCU duration us감소: 204.99us에서 132.54us로 크게 감소하여 실제 실행 시간이 단축되었음을 보여줍니다.DRAM %증가 (87.06%): DRAM 대역폭 활용률이 높아져 GPU의 메모리 시스템을 더욱 효율적으로 사용하고 있음을 나타냅니다.Global load inst및L1/TEX load sectors감소: 전역 메모리 로드 명령어와 L1/TEX 캐시 로드 섹터가 감소했습니다. 이는row-tiled방식이gate벡터를 재사용하여 불필요한 메모리 접근을 줄였기 때문입니다.Regs/thread감소 (30): 스레드당 레지스터 사용량이 줄어들어 더 많은 스레드가 동시에 실행될 수 있는 여지를 확보했습니다.Achieved occupancy대폭 증가 (89.93%): GPU의 Streaming Multiprocessor (SM) 점유율이 크게 높아져 하드웨어 리소스 활용률이 극대화되었습니다.
이러한 지표들은 "row-tiled CUDA path가 각 broadcast gate 벡터를 4개 행에 걸쳐 재사용하고, 핫 루프에서 비용이 많이 드는 flat-index modulo 연산을 제거하며, 반복적인 gate 메모리 트래픽을 줄인다"는 PR 설명의 핵심 최적화 전략이 성공적으로 작동했음을 증명합니다.
3. End-to-End 모델 성능 향상
B200 Model E2E Benchmark:
| Implementation | E2E ms | LTX2AVDenoisingStage ms | LTX2RefinementStage ms | Peak reserved MB | Speedup |
|---|---|---|---|---|---|
| Main torch expression | 46644.08 | 29839.68 | 12758.12 | 67902 | 1.00x |
| CUDA JIT fast path | 45198.37 | 28816.24 | 12419.63 | 67642 | 1.032x (+3.20%) |
전체 LTX-2.3 HQ 모델 벤치마크에서도 3.2%의 유의미한 속도 향상을 달성했습니다. 이는 특정 "핫 패턴"에 대한 커널 최적화가 전체 모델 성능에 긍정적인 영향을 미칠 수 있음을 보여줍니다.
4. 일반적 교훈
- 워크로드별 맞춤형 커널의 중요성: 범용적인 연산은 편리하지만, 특정 하드웨어 및 워크로드의 "핫 패턴"에서는 맞춤형 CUDA 커널이 압도적인 성능 우위를 제공할 수 있습니다.
- 메모리 접근 패턴 최적화:
row-tiled방식과 같은 메모리 접근 패턴 최적화는 GPU의 메모리 대역폭 활용률을 높이고 캐시 미스를 줄여 성능을 크게 향상시킵니다. - JIT 컴파일러 호환성:
torch.compile과 같은 JIT 컴파일러와의 호환성을 위해 커스텀 Op를 등록하고 fake impl을 제공하는 전략은 프레임워크의 유연성을 유지하면서도 최적화된 커널을 활용할 수 있게 합니다. - 견고한 입력 검증: 커널의 안정성을 위해 입력 텐서의 속성을 엄격하게 검증하는 것은 필수적입니다.
결론
이 PR은 SGLang의 LTX-2.3 Diffusion 모델에서 residual-gate update 연산의 성능 병목을 성공적으로 해결했습니다. 네이티브 CUDA JIT 커널과 row-tiled 메모리 접근 전략을 통해 대규모 broadcast gate 패턴에서 획기적인 속도 향상을 달성했으며, 이는 전체 모델의 추론 시간 단축으로 이어졌습니다. 이러한 최적화는 딥러닝 모델의 실제 배포 및 활용에 있어 매우 중요한 기여를 하며, SGLang이 고성능 멀티모달 모델을 지원하는 데 있어 핵심적인 역할을 수행하고 있음을 보여줍니다.
참고 자료
- https://pytorch.org/docs/stable/generated/torch.compile.html
- https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [vllm] vLLM의 GLM5.2 성능 최적화: Triton 커널 융합을 통한 E2E Throughput 향상
- 현재글 : [sglang] SGLang LTX-2.3 Diffusion 모델 최적화: Residual-Gate 연산 CUDA Fast Path 도입
- 다음글 [vllm] vLLM ROCm 환경에서 FlyDSL을 활용한 MXFP8 MoE 성능 최적화
댓글