본문으로 건너뛰기

[SGLang] LoRA Triton 커널: SGMV, SGEMM 최적화 연산

들어가며

멀티 LoRA 서빙의 핵심 성능은 Segmented GEMM 커널에 달려있다. 여러 요청이 서로 다른 LoRA 어댑터를 사용할 때, 세그먼트별로 다른 가중치 행렬을 적용하면서도 GPU 병렬성을 최대한 활용해야 한다. SGLang은 SGEMM(Segmented GEMM)과 Chunked SGMV(Segmented Gather Matrix-Vector) 두 가지 Triton 커널 제품군을 제공한다.

구조도

배치 입력 (s 토큰, 여러 LoRA)
┌─────────────────────────────────────────┐
│  seg0 (LoRA A) │ seg1 (LoRA B) │ seg2   │
│  len=32        │ len=64        │ len=16 │
└────────────────┴───────────────┴────────┘
         │                │            │
    weights[A]       weights[B]   weights[A]
         │                │            │
         ▼                ▼            ▼
┌─────────────────────────────────────────┐
│            SGEMM Kernel                 │
│  grid = (tiles_per_seg, num_segments)   │
│  각 program이 (batch_id, tile_id) 처리   │
└─────────────────────────────────────────┘

핵심 코드 분석

SGEMM LoRA A: Shrink 커널

python/sglang/srt/lora/triton_ops/sgemm_lora_a.py의 커널은 입력을 저랭크 공간으로 변환한다 (input_dim → rank).

@triton.jit
def _sgemm_lora_a_kernel(
    x, weights, output,
    N, K, stack_num,        # N = stack_num * rank, K = input_dim
    x_stride_0, x_stride_1,
    w_stride_0, w_stride_1, w_stride_2,
    output_stride_0, output_stride_1,
    seg_lens, seg_indptr, weight_indices, lora_ranks,
    BLOCK_S: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
    batch_id = tl.program_id(axis=1)
    w_index = tl.load(weight_indices + batch_id)
    rank = tl.load(lora_ranks + w_index)

    if rank == 0:
        return  # no-op: base model 요청

    pid = tl.program_id(axis=0)
    seg_start = tl.load(seg_indptr + batch_id)
    seg_len = tl.load(seg_lens + batch_id)

    N = tl.minimum(N, rank * stack_num)  # 실제 rank에 맞게 조정

핵심 설계:

  1. 2D 그리드: (tiles_per_segment, num_segments). 각 program은 하나의 세그먼트(LoRA)의 한 타일을 처리한다.
  2. rank=0 early exit: base model 요청은 커널 진입 즉시 반환한다.
  3. 동적 N 조정: tl.minimum(N, rank * stack_num)으로 각 어댑터의 실제 rank에 맞게 출력 차원을 조정한다.

행렬곱 루프:

    partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
    for k in range(0, tl.cdiv(K, BLOCK_K)):
        x_tile = tl.load(x_ptrs,
            mask=(s_offset[:, None] < seg_len) & (k_offset[None, :] < K - k * BLOCK_K),
            other=0.0)
        w_tile = tl.load(w_ptrs,
            mask=(k_offset[:, None] < K - k * BLOCK_K) & (n_offset[None, :] < N),
            other=0.0)
        partial_sum += tl.dot(x_tile, w_tile)

K 차원(input_dim)을 BLOCK_K 단위로 순회하며 누적한다.

SGEMM LoRA B: Expand 커널

python/sglang/srt/lora/triton_ops/sgemm_lora_b.py는 저랭크 표현을 출력 차원으로 확장하고, base output에 in-place로 더한다.

@triton.jit
def _sgemm_lora_b_kernel(
    x, weights, output,
    N, K,                   # N = output_dim, K = rank
    ..., seg_lens, seg_indptr, weight_indices, lora_ranks,
    BLOCK_S, BLOCK_N, BLOCK_K,
    scalings,               # LoRA scaling factor
):
    batch_id = tl.program_id(axis=1)
    w_index = tl.load(weight_indices + batch_id)
    rank = tl.load(lora_ranks + w_index)
    if rank == 0:
        return

    scaling = tl.load(scalings + w_index)
    K = tl.minimum(K, rank)

    partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
    for k in range(0, tl.cdiv(K, BLOCK_K)):
        x_tile = tl.load(x_ptrs, ...)
        w_tile = tl.load(w_ptrs, ...)
        partial_sum += tl.dot(x_tile, w_tile)

Fused scaling: LoRA의 scaling factor(alpha/rank)를 커널 내부에서 적용하여 별도 연산을 제거한다. 결과를 base output에 누적하여 output = base + scaling * (x @ B) 를 단일 커널로 처리한다.

Chunked SGMV: Shrink 커널

python/sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py는 세그먼트를 고정 크기 청크로 분할하여 처리한다.

@triton.jit
def _chunked_lora_shrink_kernel(
    x, weights, output,
    seg_indptr, weight_indices, lora_ranks, permutation, num_segs,
    N: tl.constexpr, K: tl.constexpr, NUM_SLICES: tl.constexpr,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
    pid_s = tl.program_id(1)
    if pid_s >= num_segs:
        return
    pid_n = tl.program_id(0)

    w_index = tl.load(weight_indices + pid_s)
    rank = tl.load(lora_ranks + w_index)
    if rank == 0:
        return

    seg_start = tl.load(seg_indptr + pid_s)
    seg_end = tl.load(seg_indptr + pid_s + 1)
    cur_n = tl.minimum(N, rank * NUM_SLICES)

    # 논리적 인덱스 → 물리적 인덱스 매핑
    s_offset_logical = tl.arange(0, BLOCK_M) + seg_start
    s_offset_physical = tl.load(
        permutation + s_offset_logical, mask=s_offset_logical < seg_end)

일반 SGEMM과의 핵심 차이:

  1. Permutation: permutation 배열로 논리적 세그먼트 인덱스를 물리적 토큰 인덱스로 매핑한다. 이를 통해 청크 경계에서 불연속적인 토큰을 처리할 수 있다.
  2. 고정 BLOCK_M: 세그먼트 길이와 무관하게 고정 크기 청크로 처리하여, 극단적으로 긴 세그먼트에서의 워크로드 불균형을 방지한다.
  3. constexpr 최적화: N, K, NUM_SLICES를 컴파일 타임 상수로 지정하여 커널 특수화를 유도한다.

커널 캐싱

Chunked SGMV 커널은 @cached_triton_kernel로 래핑되어 동일 파라미터 조합에 대해 커널을 재사용한다.

@cached_triton_kernel(
    lambda _, kwargs: (kwargs["K"], kwargs["NUM_SLICES"], kwargs["BLOCK_M"])
)
@triton.jit(do_not_specialize=["num_segs"])
def _chunked_lora_shrink_kernel(...):

커널 비교

┌──────────────┬─────────────────────┬─────────────────────┐
│              │      SGEMM          │   Chunked SGMV      │
├──────────────┼─────────────────────┼─────────────────────┤
│ 그리드 구성   │ (tiles, segments)   │ (N_tiles, chunks)   │
│ 세그먼트 처리 │ 가변 길이           │ 고정 BLOCK_M 청크   │
│ 인덱싱       │ 직접 seg_indptr     │ permutation 간접    │
│ 장점         │ 낮은 오버헤드        │ 편향 분포 대응      │
│ 적합 상황    │ 균등 분포            │ 편향된 LoRA 분포    │
│ K 특수화     │ 런타임 결정          │ constexpr 최적화    │
└──────────────┴─────────────────────┴─────────────────────┘

설계 근거

rank=0의 효율적 처리

모든 커널에서 rank == 0이면 즉시 반환한다. 이를 통해 base model 전용 요청이 LoRA 오버헤드 없이 처리된다. CUDA Graph 캡처 시에도 커널은 기록되지만, rank=0으로 인해 실질적 연산이 없다.

stack_num 파라미터

QKV Projection은 Q, K, V 3개의 행렬이 스택되어 있으므로 stack_num=3, Gate/Up은 stack_num=2로 설정한다. 커널 내부에서 N = rank * stack_num으로 출력 차원을 자동 조정한다.

Scaling의 Fused 처리

LoRA의 scaling factor는 원래 alpha / rank로 계산된다. 이를 B 커널의 마지막 단계에서 적용하여 별도의 element-wise 연산 커널을 제거한다.

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글