본문으로 건너뛰기

[SGLang] LoRA 백엔드: PyTorch, Triton, Chunked 구현 비교

들어가며

SGLang은 LoRA 연산을 수행하는 3가지 백엔드를 제공한다. PyTorch 네이티브(torch_native), Triton 기반(triton), 그리고 Chunked SGMV(csgmv)이다. 각 백엔드는 BaseLoRABackend를 상속하며, 동일한 인터페이스로 SGEMM(Segmented GEMM) 연산을 제공한다.

구조도

┌─────────────────────────────────────────────┐
│              BaseLoRABackend                │
│  - run_lora_a_sgemm(x, weights) -> tensor   │
│  - run_lora_b_sgemm(x, weights) -> tensor   │
│  - run_qkv_lora(x, a, b, ...) -> tensor     │
│  - run_gate_up_lora(x, a, b) -> tensor      │
│  - prepare_lora_batch(forward_batch, ...)    │
└──────────────────┬──────────────────────────┘
                   │ 상속
     ┌─────────────┼─────────────┐
     ▼             ▼             ▼
┌──────────┐ ┌──────────┐ ┌──────────────┐
│  Torch   │ │  Triton  │ │  Chunked     │
│  Native  │ │  Backend │ │  SGMV        │
│          │ │          │ │  (csgmv)     │
├──────────┤ ├──────────┤ ├──────────────┤
│ CPU-side │ │ GPU-side │ │ 청크 분할     │
│ loop per │ │ fused    │ │ + SGMV       │
│ segment  │ │ kernels  │ │ kernels      │
└──────────┘ └──────────┘ └──────────────┘

핵심 코드 분석

BaseLoRABackend: 공통 인터페이스

python/sglang/srt/lora/backend/base_backend.py에서 모든 백엔드가 구현해야 하는 인터페이스를 정의한다.

class BaseLoRABackend(LoRABackendLmHeadMixing):
    def __init__(self, max_loras_per_batch, device):
        self.max_loras_per_batch = max_loras_per_batch
        self.device = device

    def run_lora_a_sgemm(self, x, weights, *args, **kwargs) -> torch.Tensor:
        """입력 x (s, input_dim) * weights (num_lora, c*r, input_dim) -> (s, c*r)"""
        pass

    def run_lora_b_sgemm(self, x, weights, *args, **kwargs) -> torch.Tensor:
        """중간결과 x (s, r) * weights (num_lora, output_dim, r) -> (s, output_dim)"""
        pass

TorchNativeLoRABackend: PyTorch 구현

CPU 측에서 세그먼트별로 루프를 돌며 각 LoRA의 행렬곱을 수행한다. 참조 구현으로 사용된다.

class TorchNativeLoRABackend(BaseLoRABackend):
    name = "torch_native"

    def run_lora_a_sgemm(self, x, weights, stack_num=1, *args, **kwargs):
        output_tensor = sgemm_lora_a_fwd(
            inputs=x,
            weights=weights,
            weight_indices=self.batch_info.weight_indices_cpu,
            seg_len_tensor=self.batch_info.seg_lens_cpu,
            lora_ranks=self.batch_info.lora_ranks_cpu,
            scaling_tensor=self.batch_info.scalings_cpu,
            num_slices=stack_num,
        )
        return output_tensor

Torch 백엔드는 추가로 CPU 측 배치 정보를 유지한다.

@dataclass
class TorchNativeLoRABatchInfo(LoRABatchInfo):
    lora_ranks_cpu: Optional[torch.Tensor] = None
    seg_indptr_cpu: Optional[torch.Tensor] = None
    seg_lens_cpu: Optional[torch.Tensor] = None
    weight_indices_cpu: Optional[torch.Tensor] = None
    scalings_cpu: Optional[torch.Tensor] = None

TritonLoRABackend: Triton 최적화

GPU 측 Triton 커널을 사용하여 세그먼트 단위 병렬 GEMM을 수행한다.

class TritonLoRABackend(BaseLoRABackend):
    name = "triton"

    def run_lora_a_sgemm(self, x, weights, pruned_batch_info=None, stack_num=1, ...):
        batch_info = pruned_batch_info if pruned_batch_info is not None else self.batch_info
        return sgemm_lora_a_fwd(x, weights, batch_info, stack_num=stack_num)

    def run_lora_b_sgemm(self, x, weights, base_output=None, pruned_batch_info=None, ...):
        batch_info = pruned_batch_info if pruned_batch_info is not None else self.batch_info
        return sgemm_lora_b_fwd(x, weights, batch_info, base_output)

QKV와 Gate/Up에 대해 특화된 커널을 제공한다.

    def run_qkv_lora(self, x, qkv_lora_a, qkv_lora_b, output_offset,
                     max_qkv_out_dim, base_output=None, ...):
        # x: (s, input_dim)
        # qkv_lora_a: (num_lora, 3*r, input_dim) - 3개 행렬 스택
        # qkv_lora_b: (num_lora, output_dim_q + 2*output_dim_kv, r)
        lora_a_output = sgemm_lora_a_fwd(x, qkv_lora_a, self.batch_info, stack_num=3)
        lora_output = qkv_lora_b_fwd(
            lora_a_output, qkv_lora_b, self.batch_info,
            output_offset, max_qkv_out_dim, base_output)
        return lora_output

    def run_gate_up_lora(self, x, gate_up_lora_a, gate_up_lora_b,
                         base_output=None, ...):
        lora_a_output = sgemm_lora_a_fwd(x, gate_up_lora_a, self.batch_info, stack_num=2)
        lora_output = gate_up_lora_b_fwd(
            lora_a_output, gate_up_lora_b, self.batch_info, base_output)
        return lora_output

ChunkedSgmvLoRABackend: 청크 기반 SGMV

Punica 논문의 SGMV 알고리즘을 기반으로, 입력 시퀀스를 고정 크기 청크로 분할하여 처리한다. LoRA 분포가 편향된 경우(한 LoRA에 토큰이 집중) 과도한 커널 런치를 방지한다.

class ChunkedSgmvLoRABackend(BaseLoRABackend):
    name = "csgmv"

    def __init__(self, max_loras_per_batch, device, server_args):
        super().__init__(max_loras_per_batch, device)
        self.max_chunk_size = server_args.max_lora_chunk_size

    def run_lora_a_sgemm(self, x, weights, pruned_batch_info=None, stack_num=1, ...):
        batch_info = pruned_batch_info if pruned_batch_info is not None else self.batch_info
        return chunked_sgmv_lora_shrink_forward(
            x=x, weights=weights, batch_info=batch_info, num_slices=stack_num)

    def run_lora_b_sgemm(self, x, weights, output_offset, base_output=None, ...):
        output_dim = weights.shape[-2]
        return chunked_sgmv_lora_expand_forward(
            x=x, weights=weights, batch_info=batch_info,
            slice_offsets=output_offset, max_slice_size=output_dim,
            base_output=base_output)

Shrink(A 행렬)와 Expand(B 행렬)로 명명하는 것이 SGMV의 특징이다. A 행렬은 고차원→저차원(shrink), B 행렬은 저차원→고차원(expand)이다.

백엔드 비교

┌──────────┬────────────────┬────────────────┬─────────────────┐
│          │  torch_native  │    triton      │     csgmv       │
├──────────┼────────────────┼────────────────┼─────────────────┤
│ 실행위치  │ CPU loop +     │ GPU Triton     │ GPU Triton      │
│          │ GPU matmul     │ fused kernel   │ chunked kernel  │
├──────────┼────────────────┼────────────────┼─────────────────┤
│ 커널호출  │ O(num_segments)│ O(1) per layer │ O(num_chunks)   │
├──────────┼────────────────┼────────────────┼─────────────────┤
│ 장점     │ 디버깅 용이     │ 최적 성능      │ 편향 분포 대응   │
├──────────┼────────────────┼────────────────┼─────────────────┤
│ CUDA     │ 미지원         │ 지원           │ 지원            │
│ Graph    │                │                │                 │
├──────────┼────────────────┼────────────────┼─────────────────┤
│ QKV      │ 별도 커널      │ qkv_lora_b_fwd │ chunked expand  │
│ 특화     │                │ 전용 커널      │                 │
└──────────┴────────────────┴────────────────┴─────────────────┘

언제 어떤 백엔드를 선택할까?

  1. triton: 대부분의 프로덕션 환경에서 기본 선택. 단일 커널 호출로 모든 세그먼트를 처리한다.
  2. csgmv: LoRA 분포가 극도로 편향된 경우(한 어댑터가 배치 대부분 차지). 청크 분할로 부하를 균등화한다.
  3. torch_native: 개발/디버깅 용도. CPU 측 루프로 동작하여 결과 검증에 유용하다.

설계 근거

pruned_batch_info

LM Head LoRA에서는 Logprobs 계산 시 일부 토큰만 처리하므로, 전체 배치 정보 대신 pruned(축소된) 배치 정보를 사용한다. Triton과 Chunked 백엔드 모두 이를 지원한다.

Embedding LoRA 특수 처리

임베딩 레이어는 SGEMM이 아닌 임베딩 룩업으로 A 행렬을 적용한다. Triton 백엔드는 전용 embedding_lora_a_fwd 커널을, Chunked 백엔드는 chunked_embedding_lora_a_forward를 사용한다.

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글