[SGLang] LoRA + MoE 융합: 어댑터와 전문가 혼합의 통합
들어가며
MoE(Mixture of Experts) 모델에 LoRA를 적용하려면, 전문가(expert) 라우팅과 LoRA delta를 올바른 순서로 결합해야 한다. 핵심은 LoRA delta가 활성화 함수(SiLU/GELU) 이전에 더해져야 한다는 것이다. SGLang의 TritonRunnerCoreWithLoRA는 MoE forward의 Gate/Up → Activation → Down 파이프라인 사이에 LoRA delta를 정확히 삽입한다.
구조도
입력 hidden_states
│
▼
┌──────────────────────────────────────────────────┐
│ Stage 1: Gate/Up 프로젝션 (base) │
│ invoke_fused_moe_kernel(hidden, w13, ...) │
│ → intermediate_cache1 │
└───────────────────────┬──────────────────────────┘
│
▼
┌──────────────────────────────────────────────────┐
│ Stage 1.5: LoRA Gate/Up Delta │
│ _add_lora_gate_up_delta() │
│ intermediate_cache1 += LoRA delta │
│ (activation 함수 이전에 적용!) │
└───────────────────────┬──────────────────────────┘
│
▼
┌──────────────────────────────────────────────────┐
│ Stage 2: Activation (SiLU/GELU) │
│ silu_and_mul(intermediate_cache1) → │
│ intermediate_cache2 │
└───────────────────────┬──────────────────────────┘
│
▼
┌──────────────────────────────────────────────────┐
│ Stage 3: Down 프로젝션 (base) │
│ invoke_fused_moe_kernel(cache2, w2, ...) │
│ → intermediate_cache3 │
└───────────────────────┬──────────────────────────┘
│
▼
┌──────────────────────────────────────────────────┐
│ Stage 3.5: LoRA Down Delta │
│ _add_lora_down_delta() │
│ intermediate_cache3 += LoRA delta │
│ (최종 reduce 이전에 적용!) │
└───────────────────────┬──────────────────────────┘
│
▼
최종 출력
핵심 코드 분석
LoRAInfo: MoE LoRA 메타데이터
python/sglang/srt/lora/lora_moe_runners.py에서 MoE LoRA에 필요한 모든 정보를 담는 데이터 클래스이다.
@dataclass
class LoRAInfo:
gate_up_lora_a_weights: torch.Tensor # [num_loras, num_experts_or_1, max_rank, hidden_dim]
gate_up_lora_b_weights: torch.Tensor # [num_loras, num_experts, gate_up_dim, max_rank]
down_lora_a_weights: torch.Tensor # [num_loras, num_experts, max_rank, intermediate_dim]
down_lora_b_weights: torch.Tensor # [num_loras, num_experts_or_1, hidden_dim, max_rank]
seg_indptr: torch.Tensor # (num_segments + 1,)
req_to_lora: torch.Tensor # (num_segments,)
lora_ranks: torch.Tensor # [num_loras]
adapter_enabled: torch.Tensor # [num_loras]
max_lora_rank: int
num_experts: int
experts_shared_outer_loras: bool = False
experts_shared_outer_loras=True일 때 gate_up_lora_a는 [num_loras, 1, ...] 형태로, 모든 전문가가 동일한 A 행렬을 공유한다. B 행렬은 전문가별 독립이다.
TritonRunnerCoreWithLoRA: MoE + LoRA 실행기
기본 TritonRunnerCore를 확장하여 LoRA delta 삽입 지점을 추가한다.
class TritonRunnerCoreWithLoRA(TritonRunnerCore):
def run(self, runner_input, quant_info, running_state, lora_info=None):
if lora_info is None:
return super().run(runner_input, quant_info, running_state)
if get_is_capture_mode():
has_active_lora = True # CUDA graph 캡처 시 항상 LoRA 경로 기록
else:
has_active_lora = lora_info.has_active_lora
if not has_active_lora:
return super().run(runner_input, quant_info, running_state)
CUDA Graph 캡처 시에도 LoRA 커널을 기록하되, adapter_enabled가 모두 0이므로 실질적 연산은 없다. 재생 시 실제 마스크로 교체된다.
MoE LoRA Alignment
표준 MoE의 토큰 정렬(sorted_token_ids, expert_ids)과 별도로, LoRA를 위한 정렬을 수행한다.
moe_lora_align_block_size(
topk_ids,
lora_info.seg_indptr,
lora_info.req_to_lora,
int(lora_info.num_experts),
int(block_size_m),
int(max_loras),
int(max_num_tokens_padded),
int(max_num_m_blocks),
sorted_token_ids_lora,
expert_ids_lora,
num_tokens_post_padded_lora,
lora_info.adapter_enabled,
lora_ids,
None, # expert_map
)
이 함수는 (adapter_id, expert_id) 쌍별로 토큰을 정렬한다. 결과는 [max_loras, max_num_tokens_padded] 형태의 2D 텐서로, 각 LoRA 어댑터별로 독립적인 토큰 목록을 제공한다.
Stage 1.5: Gate/Up LoRA Delta
기본 Gate/Up 프로젝션 결과에 LoRA delta를 더한다.
self._add_lora_gate_up_delta(
hidden_states=hidden_states,
intermediate_cache=intermediate_cache1,
topk_weights=topk_weights,
lora_info=lora_info,
sorted_token_ids_reshaped=sorted_token_ids_reshaped,
expert_ids_reshaped=expert_ids_reshaped,
num_tokens_post_padded_lora=num_tokens_post_padded_lora,
lora_ids=lora_ids,
)
이 단계가 activation 함수 이전에 수행되어야 하는 이유: LoRA delta가 활성화 이후에 적용되면 비선형 변환의 입력이 달라져 수학적으로 W + BA와 등가가 되지 않는다.
Fused MoE LoRA 커널
python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py의 Triton 커널은 MoE의 expert-parallel 구조와 LoRA의 adapter-parallel 구조를 융합한다.
@triton.jit
def _fused_moe_lora_kernel(
a_ptr, b_ptr, c_ptr,
topk_weights_ptr, sorted_token_ids_ptr, expert_ids_ptr,
num_tokens_post_padded_ptr,
N, K, EM, num_valid_tokens, num_experts,
lora_ids, adapter_enabled,
...
):
커널 입력에서 lora_ids와 adapter_enabled를 받아, 각 program에서 해당 LoRA 어댑터의 전문가별 가중치를 선택하여 적용한다.
GPU 포인터 룩업 테이블
LoRA 가중치 포인터를 룩업 테이블로 캐싱하여 profile run 이후 변경되지 않는 포인터를 재사용한다.
_LORA_PTR_DICT: dict[tuple[int, ...], torch.Tensor] = {}
def _get_ptr(lora_weights: list[torch.Tensor], device):
key = tuple(lora_weight.data_ptr() for lora_weight in lora_weights)
if (ptr_tensor := _LORA_PTR_DICT.get(key)) is not None:
return ptr_tensor
tensor_ptrs = [lora_weight.data_ptr() for lora_weight in lora_weights]
ptr_tensor = torch.tensor(tensor_ptrs, device=device, dtype=torch.uint64)
_LORA_PTR_DICT[key] = ptr_tensor
return ptr_tensor
CUDA Graph 버퍼 사전 할당
CUDA Graph 호환을 위해 중간 버퍼를 사전 할당한다.
if cg is not None:
sorted_token_ids_lora = cg["sorted_token_ids_lora"][:max_loras * max_num_tokens_padded]
expert_ids_lora = cg["expert_ids_lora"][:max_loras * max_num_m_blocks]
num_tokens_post_padded_lora = cg["num_tokens_post_padded_lora"][:max_loras]
else:
sorted_token_ids_lora = torch.empty(
(max_loras * max_num_tokens_padded,), dtype=torch.int32, device=device)
설계 근거
experts_shared_outer_loras
A 행렬(outer projection)을 전문가 간 공유하면 메모리를 절약할 수 있다. 전문가별 특화는 B 행렬(inner projection)에서만 이루어진다. 이는 [num_loras, 1, rank, hidden_dim] vs [num_loras, num_experts, rank, hidden_dim]의 차이이다.
LoRA delta 삽입 위치의 중요성
일반 Linear 레이어에서는 LoRA delta를 어디에 더하든 결과가 동일하지만, MoE에서는 gate_up → activation → down 파이프라인 내의 삽입 위치가 결과에 영향을 준다. 활성화 함수 전에 delta를 적용해야 (W + BA)(x)와 수학적으로 등가이다.
관련 포스트
- LoRA Triton 커널: SGMV, SGEMM 최적화 연산
- LoRA 백엔드: PyTorch, Triton, Chunked 구현 비교
- LoRA Manager: 어댑터 라이프사이클 관리
참고
관련 포스트
SGLang 의 다른글
- 이전글 [SGLang] LoRA Triton 커널: SGMV, SGEMM 최적화 연산
- 현재글 : [SGLang] LoRA + MoE 융합: 어댑터와 전문가 혼합의 통합
- 다음글 [SGLang] LoRA Eviction: 어댑터 캐시 관리와 퇴거 정책
댓글