[SGLang] LoRA Layers: QKV, Gate/Up 프로젝션 어댑터
들어가며
LoRA(Low-Rank Adaptation)는 사전학습된 모델의 가중치를 동결하고, 저랭크 행렬 A와 B를 추가하여 W + B @ A 형태로 파인튜닝한다. SGLang은 Transformer의 주요 레이어(Embedding, QKV, Gate/Up, Down, LM Head)에 LoRA 래퍼를 제공하며, 각 레이어의 병렬화 특성에 맞게 가중치를 슬라이싱한다.
구조도
Input x
│
├──► base_layer.forward(x) ──► base_output
│
└──► LoRA path:
x ──► lora_A (shrink) ──► lora_B (expand) ──► lora_delta
│
▼
base_output + lora_delta
│
▼
output
레이어별 LoRA 적용:
┌────────────────┬─────────────────┬──────────────────┐
│ VocabEmbedding │ QKV Projection │ Gate/Up + Down │
│ A: embed lookup│ A: (3r, in_dim) │ A: (2r, in_dim) │
│ B: (emb, r) │ B: (q+2kv, r) │ B: (2*out, r) │
└────────────────┴─────────────────┴──────────────────┘
핵심 코드 분석
BaseLayerWithLoRA: 기본 래퍼
python/sglang/srt/lora/layers.py의 모든 LoRA 레이어는 BaseLayerWithLoRA를 상속한다.
class BaseLayerWithLoRA(nn.Module):
def __init__(self, base_layer: nn.Module, lora_backend: BaseLoRABackend):
super().__init__()
self.base_layer = base_layer
self.set_lora: bool = False
self.lora_backend = lora_backend
if hasattr(self.base_layer, "weight"):
self.weight = self.base_layer.weight
def forward(self, x: torch.Tensor):
return self.base_layer.forward(x)
set_lora가 False이면 기본 레이어 그대로 동작한다. LoRA 가중치가 설정되면 forward에서 delta를 추가한다.
VocabParallelEmbeddingWithLoRA: 임베딩 LoRA
임베딩 레이어의 LoRA는 특별하다. A 행렬은 임베딩 룩업으로 동작하고, B 행렬은 일반 행렬곱이다.
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
def apply_lora(self, base_output, input_, batch_info):
# A: embedding lookup (rank, vocab_size) -> (s, rank)
lora_a_output = self.run_lora_a_embedding(input_, batch_info)
# B: matmul (embed_dim, rank) -> (s, embed_dim)
lora_output = self.lora_backend.run_lora_b_sgemm(
x=lora_a_output,
weights=self.embedding_B_buffer,
output_offset=self.output_offset,
base_output=base_output,
)
return lora_output
TP > 1 환경에서 임베딩 LoRA는 가중치를 샤딩하지 않고 전체 복제한다. 베이스 임베딩의 all-reduce 후에 LoRA delta를 더하므로 수학적으로 올바르다.
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
return A # 비샤딩
def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
return B # 비샤딩
ParallelLMHeadWithLoRA: LM Head LoRA
LM Head는 column-parallel이므로 B 행렬을 vocab 차원으로 슬라이싱한다.
class ParallelLMHeadWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: ParallelLMHead, lora_backend):
super().__init__(base_layer, lora_backend)
tp_size = base_layer.tp_size if hasattr(base_layer, "tp_size") else 1
if tp_size > 1:
self.shard_vocab_size = get_lm_head_lora_b_shard_size(
self.vocab_size, shard_indices=base_layer.shard_indices)
self.output_offset = torch.tensor(
[0, self.shard_vocab_size], dtype=torch.int32, device=...)
LoRA 적용: output = hidden @ W^T + (hidden @ A^T) @ B^T
def apply_lora(self, base_output, hidden_states):
lm_head_batch_info = self._get_lm_head_batch_info(hidden_states.shape[0])
lora_a_output = self.lora_backend.run_lora_a_sgemm(
hidden_states, self.lm_head_A_buffer,
pruned_batch_info=lm_head_batch_info)
lora_output = self.lora_backend.run_lora_b_sgemm(
x=lora_a_output, weights=self.lm_head_B_buffer,
output_offset=self.output_offset,
base_output=base_output,
pruned_batch_info=lm_head_batch_info)
return lora_output
LM Head는 Chunked Logprobs에서 여러 pass로 나뉠 수 있으므로, _lm_head_pass_idx로 pass별 batch_info를 관리한다.
def set_lm_head_pass(self, pass_idx: int):
self.lora_backend._lm_head_pass_idx = pass_idx
def reset_lm_head_pass(self):
self.lora_backend._lm_head_pass_idx = None
LoRA 가중치 슬라이싱 규칙
TP 환경에서 LoRA A/B 가중치의 슬라이싱은 베이스 레이어의 병렬화 방식에 따라 결정된다.
┌──────────────────────┬────────────┬────────────┐
│ 레이어 타입 │ LoRA A │ LoRA B │
├──────────────────────┼────────────┼────────────┤
│ VocabEmbedding │ 비샤딩 │ 비샤딩 │
│ ColumnParallelLinear │ 비샤딩 │ 열 방향 샤딩 │
│ RowParallelLinear │ 열 방향 샤딩 │ 비샤딩 │
│ QKVParallelLinear │ 비샤딩 │ qkv 별 샤딩 │
│ ParallelLMHead │ 비샤딩 │ vocab 샤딩 │
└──────────────────────┴────────────┴────────────┘
LM Head B 슬라이싱의 실제 구현:
def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
tp_size = self.base_layer.tp_size if hasattr(self.base_layer, "tp_size") else 1
if tp_size <= 1:
return B
start_idx = self.base_layer.shard_indices.org_vocab_start_index
end_idx = self.base_layer.shard_indices.org_vocab_end_index
# B[:end_idx - start_idx, :] 범위로 슬라이싱
Extra Token 처리
Vocab 크기를 넘는 추가 토큰이 있는 경우에 대한 인터페이스가 정의되어 있으나, 현재 SGLang에서는 미지원이다.
def extra_token_embedding(self, input_, base_output):
raise NotImplementedError(
"Current SGLang codebase did not support tuned lora with extra/added tokens."
)
Forward에서는 추가 토큰을 0으로 마스킹하여 base embedding의 out-of-bounds 접근을 방지한다.
def forward(self, input_: torch.Tensor):
added_tokens_mask = input_ > self.vocab_size - 1
base_output = self.base_layer.forward(input_.masked_fill(added_tokens_mask, 0))
if self.set_lora:
base_output = self.apply_lora(base_output, input_, batch_info)
return base_output
설계 근거
SGEMM 기반 멀티 LoRA 배칭
여러 요청이 서로 다른 LoRA를 사용하더라도, Segmented GEMM(SGEMM)을 사용하면 단일 커널 호출로 모든 요청의 LoRA 연산을 처리할 수 있다. 이는 S-LoRA 논문의 핵심 기여이다.
rank=0의 no-op 보장
LoRA가 없는 요청(base model only)은 weight_indices에서 rank=0 슬롯을 가리킨다. SGEMM 커널은 rank=0일 때 즉시 반환하여 연산 오버헤드가 없다.
관련 포스트
- LoRA Manager: 어댑터 라이프사이클 관리
- LoRA 백엔드: PyTorch, Triton, Chunked 구현 비교
- LoRA Triton 커널: SGMV, SGEMM 최적화 연산
참고
관련 포스트
- [sglang] DeepSeek-V4의 Latency 최적화: Fused mHC Post/Pre Kernel 도입
- [sglang] sglang ROCm MXFP4 어텐션에서 불필요한 contiguous copy 제거를 통한 성능 최적화
- [sglang] sglang의 torch.compile 활용: Advanced Indexing Gather 최적화로 LLM 추론 가속화
- [sglang] sglang diffusion 모델 성능 향상: Cache-DiT와 torch.compile의 최적화된 적용 순서
- [sglang] NixlKVManager 성능 향상: 비동기 및 멀티스레드 KV 전송 도입
SGLang 의 다른글
- 이전글 [SGLang] LoRA Manager: 어댑터 라이프사이클 관리
- 현재글 : [SGLang] LoRA Layers: QKV, Gate/Up 프로젝션 어댑터
- 다음글 [SGLang] LoRA 백엔드: PyTorch, Triton, Chunked 구현 비교
댓글