본문으로 건너뛰기

[SGLang] DFlash: Flash 기반 고속 드래프팅

들어가며

DFlash는 SGLang의 speculative decoding 알고리즘 중 하나로, 드래프트 모델이 블록 단위로 토큰을 한 번에 생성하는 방식이다. EAGLE이 autoregressive하게 한 토큰씩 드래프트를 생성하는 반면, DFlash는 mask token을 사용하여 블록 전체를 병렬로 생성한다. 이를 통해 드래프팅 단계 자체의 latency를 크게 줄인다. 검증은 선형(non-tree) 구조로 수행하며, 타겟 모델의 hidden states를 활용하여 드래프트 KV 캐시를 효율적으로 관리한다.

구조도

┌──────────────────────────────────────────────────────┐
│                    DFlashWorker                       │
│                                                      │
│  ┌──────────────┐      ┌──────────────────────────┐ │
│  │ Target Worker │      │ Draft Worker (TpModelWorker│ │
│  │ (target_worker│      │  별도 attention backend)  │ │
│  └──────┬───────┘      └───────────┬──────────────┘ │
│         │                          │                 │
│  Target Extend:                 Draft:               │
│  hidden_states ─────────▶ target_hidden              │
│         │                   │                        │
│         │              ┌────▼─────────┐              │
│         │              │ Block Draft   │              │
│         │              │ [M][M]...[M]  │              │
│         │              │ mask_token_id │              │
│         │              │ → block_size  │              │
│         │              │   tokens      │              │
│         │              └────┬──────────┘              │
│         │                   │                        │
│         │              ┌────▼─────────┐              │
│         │              │DFlashVerify   │              │
│         │              │Input          │              │
│         ▼              │ .draft_token  │              │
│  ┌──────────────┐      │ .positions    │              │
│  │Target Verify  │◀────│ .custom_mask  │              │
│  │(linear verify)│      └──────────────┘              │
│  └──────┬───────┘                                    │
│         │                                            │
│         ▼                                            │
│  accept_len + bonus → DFlashDraftInput (다음 step)    │
└──────────────────────────────────────────────────────┘

핵심 코드 분석

1. DFlashWorker 초기화

python/sglang/srt/speculative/dflash_worker.py에서 DFlashWorker는 별도의 드래프트 워커를 생성한다.

class DFlashWorker:
    def __init__(self, server_args, ..., target_worker):
        self.target_worker = target_worker
        self.draft_window_size = server_args.speculative_dflash_draft_window_size
        self.use_compact_draft_cache = self.draft_window_size is not None

        draft_server_args = deepcopy(server_args)
        draft_server_args.attention_backend = draft_backend  # flashinfer/fa3/fa4
        self.draft_worker = TpModelWorker(
            server_args=draft_server_args, is_draft_worker=True,
            req_to_token_pool=shared_req_to_token_pool,
            token_to_kv_pool_allocator=target_token_to_kv_pool_allocator,
        )
        self.block_size = server_args.speculative_num_draft_tokens

draft_window_size가 설정되면 compact draft cache 모드가 활성화되어, 최근 윈도우만 드래프트 KV 캐시에 유지한다. 드래프트 attention backend는 flashinfer, fa3, fa4만 지원한다.

2. Mask Token 기반 블록 드래프팅

DFlash의 핵심은 mask token을 사용한 병렬 드래프트 생성이다.

self._mask_token = draft_config.mask_token
self._mask_token_id_override = draft_config.mask_token_id
self._mask_token_id = self._resolve_mask_token_id(
    mask_token=self._mask_token,
    mask_token_id=self._mask_token_id_override,
)

드래프트 모델은 [mask] 토큰 블록을 입력으로 받아 한 번의 forward로 block_size개의 토큰을 동시에 예측한다. EAGLE처럼 step별로 반복하지 않아 드래프팅 latency가 줄어든다.

3. DFlashDraftInput: 상태 관리

python/sglang/srt/speculative/dflash_info.py에서 드래프트 상태를 관리한다.

@dataclass
class DFlashDraftInput(SpecInput):
    verified_id: torch.Tensor        # 현재 토큰 (요청당 1개)
    target_hidden: torch.Tensor      # [sum(ctx_lens), K * hidden_size]
    ctx_lens: torch.Tensor           # 요청당 context 길이
    draft_seq_lens: torch.Tensor     # 드래프트 KV 캐시에 반영된 길이

    def __post_init__(self):
        super().__init__(spec_input_type=SpecInputType.DFLASH_DRAFT)

target_hidden은 타겟 모델에서 추출한 hidden states로, 드래프트 모델의 KV 캐시를 직접 구성하는 데 사용된다. ctx_lens가 요청별로 몇 개의 새 토큰이 아직 드래프트 캐시에 반영되지 않았는지를 추적한다.

4. DFlashVerifyInput: 선형 검증

@dataclass
class DFlashVerifyInput(SpecInput):
    draft_token: torch.Tensor
    positions: torch.Tensor
    draft_token_num: int
    topk: int = 1  # DFlash는 항상 선형(non-tree)
    custom_mask: torch.Tensor | None = None

topk = 1로 고정된 것이 DFlash의 특징이다. 트리 구조가 아닌 선형 체인으로 검증하여 구현이 단순하고, custom mask는 표준 causal masking을 따른다.

5. 검증 로직: accept_len + bonus

def verify(self, *, batch, logits_output, page_size):
    candidates = self.draft_token.view(bs, self.draft_token_num)
    if not sampling_info.is_all_greedy and is_dflash_sampling_verify_available():
        accept_len, bonus = compute_dflash_sampling_accept_len_and_bonus(
            candidates=candidates,
            next_token_logits=logits_output.next_token_logits,
            sampling_info=sampling_info,
        )
    else:
        target_predict = torch.argmax(logits_output.next_token_logits, dim=-1)
        accept_len, bonus = compute_dflash_accept_len_and_bonus(
            candidates=candidates, target_predict=target_predict.view(bs, self.draft_token_num),
        )

DFlash는 선형 체인이므로 검증이 단순하다. 첫 번째 불일치 위치까지를 accept_len으로, 불일치 위치의 타겟 예측을 bonus 토큰으로 반환한다.

6. Fused KV Materialization

DFlash는 타겟의 hidden states에서 드래프트 KV 캐시를 직접 구성하는 fused 경로를 제공한다.

self._use_fused_kv_materialize = is_cuda()
if self._use_fused_kv_materialize:
    self._init_fused_kv_helper()

def _init_fused_kv_helper(self):
    FusedKVMaterializeHelper = _get_fused_kv_materialize_helper()
    self._fused_kv_helper = FusedKVMaterializeHelper(
        layers=layers, rotary_emb=rotary_emb,
        num_kv_heads=first_attn.num_kv_heads,
        head_dim=first_attn.head_dim, device=self.device,
    )

Fused KV materialization은 QKV projection + RoPE + KV cache write를 하나의 Triton 커널로 통합하여, 레이어별 sequential 처리를 병렬화한다.

7. 페이지 정렬 캐시 해제

def _compute_paged_keep_slots(*, prefix_lens, commit_lens, draft_token_num, page_size):
    extended_lens = prefix_lens + int(draft_token_num)
    new_lens = prefix_lens + commit_lens.to(seq_dtype)
    aligned_new_lens = ((new_lens + page_size - 1) // page_size) * page_size
    keep_lens = torch.minimum(aligned_new_lens, extended_lens)
    keep_slots = (keep_lens - prefix_lens).to(torch.int64)
    keep_slots.clamp_(min=0, max=int(draft_token_num))
    return keep_slots

paged KV cache에서는 페이지 단위로만 해제할 수 있으므로, 수락된 토큰 이후의 전체 페이지만 반환한다.

EAGLE vs DFlash 비교

항목 EAGLE DFlash
드래프팅 방식 Autoregressive (step별) 블록 병렬 (mask token)
드래프팅 latency N * forward 1 * forward
검증 구조 트리 (topk > 1) 선형 (topk = 1)
KV 캐시 구성 드래프트 모델 자체 생성 타겟 hidden states에서 직접 구성
Overlap V2 지원 아니오
적합한 모델 범용 블록 드래프팅 지원 모델

설계 근거

블록 드래프팅: autoregressive 드래프트는 step마다 GPU idle time이 발생한다. 블록 단위 생성은 한 번의 forward로 전체 드래프트를 완료하여 latency를 최소화한다.

선형 검증의 단순함: 트리 검증은 복잡한 인덱스 관리가 필요하지만, 선형 검증은 첫 불일치까지만 찾으면 되므로 구현이 간단하고 커널 오버헤드가 적다.

Draft Windowing: draft_window_size로 드래프트 KV 캐시를 최근 윈도우로 제한하여, 긴 시퀀스에서도 메모리 사용량을 일정하게 유지한다.

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글