본문으로 건너뛰기

[SGLang] LoRA + MoE 융합: 어댑터와 전문가 혼합의 통합

들어가며

MoE(Mixture of Experts) 모델에 LoRA를 적용하려면, 전문가(expert) 라우팅과 LoRA delta를 올바른 순서로 결합해야 한다. 핵심은 LoRA delta가 활성화 함수(SiLU/GELU) 이전에 더해져야 한다는 것이다. SGLang의 TritonRunnerCoreWithLoRA는 MoE forward의 Gate/Up → Activation → Down 파이프라인 사이에 LoRA delta를 정확히 삽입한다.

구조도

입력 hidden_states
       │
       ▼
┌──────────────────────────────────────────────────┐
│ Stage 1: Gate/Up 프로젝션 (base)                   │
│   invoke_fused_moe_kernel(hidden, w13, ...)       │
│   → intermediate_cache1                           │
└───────────────────────┬──────────────────────────┘
                        │
                        ▼
┌──────────────────────────────────────────────────┐
│ Stage 1.5: LoRA Gate/Up Delta                     │
│   _add_lora_gate_up_delta()                       │
│   intermediate_cache1 += LoRA delta               │
│   (activation 함수 이전에 적용!)                    │
└───────────────────────┬──────────────────────────┘
                        │
                        ▼
┌──────────────────────────────────────────────────┐
│ Stage 2: Activation (SiLU/GELU)                   │
│   silu_and_mul(intermediate_cache1) →              │
│   intermediate_cache2                              │
└───────────────────────┬──────────────────────────┘
                        │
                        ▼
┌──────────────────────────────────────────────────┐
│ Stage 3: Down 프로젝션 (base)                      │
│   invoke_fused_moe_kernel(cache2, w2, ...)        │
│   → intermediate_cache3                            │
└───────────────────────┬──────────────────────────┘
                        │
                        ▼
┌──────────────────────────────────────────────────┐
│ Stage 3.5: LoRA Down Delta                        │
│   _add_lora_down_delta()                          │
│   intermediate_cache3 += LoRA delta               │
│   (최종 reduce 이전에 적용!)                        │
└───────────────────────┬──────────────────────────┘
                        │
                        ▼
                   최종 출력

핵심 코드 분석

LoRAInfo: MoE LoRA 메타데이터

python/sglang/srt/lora/lora_moe_runners.py에서 MoE LoRA에 필요한 모든 정보를 담는 데이터 클래스이다.

@dataclass
class LoRAInfo:
    gate_up_lora_a_weights: torch.Tensor   # [num_loras, num_experts_or_1, max_rank, hidden_dim]
    gate_up_lora_b_weights: torch.Tensor   # [num_loras, num_experts, gate_up_dim, max_rank]
    down_lora_a_weights: torch.Tensor      # [num_loras, num_experts, max_rank, intermediate_dim]
    down_lora_b_weights: torch.Tensor      # [num_loras, num_experts_or_1, hidden_dim, max_rank]

    seg_indptr: torch.Tensor       # (num_segments + 1,)
    req_to_lora: torch.Tensor      # (num_segments,)
    lora_ranks: torch.Tensor       # [num_loras]
    adapter_enabled: torch.Tensor  # [num_loras]
    max_lora_rank: int

    num_experts: int
    experts_shared_outer_loras: bool = False

experts_shared_outer_loras=True일 때 gate_up_lora_a는 [num_loras, 1, ...] 형태로, 모든 전문가가 동일한 A 행렬을 공유한다. B 행렬은 전문가별 독립이다.

TritonRunnerCoreWithLoRA: MoE + LoRA 실행기

기본 TritonRunnerCore를 확장하여 LoRA delta 삽입 지점을 추가한다.

class TritonRunnerCoreWithLoRA(TritonRunnerCore):
    def run(self, runner_input, quant_info, running_state, lora_info=None):
        if lora_info is None:
            return super().run(runner_input, quant_info, running_state)

        if get_is_capture_mode():
            has_active_lora = True  # CUDA graph 캡처 시 항상 LoRA 경로 기록
        else:
            has_active_lora = lora_info.has_active_lora
        if not has_active_lora:
            return super().run(runner_input, quant_info, running_state)

CUDA Graph 캡처 시에도 LoRA 커널을 기록하되, adapter_enabled가 모두 0이므로 실질적 연산은 없다. 재생 시 실제 마스크로 교체된다.

MoE LoRA Alignment

표준 MoE의 토큰 정렬(sorted_token_ids, expert_ids)과 별도로, LoRA를 위한 정렬을 수행한다.

moe_lora_align_block_size(
    topk_ids,
    lora_info.seg_indptr,
    lora_info.req_to_lora,
    int(lora_info.num_experts),
    int(block_size_m),
    int(max_loras),
    int(max_num_tokens_padded),
    int(max_num_m_blocks),
    sorted_token_ids_lora,
    expert_ids_lora,
    num_tokens_post_padded_lora,
    lora_info.adapter_enabled,
    lora_ids,
    None,  # expert_map
)

이 함수는 (adapter_id, expert_id) 쌍별로 토큰을 정렬한다. 결과는 [max_loras, max_num_tokens_padded] 형태의 2D 텐서로, 각 LoRA 어댑터별로 독립적인 토큰 목록을 제공한다.

Stage 1.5: Gate/Up LoRA Delta

기본 Gate/Up 프로젝션 결과에 LoRA delta를 더한다.

self._add_lora_gate_up_delta(
    hidden_states=hidden_states,
    intermediate_cache=intermediate_cache1,
    topk_weights=topk_weights,
    lora_info=lora_info,
    sorted_token_ids_reshaped=sorted_token_ids_reshaped,
    expert_ids_reshaped=expert_ids_reshaped,
    num_tokens_post_padded_lora=num_tokens_post_padded_lora,
    lora_ids=lora_ids,
)

이 단계가 activation 함수 이전에 수행되어야 하는 이유: LoRA delta가 활성화 이후에 적용되면 비선형 변환의 입력이 달라져 수학적으로 W + BA와 등가가 되지 않는다.

Fused MoE LoRA 커널

python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py의 Triton 커널은 MoE의 expert-parallel 구조와 LoRA의 adapter-parallel 구조를 융합한다.

@triton.jit
def _fused_moe_lora_kernel(
    a_ptr, b_ptr, c_ptr,
    topk_weights_ptr, sorted_token_ids_ptr, expert_ids_ptr,
    num_tokens_post_padded_ptr,
    N, K, EM, num_valid_tokens, num_experts,
    lora_ids, adapter_enabled,
    ...
):

커널 입력에서 lora_idsadapter_enabled를 받아, 각 program에서 해당 LoRA 어댑터의 전문가별 가중치를 선택하여 적용한다.

GPU 포인터 룩업 테이블

LoRA 가중치 포인터를 룩업 테이블로 캐싱하여 profile run 이후 변경되지 않는 포인터를 재사용한다.

_LORA_PTR_DICT: dict[tuple[int, ...], torch.Tensor] = {}

def _get_ptr(lora_weights: list[torch.Tensor], device):
    key = tuple(lora_weight.data_ptr() for lora_weight in lora_weights)
    if (ptr_tensor := _LORA_PTR_DICT.get(key)) is not None:
        return ptr_tensor
    tensor_ptrs = [lora_weight.data_ptr() for lora_weight in lora_weights]
    ptr_tensor = torch.tensor(tensor_ptrs, device=device, dtype=torch.uint64)
    _LORA_PTR_DICT[key] = ptr_tensor
    return ptr_tensor

CUDA Graph 버퍼 사전 할당

CUDA Graph 호환을 위해 중간 버퍼를 사전 할당한다.

if cg is not None:
    sorted_token_ids_lora = cg["sorted_token_ids_lora"][:max_loras * max_num_tokens_padded]
    expert_ids_lora = cg["expert_ids_lora"][:max_loras * max_num_m_blocks]
    num_tokens_post_padded_lora = cg["num_tokens_post_padded_lora"][:max_loras]
else:
    sorted_token_ids_lora = torch.empty(
        (max_loras * max_num_tokens_padded,), dtype=torch.int32, device=device)

설계 근거

experts_shared_outer_loras

A 행렬(outer projection)을 전문가 간 공유하면 메모리를 절약할 수 있다. 전문가별 특화는 B 행렬(inner projection)에서만 이루어진다. 이는 [num_loras, 1, rank, hidden_dim] vs [num_loras, num_experts, rank, hidden_dim]의 차이이다.

LoRA delta 삽입 위치의 중요성

일반 Linear 레이어에서는 LoRA delta를 어디에 더하든 결과가 동일하지만, MoE에서는 gate_up → activation → down 파이프라인 내의 삽입 위치가 결과에 영향을 준다. 활성화 함수 전에 delta를 적용해야 (W + BA)(x)와 수학적으로 등가이다.

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글