[SGLang] LoRA Triton 커널: SGMV, SGEMM 최적화 연산
들어가며
멀티 LoRA 서빙의 핵심 성능은 Segmented GEMM 커널에 달려있다. 여러 요청이 서로 다른 LoRA 어댑터를 사용할 때, 세그먼트별로 다른 가중치 행렬을 적용하면서도 GPU 병렬성을 최대한 활용해야 한다. SGLang은 SGEMM(Segmented GEMM)과 Chunked SGMV(Segmented Gather Matrix-Vector) 두 가지 Triton 커널 제품군을 제공한다.
구조도
배치 입력 (s 토큰, 여러 LoRA)
┌─────────────────────────────────────────┐
│ seg0 (LoRA A) │ seg1 (LoRA B) │ seg2 │
│ len=32 │ len=64 │ len=16 │
└────────────────┴───────────────┴────────┘
│ │ │
weights[A] weights[B] weights[A]
│ │ │
▼ ▼ ▼
┌─────────────────────────────────────────┐
│ SGEMM Kernel │
│ grid = (tiles_per_seg, num_segments) │
│ 각 program이 (batch_id, tile_id) 처리 │
└─────────────────────────────────────────┘
핵심 코드 분석
SGEMM LoRA A: Shrink 커널
python/sglang/srt/lora/triton_ops/sgemm_lora_a.py의 커널은 입력을 저랭크 공간으로 변환한다 (input_dim → rank).
@triton.jit
def _sgemm_lora_a_kernel(
x, weights, output,
N, K, stack_num, # N = stack_num * rank, K = input_dim
x_stride_0, x_stride_1,
w_stride_0, w_stride_1, w_stride_2,
output_stride_0, output_stride_1,
seg_lens, seg_indptr, weight_indices, lora_ranks,
BLOCK_S: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
batch_id = tl.program_id(axis=1)
w_index = tl.load(weight_indices + batch_id)
rank = tl.load(lora_ranks + w_index)
if rank == 0:
return # no-op: base model 요청
pid = tl.program_id(axis=0)
seg_start = tl.load(seg_indptr + batch_id)
seg_len = tl.load(seg_lens + batch_id)
N = tl.minimum(N, rank * stack_num) # 실제 rank에 맞게 조정
핵심 설계:
- 2D 그리드:
(tiles_per_segment, num_segments). 각 program은 하나의 세그먼트(LoRA)의 한 타일을 처리한다. - rank=0 early exit: base model 요청은 커널 진입 즉시 반환한다.
- 동적 N 조정:
tl.minimum(N, rank * stack_num)으로 각 어댑터의 실제 rank에 맞게 출력 차원을 조정한다.
행렬곱 루프:
partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_K)):
x_tile = tl.load(x_ptrs,
mask=(s_offset[:, None] < seg_len) & (k_offset[None, :] < K - k * BLOCK_K),
other=0.0)
w_tile = tl.load(w_ptrs,
mask=(k_offset[:, None] < K - k * BLOCK_K) & (n_offset[None, :] < N),
other=0.0)
partial_sum += tl.dot(x_tile, w_tile)
K 차원(input_dim)을 BLOCK_K 단위로 순회하며 누적한다.
SGEMM LoRA B: Expand 커널
python/sglang/srt/lora/triton_ops/sgemm_lora_b.py는 저랭크 표현을 출력 차원으로 확장하고, base output에 in-place로 더한다.
@triton.jit
def _sgemm_lora_b_kernel(
x, weights, output,
N, K, # N = output_dim, K = rank
..., seg_lens, seg_indptr, weight_indices, lora_ranks,
BLOCK_S, BLOCK_N, BLOCK_K,
scalings, # LoRA scaling factor
):
batch_id = tl.program_id(axis=1)
w_index = tl.load(weight_indices + batch_id)
rank = tl.load(lora_ranks + w_index)
if rank == 0:
return
scaling = tl.load(scalings + w_index)
K = tl.minimum(K, rank)
partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_K)):
x_tile = tl.load(x_ptrs, ...)
w_tile = tl.load(w_ptrs, ...)
partial_sum += tl.dot(x_tile, w_tile)
Fused scaling: LoRA의 scaling factor(alpha/rank)를 커널 내부에서 적용하여 별도 연산을 제거한다. 결과를 base output에 누적하여 output = base + scaling * (x @ B) 를 단일 커널로 처리한다.
Chunked SGMV: Shrink 커널
python/sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py는 세그먼트를 고정 크기 청크로 분할하여 처리한다.
@triton.jit
def _chunked_lora_shrink_kernel(
x, weights, output,
seg_indptr, weight_indices, lora_ranks, permutation, num_segs,
N: tl.constexpr, K: tl.constexpr, NUM_SLICES: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid_s = tl.program_id(1)
if pid_s >= num_segs:
return
pid_n = tl.program_id(0)
w_index = tl.load(weight_indices + pid_s)
rank = tl.load(lora_ranks + w_index)
if rank == 0:
return
seg_start = tl.load(seg_indptr + pid_s)
seg_end = tl.load(seg_indptr + pid_s + 1)
cur_n = tl.minimum(N, rank * NUM_SLICES)
# 논리적 인덱스 → 물리적 인덱스 매핑
s_offset_logical = tl.arange(0, BLOCK_M) + seg_start
s_offset_physical = tl.load(
permutation + s_offset_logical, mask=s_offset_logical < seg_end)
일반 SGEMM과의 핵심 차이:
- Permutation:
permutation배열로 논리적 세그먼트 인덱스를 물리적 토큰 인덱스로 매핑한다. 이를 통해 청크 경계에서 불연속적인 토큰을 처리할 수 있다. - 고정 BLOCK_M: 세그먼트 길이와 무관하게 고정 크기 청크로 처리하여, 극단적으로 긴 세그먼트에서의 워크로드 불균형을 방지한다.
- constexpr 최적화:
N,K,NUM_SLICES를 컴파일 타임 상수로 지정하여 커널 특수화를 유도한다.
커널 캐싱
Chunked SGMV 커널은 @cached_triton_kernel로 래핑되어 동일 파라미터 조합에 대해 커널을 재사용한다.
@cached_triton_kernel(
lambda _, kwargs: (kwargs["K"], kwargs["NUM_SLICES"], kwargs["BLOCK_M"])
)
@triton.jit(do_not_specialize=["num_segs"])
def _chunked_lora_shrink_kernel(...):
커널 비교
┌──────────────┬─────────────────────┬─────────────────────┐
│ │ SGEMM │ Chunked SGMV │
├──────────────┼─────────────────────┼─────────────────────┤
│ 그리드 구성 │ (tiles, segments) │ (N_tiles, chunks) │
│ 세그먼트 처리 │ 가변 길이 │ 고정 BLOCK_M 청크 │
│ 인덱싱 │ 직접 seg_indptr │ permutation 간접 │
│ 장점 │ 낮은 오버헤드 │ 편향 분포 대응 │
│ 적합 상황 │ 균등 분포 │ 편향된 LoRA 분포 │
│ K 특수화 │ 런타임 결정 │ constexpr 최적화 │
└──────────────┴─────────────────────┴─────────────────────┘
설계 근거
rank=0의 효율적 처리
모든 커널에서 rank == 0이면 즉시 반환한다. 이를 통해 base model 전용 요청이 LoRA 오버헤드 없이 처리된다. CUDA Graph 캡처 시에도 커널은 기록되지만, rank=0으로 인해 실질적 연산이 없다.
stack_num 파라미터
QKV Projection은 Q, K, V 3개의 행렬이 스택되어 있으므로 stack_num=3, Gate/Up은 stack_num=2로 설정한다. 커널 내부에서 N = rank * stack_num으로 출력 차원을 자동 조정한다.
Scaling의 Fused 처리
LoRA의 scaling factor는 원래 alpha / rank로 계산된다. 이를 B 커널의 마지막 단계에서 적용하여 별도의 element-wise 연산 커널을 제거한다.
관련 포스트
- LoRA Layers: QKV, Gate/Up 프로젝션 어댑터
- LoRA 백엔드: PyTorch, Triton, Chunked 구현 비교
- LoRA + MoE 융합: 어댑터와 전문가 혼합의 통합
참고
관련 포스트
- [sglang] DeepSeek-V4의 Latency 최적화: Fused mHC Post/Pre Kernel 도입
- [sglang] sglang ROCm MXFP4 어텐션에서 불필요한 contiguous copy 제거를 통한 성능 최적화
- [vllm] AMD RDNA3 (gfx1100)를 위한 vLLM의 W4A16 GPTQ 커널 최적화 심층 분석
- [sglang] SGLang의 MoE 성능 최적화: 512 전문가 모델을 위한 커널 최적화
- [sglang] sglang의 torch.compile 활용: Advanced Indexing Gather 최적화로 LLM 추론 가속화
SGLang 의 다른글
- 이전글 [SGLang] LoRA 백엔드: PyTorch, Triton, Chunked 구현 비교
- 현재글 : [SGLang] LoRA Triton 커널: SGMV, SGEMM 최적화 연산
- 다음글 [SGLang] LoRA + MoE 융합: 어댑터와 전문가 혼합의 통합
댓글