[SGLang] LoRA 백엔드: PyTorch, Triton, Chunked 구현 비교
들어가며
SGLang은 LoRA 연산을 수행하는 3가지 백엔드를 제공한다. PyTorch 네이티브(torch_native), Triton 기반(triton), 그리고 Chunked SGMV(csgmv)이다. 각 백엔드는 BaseLoRABackend를 상속하며, 동일한 인터페이스로 SGEMM(Segmented GEMM) 연산을 제공한다.
구조도
┌─────────────────────────────────────────────┐
│ BaseLoRABackend │
│ - run_lora_a_sgemm(x, weights) -> tensor │
│ - run_lora_b_sgemm(x, weights) -> tensor │
│ - run_qkv_lora(x, a, b, ...) -> tensor │
│ - run_gate_up_lora(x, a, b) -> tensor │
│ - prepare_lora_batch(forward_batch, ...) │
└──────────────────┬──────────────────────────┘
│ 상속
┌─────────────┼─────────────┐
▼ ▼ ▼
┌──────────┐ ┌──────────┐ ┌──────────────┐
│ Torch │ │ Triton │ │ Chunked │
│ Native │ │ Backend │ │ SGMV │
│ │ │ │ │ (csgmv) │
├──────────┤ ├──────────┤ ├──────────────┤
│ CPU-side │ │ GPU-side │ │ 청크 분할 │
│ loop per │ │ fused │ │ + SGMV │
│ segment │ │ kernels │ │ kernels │
└──────────┘ └──────────┘ └──────────────┘
핵심 코드 분석
BaseLoRABackend: 공통 인터페이스
python/sglang/srt/lora/backend/base_backend.py에서 모든 백엔드가 구현해야 하는 인터페이스를 정의한다.
class BaseLoRABackend(LoRABackendLmHeadMixing):
def __init__(self, max_loras_per_batch, device):
self.max_loras_per_batch = max_loras_per_batch
self.device = device
def run_lora_a_sgemm(self, x, weights, *args, **kwargs) -> torch.Tensor:
"""입력 x (s, input_dim) * weights (num_lora, c*r, input_dim) -> (s, c*r)"""
pass
def run_lora_b_sgemm(self, x, weights, *args, **kwargs) -> torch.Tensor:
"""중간결과 x (s, r) * weights (num_lora, output_dim, r) -> (s, output_dim)"""
pass
TorchNativeLoRABackend: PyTorch 구현
CPU 측에서 세그먼트별로 루프를 돌며 각 LoRA의 행렬곱을 수행한다. 참조 구현으로 사용된다.
class TorchNativeLoRABackend(BaseLoRABackend):
name = "torch_native"
def run_lora_a_sgemm(self, x, weights, stack_num=1, *args, **kwargs):
output_tensor = sgemm_lora_a_fwd(
inputs=x,
weights=weights,
weight_indices=self.batch_info.weight_indices_cpu,
seg_len_tensor=self.batch_info.seg_lens_cpu,
lora_ranks=self.batch_info.lora_ranks_cpu,
scaling_tensor=self.batch_info.scalings_cpu,
num_slices=stack_num,
)
return output_tensor
Torch 백엔드는 추가로 CPU 측 배치 정보를 유지한다.
@dataclass
class TorchNativeLoRABatchInfo(LoRABatchInfo):
lora_ranks_cpu: Optional[torch.Tensor] = None
seg_indptr_cpu: Optional[torch.Tensor] = None
seg_lens_cpu: Optional[torch.Tensor] = None
weight_indices_cpu: Optional[torch.Tensor] = None
scalings_cpu: Optional[torch.Tensor] = None
TritonLoRABackend: Triton 최적화
GPU 측 Triton 커널을 사용하여 세그먼트 단위 병렬 GEMM을 수행한다.
class TritonLoRABackend(BaseLoRABackend):
name = "triton"
def run_lora_a_sgemm(self, x, weights, pruned_batch_info=None, stack_num=1, ...):
batch_info = pruned_batch_info if pruned_batch_info is not None else self.batch_info
return sgemm_lora_a_fwd(x, weights, batch_info, stack_num=stack_num)
def run_lora_b_sgemm(self, x, weights, base_output=None, pruned_batch_info=None, ...):
batch_info = pruned_batch_info if pruned_batch_info is not None else self.batch_info
return sgemm_lora_b_fwd(x, weights, batch_info, base_output)
QKV와 Gate/Up에 대해 특화된 커널을 제공한다.
def run_qkv_lora(self, x, qkv_lora_a, qkv_lora_b, output_offset,
max_qkv_out_dim, base_output=None, ...):
# x: (s, input_dim)
# qkv_lora_a: (num_lora, 3*r, input_dim) - 3개 행렬 스택
# qkv_lora_b: (num_lora, output_dim_q + 2*output_dim_kv, r)
lora_a_output = sgemm_lora_a_fwd(x, qkv_lora_a, self.batch_info, stack_num=3)
lora_output = qkv_lora_b_fwd(
lora_a_output, qkv_lora_b, self.batch_info,
output_offset, max_qkv_out_dim, base_output)
return lora_output
def run_gate_up_lora(self, x, gate_up_lora_a, gate_up_lora_b,
base_output=None, ...):
lora_a_output = sgemm_lora_a_fwd(x, gate_up_lora_a, self.batch_info, stack_num=2)
lora_output = gate_up_lora_b_fwd(
lora_a_output, gate_up_lora_b, self.batch_info, base_output)
return lora_output
ChunkedSgmvLoRABackend: 청크 기반 SGMV
Punica 논문의 SGMV 알고리즘을 기반으로, 입력 시퀀스를 고정 크기 청크로 분할하여 처리한다. LoRA 분포가 편향된 경우(한 LoRA에 토큰이 집중) 과도한 커널 런치를 방지한다.
class ChunkedSgmvLoRABackend(BaseLoRABackend):
name = "csgmv"
def __init__(self, max_loras_per_batch, device, server_args):
super().__init__(max_loras_per_batch, device)
self.max_chunk_size = server_args.max_lora_chunk_size
def run_lora_a_sgemm(self, x, weights, pruned_batch_info=None, stack_num=1, ...):
batch_info = pruned_batch_info if pruned_batch_info is not None else self.batch_info
return chunked_sgmv_lora_shrink_forward(
x=x, weights=weights, batch_info=batch_info, num_slices=stack_num)
def run_lora_b_sgemm(self, x, weights, output_offset, base_output=None, ...):
output_dim = weights.shape[-2]
return chunked_sgmv_lora_expand_forward(
x=x, weights=weights, batch_info=batch_info,
slice_offsets=output_offset, max_slice_size=output_dim,
base_output=base_output)
Shrink(A 행렬)와 Expand(B 행렬)로 명명하는 것이 SGMV의 특징이다. A 행렬은 고차원→저차원(shrink), B 행렬은 저차원→고차원(expand)이다.
백엔드 비교
┌──────────┬────────────────┬────────────────┬─────────────────┐
│ │ torch_native │ triton │ csgmv │
├──────────┼────────────────┼────────────────┼─────────────────┤
│ 실행위치 │ CPU loop + │ GPU Triton │ GPU Triton │
│ │ GPU matmul │ fused kernel │ chunked kernel │
├──────────┼────────────────┼────────────────┼─────────────────┤
│ 커널호출 │ O(num_segments)│ O(1) per layer │ O(num_chunks) │
├──────────┼────────────────┼────────────────┼─────────────────┤
│ 장점 │ 디버깅 용이 │ 최적 성능 │ 편향 분포 대응 │
├──────────┼────────────────┼────────────────┼─────────────────┤
│ CUDA │ 미지원 │ 지원 │ 지원 │
│ Graph │ │ │ │
├──────────┼────────────────┼────────────────┼─────────────────┤
│ QKV │ 별도 커널 │ qkv_lora_b_fwd │ chunked expand │
│ 특화 │ │ 전용 커널 │ │
└──────────┴────────────────┴────────────────┴─────────────────┘
언제 어떤 백엔드를 선택할까?
- triton: 대부분의 프로덕션 환경에서 기본 선택. 단일 커널 호출로 모든 세그먼트를 처리한다.
- csgmv: LoRA 분포가 극도로 편향된 경우(한 어댑터가 배치 대부분 차지). 청크 분할로 부하를 균등화한다.
- torch_native: 개발/디버깅 용도. CPU 측 루프로 동작하여 결과 검증에 유용하다.
설계 근거
pruned_batch_info
LM Head LoRA에서는 Logprobs 계산 시 일부 토큰만 처리하므로, 전체 배치 정보 대신 pruned(축소된) 배치 정보를 사용한다. Triton과 Chunked 백엔드 모두 이를 지원한다.
Embedding LoRA 특수 처리
임베딩 레이어는 SGEMM이 아닌 임베딩 룩업으로 A 행렬을 적용한다. Triton 백엔드는 전용 embedding_lora_a_fwd 커널을, Chunked 백엔드는 chunked_embedding_lora_a_forward를 사용한다.
관련 포스트
참고
관련 포스트
SGLang 의 다른글
- 이전글 [SGLang] LoRA Layers: QKV, Gate/Up 프로젝션 어댑터
- 현재글 : [SGLang] LoRA 백엔드: PyTorch, Triton, Chunked 구현 비교
- 다음글 [SGLang] LoRA Triton 커널: SGMV, SGEMM 최적화 연산
댓글