본문으로 건너뛰기

[SGLang] LoRA Manager: 어댑터 라이프사이클 관리

들어가며

멀티 테넌트 LLM 서빙에서는 하나의 베이스 모델 위에 다수의 LoRA 어댑터를 동시에 서빙해야 한다. SGLang의 LoRAManager는 어댑터의 로딩/언로딩, 메모리 풀 관리, 배치별 LoRA 정보 준비, 그리고 CUDA Graph 호환까지 종합적으로 관리한다.

구조도

┌────────────────────────────────────────────────┐
│                  LoRAManager                    │
│                                                │
│  ┌──────────┐  ┌──────────┐  ┌──────────────┐  │
│  │ configs  │  │  loras   │  │  lora_refs   │  │
│  │ (LoRA    │  │ (LoRA    │  │ (LoRARef     │  │
│  │  Config) │  │ Adapter) │  │  registry)   │  │
│  └──────────┘  └──────────┘  └──────────────┘  │
│                                                │
│  ┌──────────────────────────────────────────┐  │
│  │         LoRAMemoryPool                    │  │
│  │  ┌────────┐ ┌────────┐ ┌────────┐        │  │
│  │  │ Slot 0 │ │ Slot 1 │ │ Slot N │  ...   │  │
│  │  │(LoRA A)│ │(LoRA B)│ │(base)  │        │  │
│  │  └────────┘ └────────┘ └────────┘        │  │
│  └──────────────────────────────────────────┘  │
│                                                │
│  ┌──────────────────────────────────────────┐  │
│  │         LoRABackend (Triton/Torch/CSGMV) │  │
│  │  - SGEMM kernels                         │  │
│  │  - Batch info management                 │  │
│  └──────────────────────────────────────────┘  │
└────────────────────────────────────────────────┘

핵심 코드 분석

초기화: 백엔드와 상태 설정

python/sglang/srt/lora/lora_manager.py에서 LoRAManager를 초기화한다.

class LoRAManager:
    def __init__(self, base_model, base_hf_config, max_loras_per_batch,
                 load_config, dtype, server_args, lora_backend="triton",
                 tp_size=1, tp_rank=0, ...):
        self.base_model = base_model
        self.max_loras_per_batch = max_loras_per_batch
        self.eviction_policy = server_args.lora_eviction_policy

        backend_type = get_backend_from_name(lora_backend)
        self.lora_backend: BaseLoRABackend = backend_type(
            max_loras_per_batch=max_loras_per_batch,
            device=self.device,
            server_args=server_args,
        )
        self.init_state(max_lora_rank=max_lora_rank,
                       target_modules=target_modules,
                       lora_paths=lora_paths)

max_loras_per_batch는 한 번의 forward에서 동시에 적용할 수 있는 LoRA 수의 상한이다.

어댑터 로딩

런타임에 새 LoRA 어댑터를 로드한다. 설정 검증 후 가중치를 메모리 풀에 적재한다.

def load_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateOutput:
    try:
        new_adapter = LoRAConfig(
            lora_ref.lora_path,
            base_vocab_size=self.base_hf_config.vocab_size,
        )
        self.validate_new_adapter(new_adapter, lora_ref)
        self.configs[lora_ref.lora_id] = new_adapter
        self.load_lora_weights(lora_ref)
        self.lora_refs[lora_ref.lora_id] = lora_ref
        self.num_pinned_loras += int(lora_ref.pinned)
    except Exception as e:
        return self.create_lora_update_result(success=False, error_message=str(e))
    return self.create_lora_update_result(success=True)

어댑터 검증

새 어댑터의 호환성을 메모리 풀 구성과 대조 검증한다.

def validate_new_adapter(self, lora_config, lora_ref):
    if lora_config.lora_added_tokens_size > 0:
        raise ValueError("LoRA serving currently doesn't support adapters that add tokens")
    memory_pool = getattr(self, "memory_pool", None)
    incompatible = memory_pool and not memory_pool.can_support(lora_config)
    if incompatible:
        raise ValueError(
            f"LoRA adapter {lora_ref.lora_name} with rank {lora_config.r} "
            "is incompatible with the current LoRA memory pool configuration.")
    if lora_ref.pinned and self.num_pinned_loras >= self.max_loras_per_batch - 1:
        raise ValueError("Not allowed to pin all slots to avoid starvation")

Pinned LoRA(항상 메모리에 유지)가 모든 슬롯을 차지하면 기아 상태가 발생하므로, 최소 1개 슬롯은 비핀 어댑터용으로 남겨둔다.

배치 준비: prepare_lora_batch

Forward 배치의 각 요청에 대해 LoRA 인덱스, rank, scaling 정보를 설정한다.

def prepare_lora_batch(self, forward_batch: ForwardBatch):
    weight_indices = [0] * len(forward_batch.lora_ids)
    lora_ranks = [0] * self.max_loras_per_batch
    scalings = [0] * self.max_loras_per_batch

    for i, uid in enumerate(forward_batch.lora_ids):
        weight_indices[i] = self.memory_pool.get_buffer_id(uid)
        if uid is not None:
            lora = self.loras[uid]
            lora_ranks[weight_indices[i]] = lora.config.r
            scalings[weight_indices[i]] = lora.scaling

    self.lora_backend.prepare_lora_batch(
        forward_batch=forward_batch,
        weight_indices=weight_indices,
        lora_ranks=lora_ranks,
        scalings=scalings,
        use_cuda_graph=use_cuda_graph,
    )

weight_indices는 각 요청이 메모리 풀의 어떤 슬롯의 LoRA 가중치를 사용할지 매핑한다. LoRA가 없는 요청(uid=None)은 rank=0이 되어 커널이 no-op으로 처리한다.

LoRA 모듈 정보 업데이트

메모리 풀의 가중치가 변경되면 모든 LoRA 레이어에 최신 버퍼를 반영한다.

def update_lora_info(self):
    for layer_id, layer_modules in enumerate(self.lora_modules):
        for module_name, module in layer_modules.items():
            if isinstance(module, FusedMoEWithLoRA):
                # MoE 레이어는 gate_up/down 별도 설정
                module.set_lora_info(
                    gate_up_lora_a_weights=gate_up_a,
                    gate_up_lora_b_weights=gate_up_b,
                    down_lora_a_weights=down_a,
                    down_lora_b_weights=down_b,
                )
                continue
            target_module = get_target_module_name(module_name, ...)
            module.set_lora_info(
                self.memory_pool.get_tensor(target_module, layer_id, LoRAType.LORA_A),
                self.memory_pool.get_tensor(target_module, layer_id, LoRAType.LORA_B),
            )

배치 검증

배치 내 LoRA ID 수가 max_loras_per_batch를 초과하지 않는지 확인한다.

def validate_lora_batch(self, lora_ids: set[Optional[str]]) -> bool:
    if len(lora_ids) > self.max_loras_per_batch:
        return False
    pinned_loras_in_batch = sum(
        int(self.lora_refs[lid].pinned) for lid in lora_ids if lid is not None
    )
    required_slots = len(lora_ids) - pinned_loras_in_batch
    mem_pool_vacancy = self.memory_pool.max_loras_per_batch - self.num_pinned_loras
    return required_slots <= mem_pool_vacancy

CUDA Graph 지원

LoRA 배치 정보의 in-place 업데이트로 CUDA Graph와 호환되도록 한다.

def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph, num_tokens_per_bs):
    self.max_bs_in_cuda_graph = max_bs_in_cuda_graph
    self.lora_backend.init_cuda_graph_batch_info(
        max_bs_in_cuda_graph=max_bs_in_cuda_graph,
        num_tokens_per_bs=num_tokens_per_bs,
    )

설계 근거

왜 메모리 풀 기반 관리인가?

매 forward마다 LoRA 가중치를 로드하면 오버헤드가 크다. 고정 크기 메모리 풀에 최대 max_loras_per_batch개의 어댑터를 미리 배치하고, 슬롯 단위로 교체(eviction)하여 가중치 로딩 비용을 분산한다.

Pinned vs Unpinned LoRA

자주 사용되는 어댑터를 pinned로 설정하면 eviction 대상에서 제외되어 항상 GPU 메모리에 상주한다. 단, 모든 슬롯이 pinned되면 새 어댑터를 수용할 수 없으므로 최소 1개 슬롯을 비워둔다.

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글