본문으로 건너뛰기

[SGLang] Multi-Layer EAGLE: 다계층 드래프트로 더 긴 추측

들어가며

기본 EAGLE은 단일 드래프트 모델 레이어로 다음 토큰을 추측한다. Multi-Layer EAGLE은 여러 개의 독립 드래프트 레이어를 순차적으로 실행하여 각 step에서 서로 다른 가중치로 예측한다. 이 방식은 단일 레이어 대비 더 높은 예측 정확도를 제공하며, 특히 긴 추측 시퀀스에서 acceptance rate 감소를 완화한다. SGLang의 MultiLayerEagleWorkermodel_runner_list를 통해 step별 독립 모델을 관리한다.

구조도

┌──────────────────────────────────────────────────┐
│              MultiLayerEagleWorker                 │
│                                                    │
│  ┌──────────┐  ┌──────────┐  ┌──────────┐        │
│  │ MTP Layer │  │ MTP Layer │  │ MTP Layer │        │
│  │  step=0   │  │  step=1   │  │  step=2   │        │
│  │(model_    │  │(model_    │  │(model_    │        │
│  │ runner[0])│  │ runner[1])│  │ runner[2])│        │
│  └─────┬─────┘  └─────┬─────┘  └─────┬─────┘      │
│        │              │              │              │
│        ▼              ▼              ▼              │
│  ┌─────────┐   ┌─────────┐   ┌─────────┐          │
│  │ topk_p₀ │   │ topk_p₁ │   │ topk_p₂ │          │
│  │ topk_i₀ │   │ topk_i₁ │   │ topk_i₂ │          │
│  └─────────┘   └─────────┘   └─────────┘          │
│        │              │              │              │
│        └──────────────┼──────────────┘              │
│                       ▼                             │
│              ┌────────────────┐                     │
│              │ concat topk_p  │                     │
│              │ concat topk_i  │                     │
│              └────────────────┘                     │
│                                                    │
│  ┌──────────────────────────────┐                  │
│  │      Target Model (Verify)   │                  │
│  └──────────────────────────────┘                  │
└──────────────────────────────────────────────────┘

핵심 코드 분석

1. 다계층 모델 초기화

python/sglang/srt/speculative/multi_layer_eagle_worker.py에서 각 step마다 독립 ModelRunner를 사용한다.

class MultiLayerEagleWorker(TpModelWorker):
    def __init__(self, server_args, ...):
        self.speculative_num_steps = server_args.speculative_num_steps
        # super().__init__()에서 model_runner_list가 생성됨
        # is_multi_layer_eagle=True 플래그로 다계층 모드 활성화

        embed, head = self.target_worker.model_runner.model.get_embed_and_head()
        for i in range(self.speculative_num_steps):
            self.mtp_model_runner(i).model.set_embed_and_head(embed, head)

    def mtp_model_runner(self, layer_id: int) -> ModelRunner:
        return self.model_runner_list[layer_id]

model_runner_list[i]가 i번째 step의 드래프트 모델이다. 모든 레이어가 타겟의 embed/head를 공유하지만 중간 가중치는 독립적이다.

2. 드래프트 단계: step별 독립 forward

def draft(self, batch: ScheduleBatch):
    scores = None
    input_ids, hidden_states, scores, tree_info = select_top_k_tokens(
        0, topk_p, topk_index, hidden_states, scores, self.topk
    )

    if self.speculative_num_steps == 1:
        score_list.append(tree_info[0])
        token_list.append(tree_info[1])
        parents_list.append(tree_info[2])
    else:
        for i in range(self.speculative_num_steps):
            score_list.append(tree_info[0][:, :, i].unsqueeze(-1))
            token_index = tree_info[1][:, i].unsqueeze(-1)
            token_list.append(token_index)
            if i == 0:
                parents_list.append(tree_info[2])
            else:
                parents_list.append(
                    torch.full((tree_info[2].size(0), 1), i, ...)
                )

Multi-Layer에서는 select_top_k_tokens의 결과를 step별로 분리하여 트리의 각 레벨을 독립적으로 관리한다. speculative_num_steps > 1일 때 score와 token을 step 차원에서 slice한다.

3. 드래프트 확장: step별 순차 실행

prefill 후 드래프트 KV 캐시를 채우는 forward_draft_extend에서 각 레이어를 순차 실행한다.

def forward_draft_extend(self, batch, hidden_states, next_token_ids, seq_lens_cpu):
    topk_p_list = []
    topk_index_list = []
    for step in range(self.speculative_num_steps):
        logits_output = self.mtp_model_runner(step).forward(forward_batch).logits_output
        probs = torch.softmax(logits_output.next_token_logits, dim=-1)
        topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
        topk_p_list.append(topk_p)
        topk_index_list.append(topk_index)
        # 다음 step의 입력 갱신
        for i, extend_len in enumerate(forward_batch.extend_seq_lens):
            input_ids = forward_batch.input_ids[pt : pt + extend_len]
            forward_batch.input_ids[pt : pt + extend_len] = torch.cat(
                (input_ids[1:], topk_index[i].reshape(1))
            )

    forward_batch.spec_info.topk_p = torch.cat(topk_p_list, dim=1)
    forward_batch.spec_info.topk_index = torch.cat(topk_index_list, dim=1)

각 step에서 이전 step의 예측 토큰을 다음 입력에 연결하고, 최종적으로 모든 step의 topk_ptopk_index를 concat한다.

4. 디코드 후 드래프트 확장: CUDA Graph 활용

forward_draft_extend_after_decode는 디코드 후 수락된 토큰으로 드래프트 캐시를 갱신한다.

def forward_draft_extend_after_decode(self, batch):
    for step in range(self.speculative_num_steps):
        can_cuda_graph = len(self.cuda_graph_runner_for_draft_extend_list) and \
            self.cuda_graph_runner_for_draft_extend_list[step].can_run(forward_batch)
        if can_cuda_graph:
            logits_output = self.cuda_graph_runner_for_draft_extend_list[step].replay(forward_batch)
        else:
            self.mtp_model_runner(step).attn_backend.init_forward_metadata(forward_batch)
            logits_output = self.mtp_model_runner(step).forward(forward_batch, ...).logits_output

각 step마다 별도의 MultiLayerEagleDraftExtendCudaGraphRunner가 있어, step별로 CUDA graph 사용 여부를 독립적으로 결정한다.

5. Attention Backend 초기화

def init_attention_backend(self):
    for step in range(self.speculative_num_steps):
        draft_backend_factory = DraftBackendFactory(
            self.server_args, self.mtp_model_runner(step),
            self.topk, self.speculative_num_steps,
        )
        self.draft_extend_attn_backend_list.append(
            draft_backend_factory.create_draft_extend_backend()
        )

각 step의 attention backend가 독립적으로 초기화된다. DraftBackendFactoryspeculative_attention_mode 설정에 따라 적절한 backend를 생성한다.

6. 검증 단계

Multi-Layer EAGLE의 검증은 기본 EAGLE과 동일한 EagleVerifyInput.verify()를 사용한다.

def verify(self, batch, spec_info: EagleVerifyInput):
    spec_info.prepare_for_verify(batch, self.page_size)
    batch.forward_mode = ForwardMode.TARGET_VERIFY
    batch.spec_info = spec_info
    model_worker_batch.return_hidden_states_before_norm = True
    batch_result = self.target_worker.forward_batch_generation(model_worker_batch, is_verify=True)
    res = spec_info.verify(batch, logits_output, self.token_to_kv_pool_allocator, ...)

return_hidden_states_before_norm = True로 정규화 전 hidden states를 받아 다음 드래프트 입력으로 활용한다.

단일 계층 vs 다계층 비교

항목 Single-Layer EAGLE Multi-Layer EAGLE
모델 수 1개 드래프트 모델 N개 (step당 1개)
메모리 낮음 높음 (N배 가중치)
예측 정확도 step이 증가할수록 감소 step별 전문 모델로 완화
CUDA Graph 1개 runner step별 독립 runner
hidden_states 필요 예 (타겟에서 추출) 아니오 (enable_multi_layer_eagle 시 skip)

spec_need_hidden_states() 함수가 이를 반영한다:

def spec_need_hidden_states(server_args=None):
    return not server_args.enable_multi_layer_eagle

설계 근거

step별 전문화: 1번째 토큰 예측과 5번째 토큰 예측은 난이도가 다르다. 각 step에 특화된 모델을 두면 깊은 추측에서도 높은 acceptance rate를 유지할 수 있다.

CUDA Graph step 분리: step별로 독립 CUDA graph를 캡처하여 배치 크기 변화에 유연하게 대응한다.

캐시 풀 공유: 드래프트와 타겟이 동일한 token_to_kv_pool_allocator를 공유하여 clear_cache_pool이 no-op이다.

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글