본문으로 건너뛰기

[SGLang] LoRA Layers: QKV, Gate/Up 프로젝션 어댑터

들어가며

LoRA(Low-Rank Adaptation)는 사전학습된 모델의 가중치를 동결하고, 저랭크 행렬 A와 B를 추가하여 W + B @ A 형태로 파인튜닝한다. SGLang은 Transformer의 주요 레이어(Embedding, QKV, Gate/Up, Down, LM Head)에 LoRA 래퍼를 제공하며, 각 레이어의 병렬화 특성에 맞게 가중치를 슬라이싱한다.

구조도

Input x
   │
   ├──► base_layer.forward(x) ──► base_output
   │
   └──► LoRA path:
        x ──► lora_A (shrink) ──► lora_B (expand) ──► lora_delta
                                                         │
                                                         ▼
                                              base_output + lora_delta
                                                         │
                                                         ▼
                                                      output

레이어별 LoRA 적용:
┌────────────────┬─────────────────┬──────────────────┐
│ VocabEmbeddingQKV ProjectionGate/Up + Down   │
│ A: embed lookupA: (3r, in_dim) │ A: (2r, in_dim)  │
│ B: (emb, r)    │ B: (q+2kv, r)  │ B: (2*out, r)    │
└────────────────┴─────────────────┴──────────────────┘

핵심 코드 분석

BaseLayerWithLoRA: 기본 래퍼

python/sglang/srt/lora/layers.py의 모든 LoRA 레이어는 BaseLayerWithLoRA를 상속한다.

class BaseLayerWithLoRA(nn.Module):
    def __init__(self, base_layer: nn.Module, lora_backend: BaseLoRABackend):
        super().__init__()
        self.base_layer = base_layer
        self.set_lora: bool = False
        self.lora_backend = lora_backend
        if hasattr(self.base_layer, "weight"):
            self.weight = self.base_layer.weight

    def forward(self, x: torch.Tensor):
        return self.base_layer.forward(x)

set_lora가 False이면 기본 레이어 그대로 동작한다. LoRA 가중치가 설정되면 forward에서 delta를 추가한다.

VocabParallelEmbeddingWithLoRA: 임베딩 LoRA

임베딩 레이어의 LoRA는 특별하다. A 행렬은 임베딩 룩업으로 동작하고, B 행렬은 일반 행렬곱이다.

class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
    def apply_lora(self, base_output, input_, batch_info):
        # A: embedding lookup (rank, vocab_size) -> (s, rank)
        lora_a_output = self.run_lora_a_embedding(input_, batch_info)
        # B: matmul (embed_dim, rank) -> (s, embed_dim)
        lora_output = self.lora_backend.run_lora_b_sgemm(
            x=lora_a_output,
            weights=self.embedding_B_buffer,
            output_offset=self.output_offset,
            base_output=base_output,
        )
        return lora_output

TP > 1 환경에서 임베딩 LoRA는 가중치를 샤딩하지 않고 전체 복제한다. 베이스 임베딩의 all-reduce 후에 LoRA delta를 더하므로 수학적으로 올바르다.

def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
    return A  # 비샤딩

def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
    return B  # 비샤딩

ParallelLMHeadWithLoRA: LM Head LoRA

LM Head는 column-parallel이므로 B 행렬을 vocab 차원으로 슬라이싱한다.

class ParallelLMHeadWithLoRA(BaseLayerWithLoRA):
    def __init__(self, base_layer: ParallelLMHead, lora_backend):
        super().__init__(base_layer, lora_backend)
        tp_size = base_layer.tp_size if hasattr(base_layer, "tp_size") else 1
        if tp_size > 1:
            self.shard_vocab_size = get_lm_head_lora_b_shard_size(
                self.vocab_size, shard_indices=base_layer.shard_indices)
            self.output_offset = torch.tensor(
                [0, self.shard_vocab_size], dtype=torch.int32, device=...)

LoRA 적용: output = hidden @ W^T + (hidden @ A^T) @ B^T

def apply_lora(self, base_output, hidden_states):
    lm_head_batch_info = self._get_lm_head_batch_info(hidden_states.shape[0])
    lora_a_output = self.lora_backend.run_lora_a_sgemm(
        hidden_states, self.lm_head_A_buffer,
        pruned_batch_info=lm_head_batch_info)
    lora_output = self.lora_backend.run_lora_b_sgemm(
        x=lora_a_output, weights=self.lm_head_B_buffer,
        output_offset=self.output_offset,
        base_output=base_output,
        pruned_batch_info=lm_head_batch_info)
    return lora_output

LM Head는 Chunked Logprobs에서 여러 pass로 나뉠 수 있으므로, _lm_head_pass_idx로 pass별 batch_info를 관리한다.

def set_lm_head_pass(self, pass_idx: int):
    self.lora_backend._lm_head_pass_idx = pass_idx

def reset_lm_head_pass(self):
    self.lora_backend._lm_head_pass_idx = None

LoRA 가중치 슬라이싱 규칙

TP 환경에서 LoRA A/B 가중치의 슬라이싱은 베이스 레이어의 병렬화 방식에 따라 결정된다.

┌──────────────────────┬────────────┬────────────┐
│ 레이어 타입            │ LoRA A     │ LoRA B     │
├──────────────────────┼────────────┼────────────┤
│ VocabEmbedding       │ 비샤딩      │ 비샤딩      │
│ ColumnParallelLinear │ 비샤딩      │ 열 방향 샤딩 │
│ RowParallelLinear    │ 열 방향 샤딩 │ 비샤딩      │
│ QKVParallelLinear    │ 비샤딩      │ qkv 별 샤딩 │
│ ParallelLMHead       │ 비샤딩      │ vocab 샤딩  │
└──────────────────────┴────────────┴────────────┘

LM Head B 슬라이싱의 실제 구현:

def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
    tp_size = self.base_layer.tp_size if hasattr(self.base_layer, "tp_size") else 1
    if tp_size <= 1:
        return B
    start_idx = self.base_layer.shard_indices.org_vocab_start_index
    end_idx = self.base_layer.shard_indices.org_vocab_end_index
    # B[:end_idx - start_idx, :] 범위로 슬라이싱

Extra Token 처리

Vocab 크기를 넘는 추가 토큰이 있는 경우에 대한 인터페이스가 정의되어 있으나, 현재 SGLang에서는 미지원이다.

def extra_token_embedding(self, input_, base_output):
    raise NotImplementedError(
        "Current SGLang codebase did not support tuned lora with extra/added tokens."
    )

Forward에서는 추가 토큰을 0으로 마스킹하여 base embedding의 out-of-bounds 접근을 방지한다.

def forward(self, input_: torch.Tensor):
    added_tokens_mask = input_ > self.vocab_size - 1
    base_output = self.base_layer.forward(input_.masked_fill(added_tokens_mask, 0))
    if self.set_lora:
        base_output = self.apply_lora(base_output, input_, batch_info)
    return base_output

설계 근거

SGEMM 기반 멀티 LoRA 배칭

여러 요청이 서로 다른 LoRA를 사용하더라도, Segmented GEMM(SGEMM)을 사용하면 단일 커널 호출로 모든 요청의 LoRA 연산을 처리할 수 있다. 이는 S-LoRA 논문의 핵심 기여이다.

rank=0의 no-op 보장

LoRA가 없는 요청(base model only)은 weight_indices에서 rank=0 슬롯을 가리킨다. SGEMM 커널은 rank=0일 때 즉시 반환하여 연산 오버헤드가 없다.

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글